Fix 113 hardening issues across entire Go backend

Security:
- Replace all binding: tags with validate: + c.Validate() in admin handlers
- Add rate limiting to auth endpoints (login, register, password reset)
- Add security headers (HSTS, XSS protection, nosniff, frame options)
- Wire Google Pub/Sub token verification into webhook handler
- Replace ParseUnverified with proper OIDC/JWKS key verification
- Verify inner Apple JWS signatures in webhook handler
- Add io.LimitReader (1MB) to all webhook body reads
- Add ownership verification to file deletion
- Move hardcoded admin credentials to env vars
- Add uniqueIndex to User.Email
- Hide ConfirmationCode from JSON serialization
- Mask confirmation codes in admin responses
- Use http.DetectContentType for upload validation
- Fix path traversal in storage service
- Replace os.Getenv with Viper in stripe service
- Sanitize Redis URLs before logging
- Separate DEBUG_FIXED_CODES from DEBUG flag
- Reject weak SECRET_KEY in production
- Add host check on /_next/* proxy routes
- Use explicit localhost CORS origins in debug mode
- Replace err.Error() with generic messages in all admin error responses

Critical fixes:
- Rewrite FCM to HTTP v1 API with OAuth 2.0 service account auth
- Fix user_customuser -> auth_user table names in raw SQL
- Fix dashboard verified query to use UserProfile model
- Add escapeLikeWildcards() to prevent SQL wildcard injection

Bug fixes:
- Add bounds checks for days/expiring_soon query params (1-3650)
- Add receipt_data/transaction_id empty-check to RestoreSubscription
- Change Active bool -> *bool in device handler
- Check all unchecked GORM/FindByIDWithProfile errors
- Add validation for notification hour fields (0-23)
- Add max=10000 validation on task description updates

Transactions & data integrity:
- Wrap registration flow in transaction
- Wrap QuickComplete in transaction
- Move image creation inside completion transaction
- Wrap SetSpecialties in transaction
- Wrap GetOrCreateToken in transaction
- Wrap completion+image deletion in transaction

Performance:
- Batch completion summaries (2 queries vs 2N)
- Reuse single http.Client in IAP validation
- Cache dashboard counts (30s TTL)
- Batch COUNT queries in admin user list
- Add Limit(500) to document queries
- Add reminder_stage+due_date filters to reminder queries
- Parse AllowedTypes once at init
- In-memory user cache in auth middleware (30s TTL)
- Timezone change detection cache
- Optimize P95 with per-endpoint sorted buffers
- Replace crypto/md5 with hash/fnv for ETags

Code quality:
- Add sync.Once to all monitoring Stop()/Close() methods
- Replace 8 fmt.Printf with zerolog in auth service
- Log previously discarded errors
- Standardize delete response shapes
- Route hardcoded English through i18n
- Remove FileURL from DocumentResponse (keep MediaURL only)
- Thread user timezone through kanban board responses
- Initialize empty slices to prevent null JSON
- Extract shared field map for task Update/UpdateTx
- Delete unused SoftDeleteModel, min(), formatCron, legacy handlers

Worker & jobs:
- Wire Asynq email infrastructure into worker
- Register HandleReminderLogCleanup with daily 3AM cron
- Use per-user timezone in HandleSmartReminder
- Replace direct DB queries with repository calls
- Delete legacy reminder handlers (~200 lines)
- Delete unused task type constants

Dependencies:
- Replace archived jung-kurt/gofpdf with go-pdf/fpdf
- Replace unmaintained gomail.v2 with wneessen/go-mail
- Add TODO for Echo jwt v3 transitive dep removal

Test infrastructure:
- Fix MakeRequest/SeedLookupData error handling
- Replace os.Exit(0) with t.Skip() in scope/consistency tests
- Add 11 new FCM v1 tests

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Trey t
2026-03-18 23:14:13 -05:00
parent 3b86d0aae1
commit 42a5533a56
95 changed files with 2892 additions and 1783 deletions

View File

@@ -47,7 +47,7 @@ func main() {
Int("db_port", cfg.Database.Port).
Str("db_name", cfg.Database.Database).
Str("db_user", cfg.Database.User).
Str("redis_url", cfg.Redis.URL).
Str("redis_url", config.MaskURLCredentials(cfg.Redis.URL)).
Msg("Starting HoneyDue API server")
// Connect to database (retry with backoff)

View File

@@ -2,7 +2,6 @@ package main
import (
"context"
"fmt"
"os"
"os/signal"
"syscall"
@@ -115,12 +114,6 @@ func main() {
}
}
// Check worker kill switch
if !cfg.Features.WorkerEnabled {
log.Warn().Msg("Worker disabled by FEATURE_WORKER_ENABLED=false, exiting")
return
}
// Create Asynq server
srv := asynq.NewServer(
redisOpt,
@@ -151,6 +144,13 @@ func main() {
mux.HandleFunc(jobs.TypeSendEmail, jobHandler.HandleSendEmail)
mux.HandleFunc(jobs.TypeSendPush, jobHandler.HandleSendPush)
mux.HandleFunc(jobs.TypeOnboardingEmails, jobHandler.HandleOnboardingEmails)
mux.HandleFunc(jobs.TypeReminderLogCleanup, jobHandler.HandleReminderLogCleanup)
// Register email job handlers (welcome, verification, password reset, password changed)
if emailService != nil {
emailJobHandler := jobs.NewEmailJobHandler(emailService)
emailJobHandler.RegisterHandlers(mux)
}
// Start scheduler for periodic tasks
scheduler := asynq.NewScheduler(redisOpt, nil)
@@ -177,6 +177,13 @@ func main() {
}
log.Info().Str("cron", "0 10 * * *").Msg("Registered onboarding emails job (runs daily at 10:00 AM UTC)")
// Schedule reminder log cleanup (runs daily at 3:00 AM UTC)
// Removes reminder logs older than 90 days to prevent table bloat
if _, err := scheduler.Register("0 3 * * *", asynq.NewTask(jobs.TypeReminderLogCleanup, nil)); err != nil {
log.Fatal().Err(err).Msg("Failed to register reminder log cleanup job")
}
log.Info().Str("cron", "0 3 * * *").Msg("Registered reminder log cleanup job (runs daily at 3:00 AM UTC)")
// Handle graceful shutdown
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
@@ -205,8 +212,3 @@ func main() {
log.Info().Msg("Worker stopped")
}
// formatCron creates a cron expression for a specific hour (runs at minute 0)
func formatCron(hour int) string {
return fmt.Sprintf("0 %02d * * *", hour)
}

13
go.mod
View File

@@ -3,12 +3,12 @@ module github.com/treytartt/honeydue-api
go 1.24.0
require (
github.com/go-pdf/fpdf v0.9.0
github.com/go-playground/validator/v10 v10.23.0
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/hibiken/asynq v0.25.1
github.com/jung-kurt/gofpdf v1.16.2
github.com/labstack/echo/v4 v4.11.4
github.com/nicksnyder/go-i18n/v2 v2.6.0
github.com/redis/go-redis/v9 v9.17.1
@@ -18,11 +18,14 @@ require (
github.com/sideshow/apns2 v0.25.0
github.com/spf13/viper v1.20.1
github.com/stretchr/testify v1.11.1
github.com/stripe/stripe-go/v81 v81.4.0
github.com/wneessen/go-mail v0.7.2
golang.org/x/crypto v0.45.0
golang.org/x/oauth2 v0.34.0
golang.org/x/text v0.31.0
golang.org/x/time v0.14.0
google.golang.org/api v0.257.0
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.6.0
gorm.io/driver/sqlite v1.6.0
gorm.io/gorm v1.31.1
@@ -44,7 +47,7 @@ require (
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect; TODO(S-19): Pulled by echo/v4 middleware upgrade Echo to v4.12+ which removes built-in JWT middleware (uses echo-jwt/v4 with jwt/v5 instead), eliminating this vulnerable transitive dep
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect
@@ -71,7 +74,6 @@ require (
github.com/spf13/afero v1.14.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/stripe/stripe-go/v81 v81.4.0 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
@@ -86,10 +88,7 @@ require (
golang.org/x/net v0.47.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/time v0.14.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846 // indirect
google.golang.org/grpc v1.77.0 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

18
go.sum
View File

@@ -8,7 +8,6 @@ github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20201120081800-1786d5ef83d4/go.mod h1:OMCwj8VM1Kc9e19TLln2VL61YJF0x1XFtfdL4JdbSyE=
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
@@ -36,6 +35,8 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-pdf/fpdf v0.9.0 h1:PPvSaUuo1iMi9KkaAn90NuKi+P4gwMedWPHhj8YlJQw=
github.com/go-pdf/fpdf v0.9.0/go.mod h1:oO8N111TkmKb9D7VvWGLvLJlaZUQVPM+6V42pp3iV4Y=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
@@ -83,9 +84,6 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
github.com/jung-kurt/gofpdf v1.16.2 h1:jgbatWHfRlPYiK85qgevsZTHviWXKwB1TTiKdz5PtRc=
github.com/jung-kurt/gofpdf v1.16.2/go.mod h1:1hl7y57EsiPAkLbOwzpzqgx1A30nQCk/YmFV8S2vmK0=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
@@ -110,8 +108,6 @@ github.com/nicksnyder/go-i18n/v2 v2.6.0 h1:C/m2NNWNiTB6SK4Ao8df5EWm3JETSTIGNXBpM
github.com/nicksnyder/go-i18n/v2 v2.6.0/go.mod h1:88sRqr0C6OPyJn0/KRNaEz1uWorjxIKP7rUUcvycecE=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/phpdave11/gofpdi v1.0.7/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@@ -126,7 +122,6 @@ github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfFZQK844Gfx8o5WFuvpxWRwnSoipWe/p622j1v06w=
github.com/sagikazarmark/locafero v0.9.0 h1:GbgQGNtTrEmddYDSAH9QLRyfAHY12md+8YFTqyMTC9k=
github.com/sagikazarmark/locafero v0.9.0/go.mod h1:UBUyz37V+EdMS3hDF3QWIiVr/2dPrx49OMO0Bn0hJqk=
github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI=
@@ -150,7 +145,6 @@ github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3A
github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4=
github.com/spf13/viper v1.20.1/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
@@ -168,6 +162,8 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo=
github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
github.com/wneessen/go-mail v0.7.2 h1:xxPnhZ6IZLSgxShebmZ6DPKh1b6OJcoHfzy7UjOkzS8=
github.com/wneessen/go-mail v0.7.2/go.mod h1:+TkW6QP3EVkgTEqHtVmnAE/1MRhmzb8Y9/W3pweuS+k=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
@@ -189,7 +185,6 @@ go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
golang.org/x/crypto v0.0.0-20170512130425-ab89591268e0/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/image v0.0.0-20190910094157-69e4b8554b2a/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220403103023-749bd193bc2b/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
@@ -213,7 +208,6 @@ golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
@@ -237,13 +231,9 @@ google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHh
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk=
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df h1:n7WqCuqOuCbNr617RXOY0AWRXxgwEyPp2z+p0+hgMuE=
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkpBDuZnXENFzX8qRjMDMyPD6BRkCw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -1,6 +1,9 @@
package dto
import "github.com/treytartt/honeydue-api/internal/middleware"
import (
"github.com/shopspring/decimal"
"github.com/treytartt/honeydue-api/internal/middleware"
)
// PaginationParams holds pagination query parameters
type PaginationParams struct {
@@ -115,9 +118,9 @@ type UpdateResidenceRequest struct {
YearBuilt *int `json:"year_built"`
Description *string `json:"description"`
PurchaseDate *string `json:"purchase_date"`
PurchasePrice *float64 `json:"purchase_price"`
IsActive *bool `json:"is_active"`
IsPrimary *bool `json:"is_primary"`
PurchasePrice *decimal.Decimal `json:"purchase_price"`
IsActive *bool `json:"is_active"`
IsPrimary *bool `json:"is_primary"`
}
// TaskFilters holds task-specific filter parameters
@@ -144,8 +147,8 @@ type UpdateTaskRequest struct {
InProgress *bool `json:"in_progress"`
DueDate *string `json:"due_date"`
NextDueDate *string `json:"next_due_date"`
EstimatedCost *float64 `json:"estimated_cost"`
ActualCost *float64 `json:"actual_cost"`
EstimatedCost *decimal.Decimal `json:"estimated_cost"`
ActualCost *decimal.Decimal `json:"actual_cost"`
ContractorID *uint `json:"contractor_id"`
ParentTaskID *uint `json:"parent_task_id"`
IsCancelled *bool `json:"is_cancelled"`
@@ -201,8 +204,8 @@ type UpdateDocumentRequest struct {
MimeType *string `json:"mime_type" validate:"omitempty,max=100"`
PurchaseDate *string `json:"purchase_date"`
ExpiryDate *string `json:"expiry_date"`
PurchasePrice *float64 `json:"purchase_price"`
Vendor *string `json:"vendor" validate:"omitempty,max=200"`
PurchasePrice *decimal.Decimal `json:"purchase_price"`
Vendor *string `json:"vendor" validate:"omitempty,max=200"`
SerialNumber *string `json:"serial_number" validate:"omitempty,max=100"`
ModelNumber *string `json:"model_number" validate:"omitempty,max=100"`
Provider *string `json:"provider" validate:"omitempty,max=200"`
@@ -292,9 +295,9 @@ type CreateTaskRequest struct {
FrequencyID *uint `json:"frequency_id"`
InProgress bool `json:"in_progress"`
AssignedToID *uint `json:"assigned_to_id"`
DueDate *string `json:"due_date"`
EstimatedCost *float64 `json:"estimated_cost"`
ContractorID *uint `json:"contractor_id"`
DueDate *string `json:"due_date"`
EstimatedCost *decimal.Decimal `json:"estimated_cost"`
ContractorID *uint `json:"contractor_id"`
}
// CreateContractorRequest for creating a new contractor
@@ -328,8 +331,8 @@ type CreateDocumentRequest struct {
MimeType string `json:"mime_type" validate:"max=100"`
PurchaseDate *string `json:"purchase_date"`
ExpiryDate *string `json:"expiry_date"`
PurchasePrice *float64 `json:"purchase_price"`
Vendor string `json:"vendor" validate:"max=200"`
PurchasePrice *decimal.Decimal `json:"purchase_price"`
Vendor string `json:"vendor" validate:"max=200"`
SerialNumber string `json:"serial_number" validate:"max=100"`
ModelNumber string `json:"model_number" validate:"max=100"`
TaskID *uint `json:"task_id"`

View File

@@ -33,21 +33,21 @@ type AdminUserFilters struct {
// CreateAdminUserRequest for creating a new admin user
type CreateAdminUserRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=8"`
FirstName string `json:"first_name" binding:"max=100"`
LastName string `json:"last_name" binding:"max=100"`
Role string `json:"role" binding:"omitempty,oneof=admin super_admin"`
Email string `json:"email" validate:"required,email"`
Password string `json:"password" validate:"required,min=8"`
FirstName string `json:"first_name" validate:"max=100"`
LastName string `json:"last_name" validate:"max=100"`
Role string `json:"role" validate:"omitempty,oneof=admin super_admin"`
IsActive *bool `json:"is_active"`
}
// UpdateAdminUserRequest for updating an admin user
type UpdateAdminUserRequest struct {
Email *string `json:"email" binding:"omitempty,email"`
Password *string `json:"password" binding:"omitempty,min=8"`
FirstName *string `json:"first_name" binding:"omitempty,max=100"`
LastName *string `json:"last_name" binding:"omitempty,max=100"`
Role *string `json:"role" binding:"omitempty,oneof=admin super_admin"`
Email *string `json:"email" validate:"omitempty,email"`
Password *string `json:"password" validate:"omitempty,min=8"`
FirstName *string `json:"first_name" validate:"omitempty,max=100"`
LastName *string `json:"last_name" validate:"omitempty,max=100"`
Role *string `json:"role" validate:"omitempty,oneof=admin super_admin"`
IsActive *bool `json:"is_active"`
}
@@ -55,7 +55,7 @@ type UpdateAdminUserRequest struct {
func (h *AdminUserManagementHandler) List(c echo.Context) error {
var filters AdminUserFilters
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var adminUsers []models.AdminUser
@@ -134,7 +134,10 @@ func (h *AdminUserManagementHandler) Create(c echo.Context) error {
var req CreateAdminUserRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
// Check if email already exists
@@ -199,7 +202,10 @@ func (h *AdminUserManagementHandler) Update(c echo.Context) error {
var req UpdateAdminUserRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
if req.Email != nil {

View File

@@ -44,7 +44,7 @@ type UpdateAppleSocialAuthRequest struct {
func (h *AdminAppleSocialAuthHandler) List(c echo.Context) error {
var filters dto.PaginationParams
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var entries []models.AppleSocialAuth
@@ -139,7 +139,7 @@ func (h *AdminAppleSocialAuthHandler) Update(c echo.Context) error {
var req UpdateAppleSocialAuthRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if req.Email != nil {
@@ -183,14 +183,15 @@ func (h *AdminAppleSocialAuthHandler) Delete(c echo.Context) error {
func (h *AdminAppleSocialAuthHandler) BulkDelete(c echo.Context) error {
var req dto.BulkDeleteRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := h.db.Where("id IN ?", req.IDs).Delete(&models.AppleSocialAuth{}).Error; err != nil {
result := h.db.Where("id IN ?", req.IDs).Delete(&models.AppleSocialAuth{})
if result.Error != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth entries"})
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Apple social auth entries deleted successfully", "count": len(req.IDs)})
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Apple social auth entries deleted successfully", "count": result.RowsAffected})
}
// toResponse converts an AppleSocialAuth model to AppleSocialAuthResponse

View File

@@ -27,8 +27,8 @@ func NewAdminAuthHandler(adminRepo *repositories.AdminRepository, cfg *config.Co
// LoginRequest represents the admin login request
type LoginRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required"`
Email string `json:"email" validate:"required,email"`
Password string `json:"password" validate:"required"`
}
// LoginResponse represents the admin login response
@@ -71,7 +71,10 @@ func NewAdminUserResponse(admin *models.AdminUser) AdminUserResponse {
func (h *AdminAuthHandler) Login(c echo.Context) error {
var req LoginRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request: " + err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
// Find admin by email
@@ -100,7 +103,10 @@ func (h *AdminAuthHandler) Login(c echo.Context) error {
_ = h.adminRepo.UpdateLastLogin(admin.ID)
// Refresh admin data after updating last login
admin, _ = h.adminRepo.FindByID(admin.ID)
admin, err = h.adminRepo.FindByID(admin.ID)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to refresh admin data"})
}
return c.JSON(http.StatusOK, LoginResponse{
Token: token,

View File

@@ -34,7 +34,7 @@ type AuthTokenResponse struct {
func (h *AdminAuthTokenHandler) List(c echo.Context) error {
var filters dto.PaginationParams
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var tokens []models.AuthToken
@@ -132,7 +132,7 @@ func (h *AdminAuthTokenHandler) Delete(c echo.Context) error {
func (h *AdminAuthTokenHandler) BulkDelete(c echo.Context) error {
var req dto.BulkDeleteRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
result := h.db.Where("user_id IN ?", req.IDs).Delete(&models.AuthToken{})

View File

@@ -58,7 +58,7 @@ type CompletionFilters struct {
func (h *AdminCompletionHandler) List(c echo.Context) error {
var filters CompletionFilters
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var completions []models.TaskCompletion
@@ -167,7 +167,7 @@ func (h *AdminCompletionHandler) Delete(c echo.Context) error {
func (h *AdminCompletionHandler) BulkDelete(c echo.Context) error {
var req dto.BulkDeleteRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
result := h.db.Where("id IN ?", req.IDs).Delete(&models.TaskCompletion{})
@@ -201,7 +201,7 @@ func (h *AdminCompletionHandler) Update(c echo.Context) error {
var req UpdateCompletionRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if req.Notes != nil {

View File

@@ -35,8 +35,8 @@ type AdminCompletionImageResponse struct {
// CreateCompletionImageRequest represents the request to create a completion image
type CreateCompletionImageRequest struct {
CompletionID uint `json:"completion_id" binding:"required"`
ImageURL string `json:"image_url" binding:"required"`
CompletionID uint `json:"completion_id" validate:"required"`
ImageURL string `json:"image_url" validate:"required"`
Caption string `json:"caption"`
}
@@ -50,7 +50,7 @@ type UpdateCompletionImageRequest struct {
func (h *AdminCompletionImageHandler) List(c echo.Context) error {
var filters dto.PaginationParams
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
// Optional completion_id filter
@@ -91,10 +91,39 @@ func (h *AdminCompletionImageHandler) List(c echo.Context) error {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch completion images"})
}
// Batch-load completion+task info to avoid N+1 queries
completionIDs := make([]uint, 0, len(images))
for _, img := range images {
completionIDs = append(completionIDs, img.CompletionID)
}
completionMap := make(map[uint]models.TaskCompletion)
if len(completionIDs) > 0 {
var completions []models.TaskCompletion
h.db.Preload("Task").Where("id IN ?", completionIDs).Find(&completions)
for _, c := range completions {
completionMap[c.ID] = c
}
}
// Build response with task info
responses := make([]AdminCompletionImageResponse, len(images))
for i, image := range images {
responses[i] = h.toResponse(&image)
response := AdminCompletionImageResponse{
ID: image.ID,
CompletionID: image.CompletionID,
ImageURL: image.ImageURL,
Caption: image.Caption,
CreatedAt: image.CreatedAt.Format("2006-01-02T15:04:05Z"),
UpdatedAt: image.UpdatedAt.Format("2006-01-02T15:04:05Z"),
}
if comp, ok := completionMap[image.CompletionID]; ok {
response.TaskID = comp.TaskID
if comp.Task.ID != 0 {
response.TaskTitle = comp.Task.Title
}
}
responses[i] = response
}
return c.JSON(http.StatusOK, dto.NewPaginatedResponse(responses, total, filters.GetPage(), filters.GetPerPage()))
@@ -122,7 +151,10 @@ func (h *AdminCompletionImageHandler) Get(c echo.Context) error {
func (h *AdminCompletionImageHandler) Create(c echo.Context) error {
var req CreateCompletionImageRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
// Verify completion exists
@@ -164,7 +196,7 @@ func (h *AdminCompletionImageHandler) Update(c echo.Context) error {
var req UpdateCompletionImageRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if req.ImageURL != nil {
@@ -207,14 +239,15 @@ func (h *AdminCompletionImageHandler) Delete(c echo.Context) error {
func (h *AdminCompletionImageHandler) BulkDelete(c echo.Context) error {
var req dto.BulkDeleteRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := h.db.Where("id IN ?", req.IDs).Delete(&models.TaskCompletionImage{}).Error; err != nil {
result := h.db.Where("id IN ?", req.IDs).Delete(&models.TaskCompletionImage{})
if result.Error != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete completion images"})
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Completion images deleted successfully", "count": len(req.IDs)})
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Completion images deleted successfully", "count": result.RowsAffected})
}
// toResponse converts a TaskCompletionImage model to AdminCompletionImageResponse

View File

@@ -3,6 +3,7 @@ package handlers
import (
"net/http"
"strconv"
"strings"
"github.com/labstack/echo/v4"
"gorm.io/gorm"
@@ -11,6 +12,14 @@ import (
"github.com/treytartt/honeydue-api/internal/models"
)
// maskCode masks a confirmation code, showing only the last 4 characters.
func maskCode(code string) string {
if len(code) <= 4 {
return strings.Repeat("*", len(code))
}
return strings.Repeat("*", len(code)-4) + code[len(code)-4:]
}
// AdminConfirmationCodeHandler handles admin confirmation code management endpoints
type AdminConfirmationCodeHandler struct {
db *gorm.DB
@@ -37,7 +46,7 @@ type ConfirmationCodeResponse struct {
func (h *AdminConfirmationCodeHandler) List(c echo.Context) error {
var filters dto.PaginationParams
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var codes []models.ConfirmationCode
@@ -79,7 +88,7 @@ func (h *AdminConfirmationCodeHandler) List(c echo.Context) error {
UserID: code.UserID,
Username: code.User.Username,
Email: code.User.Email,
Code: code.Code,
Code: maskCode(code.Code),
ExpiresAt: code.ExpiresAt.Format("2006-01-02T15:04:05Z"),
IsUsed: code.IsUsed,
CreatedAt: code.CreatedAt.Format("2006-01-02T15:04:05Z"),
@@ -109,7 +118,7 @@ func (h *AdminConfirmationCodeHandler) Get(c echo.Context) error {
UserID: code.UserID,
Username: code.User.Username,
Email: code.User.Email,
Code: code.Code,
Code: maskCode(code.Code),
ExpiresAt: code.ExpiresAt.Format("2006-01-02T15:04:05Z"),
IsUsed: code.IsUsed,
CreatedAt: code.CreatedAt.Format("2006-01-02T15:04:05Z"),
@@ -141,7 +150,7 @@ func (h *AdminConfirmationCodeHandler) Delete(c echo.Context) error {
func (h *AdminConfirmationCodeHandler) BulkDelete(c echo.Context) error {
var req dto.BulkDeleteRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
result := h.db.Where("id IN ?", req.IDs).Delete(&models.ConfirmationCode{})

View File

@@ -25,7 +25,7 @@ func NewAdminContractorHandler(db *gorm.DB) *AdminContractorHandler {
func (h *AdminContractorHandler) List(c echo.Context) error {
var filters dto.ContractorFilters
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var contractors []models.Contractor
@@ -130,7 +130,7 @@ func (h *AdminContractorHandler) Update(c echo.Context) error {
var req dto.UpdateContractorRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
// Verify residence if changing
@@ -213,7 +213,7 @@ func (h *AdminContractorHandler) Update(c echo.Context) error {
func (h *AdminContractorHandler) Create(c echo.Context) error {
var req dto.CreateContractorRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
// Verify residence exists if provided
@@ -290,7 +290,7 @@ func (h *AdminContractorHandler) Delete(c echo.Context) error {
func (h *AdminContractorHandler) BulkDelete(c echo.Context) error {
var req dto.BulkDeleteRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
// Soft delete - deactivate all

View File

@@ -2,6 +2,7 @@ package handlers
import (
"net/http"
"sync"
"time"
"github.com/labstack/echo/v4"
@@ -11,6 +12,18 @@ import (
"github.com/treytartt/honeydue-api/internal/task/scopes"
)
// dashboardCache provides short-lived in-memory caching for dashboard stats
// to avoid expensive COUNT queries on every request.
type dashboardCache struct {
mu sync.RWMutex
stats *DashboardStats
expiresAt time.Time
}
var dashCache = &dashboardCache{}
const dashboardCacheTTL = 30 * time.Second
// AdminDashboardHandler handles admin dashboard endpoints
type AdminDashboardHandler struct {
db *gorm.DB
@@ -94,6 +107,15 @@ type SubscriptionStats struct {
// GetStats handles GET /api/admin/dashboard/stats
func (h *AdminDashboardHandler) GetStats(c echo.Context) error {
// Check cache first
dashCache.mu.RLock()
if dashCache.stats != nil && time.Now().Before(dashCache.expiresAt) {
cached := *dashCache.stats
dashCache.mu.RUnlock()
return c.JSON(http.StatusOK, cached)
}
dashCache.mu.RUnlock()
stats := DashboardStats{}
now := time.Now()
thirtyDaysAgo := now.AddDate(0, 0, -30)
@@ -101,7 +123,7 @@ func (h *AdminDashboardHandler) GetStats(c echo.Context) error {
// User stats
h.db.Model(&models.User{}).Count(&stats.Users.Total)
h.db.Model(&models.User{}).Where("is_active = ?", true).Count(&stats.Users.Active)
h.db.Model(&models.User{}).Where("verified = ?", true).Count(&stats.Users.Verified)
h.db.Model(&models.UserProfile{}).Where("verified = ?", true).Count(&stats.Users.Verified)
h.db.Model(&models.User{}).Where("date_joined >= ?", thirtyDaysAgo).Count(&stats.Users.New30d)
// Residence stats
@@ -164,5 +186,11 @@ func (h *AdminDashboardHandler) GetStats(c echo.Context) error {
h.db.Model(&models.UserSubscription{}).Where("tier = ?", "premium").Count(&stats.Subscriptions.Premium)
h.db.Model(&models.UserSubscription{}).Where("tier = ?", "pro").Count(&stats.Subscriptions.Pro)
// Cache the result
dashCache.mu.Lock()
dashCache.stats = &stats
dashCache.expiresAt = time.Now().Add(dashboardCacheTTL)
dashCache.mu.Unlock()
return c.JSON(http.StatusOK, stats)
}

View File

@@ -50,7 +50,7 @@ type GCMDeviceResponse struct {
func (h *AdminDeviceHandler) ListAPNS(c echo.Context) error {
var filters dto.PaginationParams
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var devices []models.APNSDevice
@@ -106,7 +106,7 @@ func (h *AdminDeviceHandler) ListAPNS(c echo.Context) error {
func (h *AdminDeviceHandler) ListGCM(c echo.Context) error {
var filters dto.PaginationParams
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var devices []models.GCMDevice
@@ -174,13 +174,15 @@ func (h *AdminDeviceHandler) UpdateAPNS(c echo.Context) error {
}
var req struct {
Active bool `json:"active"`
Active *bool `json:"active"`
}
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
device.Active = req.Active
if req.Active != nil {
device.Active = *req.Active
}
if err := h.db.Save(&device).Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update device"})
}
@@ -204,13 +206,15 @@ func (h *AdminDeviceHandler) UpdateGCM(c echo.Context) error {
}
var req struct {
Active bool `json:"active"`
Active *bool `json:"active"`
}
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
device.Active = req.Active
if req.Active != nil {
device.Active = *req.Active
}
if err := h.db.Save(&device).Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update device"})
}
@@ -260,7 +264,7 @@ func (h *AdminDeviceHandler) DeleteGCM(c echo.Context) error {
func (h *AdminDeviceHandler) BulkDeleteAPNS(c echo.Context) error {
var req dto.BulkDeleteRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
result := h.db.Where("id IN ?", req.IDs).Delete(&models.APNSDevice{})
@@ -275,7 +279,7 @@ func (h *AdminDeviceHandler) BulkDeleteAPNS(c echo.Context) error {
func (h *AdminDeviceHandler) BulkDeleteGCM(c echo.Context) error {
var req dto.BulkDeleteRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
result := h.db.Where("id IN ?", req.IDs).Delete(&models.GCMDevice{})
@@ -307,10 +311,3 @@ func (h *AdminDeviceHandler) GetStats(c echo.Context) error {
"total": apnsTotal + gcmTotal,
})
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@@ -6,7 +6,6 @@ import (
"time"
"github.com/labstack/echo/v4"
"github.com/shopspring/decimal"
"gorm.io/gorm"
"github.com/treytartt/honeydue-api/internal/admin/dto"
@@ -27,7 +26,7 @@ func NewAdminDocumentHandler(db *gorm.DB) *AdminDocumentHandler {
func (h *AdminDocumentHandler) List(c echo.Context) error {
var filters dto.DocumentFilters
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var documents []models.Document
@@ -132,7 +131,7 @@ func (h *AdminDocumentHandler) Update(c echo.Context) error {
var req dto.UpdateDocumentRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
// Verify residence if changing
@@ -183,8 +182,7 @@ func (h *AdminDocumentHandler) Update(c echo.Context) error {
}
}
if req.PurchasePrice != nil {
d := decimal.NewFromFloat(*req.PurchasePrice)
document.PurchasePrice = &d
document.PurchasePrice = req.PurchasePrice
}
if req.Vendor != nil {
document.Vendor = *req.Vendor
@@ -232,7 +230,7 @@ func (h *AdminDocumentHandler) Update(c echo.Context) error {
func (h *AdminDocumentHandler) Create(c echo.Context) error {
var req dto.CreateDocumentRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
// Verify residence exists
@@ -282,8 +280,7 @@ func (h *AdminDocumentHandler) Create(c echo.Context) error {
}
}
if req.PurchasePrice != nil {
d := decimal.NewFromFloat(*req.PurchasePrice)
document.PurchasePrice = &d
document.PurchasePrice = req.PurchasePrice
}
if err := h.db.Create(&document).Error; err != nil {
@@ -322,7 +319,7 @@ func (h *AdminDocumentHandler) Delete(c echo.Context) error {
func (h *AdminDocumentHandler) BulkDelete(c echo.Context) error {
var req dto.BulkDeleteRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
// Soft delete - deactivate all

View File

@@ -36,8 +36,8 @@ type DocumentImageResponse struct {
// CreateDocumentImageRequest represents the request to create a document image
type CreateDocumentImageRequest struct {
DocumentID uint `json:"document_id" binding:"required"`
ImageURL string `json:"image_url" binding:"required"`
DocumentID uint `json:"document_id" validate:"required"`
ImageURL string `json:"image_url" validate:"required"`
Caption string `json:"caption"`
}
@@ -51,7 +51,7 @@ type UpdateDocumentImageRequest struct {
func (h *AdminDocumentImageHandler) List(c echo.Context) error {
var filters dto.PaginationParams
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
// Optional document_id filter
@@ -123,7 +123,10 @@ func (h *AdminDocumentImageHandler) Get(c echo.Context) error {
func (h *AdminDocumentImageHandler) Create(c echo.Context) error {
var req CreateDocumentImageRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
// Verify document exists
@@ -165,7 +168,7 @@ func (h *AdminDocumentImageHandler) Update(c echo.Context) error {
var req UpdateDocumentImageRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if req.ImageURL != nil {
@@ -208,14 +211,15 @@ func (h *AdminDocumentImageHandler) Delete(c echo.Context) error {
func (h *AdminDocumentImageHandler) BulkDelete(c echo.Context) error {
var req dto.BulkDeleteRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := h.db.Where("id IN ?", req.IDs).Delete(&models.DocumentImage{}).Error; err != nil {
result := h.db.Where("id IN ?", req.IDs).Delete(&models.DocumentImage{})
if result.Error != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete document images"})
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Document images deleted successfully", "count": len(req.IDs)})
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Document images deleted successfully", "count": result.RowsAffected})
}
// toResponse converts a DocumentImage model to DocumentImageResponse

View File

@@ -37,7 +37,7 @@ type FeatureBenefitResponse struct {
func (h *AdminFeatureBenefitHandler) List(c echo.Context) error {
var filters dto.PaginationParams
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var benefits []models.FeatureBenefit
@@ -112,15 +112,18 @@ func (h *AdminFeatureBenefitHandler) Get(c echo.Context) error {
// Create handles POST /api/admin/feature-benefits
func (h *AdminFeatureBenefitHandler) Create(c echo.Context) error {
var req struct {
FeatureName string `json:"feature_name" binding:"required"`
FreeTierText string `json:"free_tier_text" binding:"required"`
ProTierText string `json:"pro_tier_text" binding:"required"`
FeatureName string `json:"feature_name" validate:"required"`
FreeTierText string `json:"free_tier_text" validate:"required"`
ProTierText string `json:"pro_tier_text" validate:"required"`
DisplayOrder int `json:"display_order"`
IsActive *bool `json:"is_active"`
}
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
benefit := models.FeatureBenefit{
@@ -175,7 +178,7 @@ func (h *AdminFeatureBenefitHandler) Update(c echo.Context) error {
}
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if req.FeatureName != nil {

View File

@@ -34,7 +34,9 @@ func (h *AdminLimitationsHandler) GetSettings(c echo.Context) error {
if err == gorm.ErrRecordNotFound {
// Create default settings
settings = models.SubscriptionSettings{ID: 1, EnableLimitations: false}
h.db.Create(&settings)
if err := h.db.Create(&settings).Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to create default settings"})
}
} else {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch settings"})
}
@@ -54,7 +56,7 @@ type UpdateLimitationsSettingsRequest struct {
func (h *AdminLimitationsHandler) UpdateSettings(c echo.Context) error {
var req UpdateLimitationsSettingsRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var settings models.SubscriptionSettings
@@ -117,8 +119,12 @@ func (h *AdminLimitationsHandler) ListTierLimits(c echo.Context) error {
if len(limits) == 0 {
freeLimits := models.GetDefaultFreeLimits()
proLimits := models.GetDefaultProLimits()
h.db.Create(&freeLimits)
h.db.Create(&proLimits)
if err := h.db.Create(&freeLimits).Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to create default free tier limits"})
}
if err := h.db.Create(&proLimits).Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to create default pro tier limits"})
}
limits = []models.TierLimits{freeLimits, proLimits}
}
@@ -149,7 +155,9 @@ func (h *AdminLimitationsHandler) GetTierLimits(c echo.Context) error {
} else {
limits = models.GetDefaultProLimits()
}
h.db.Create(&limits)
if err := h.db.Create(&limits).Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to create default tier limits"})
}
} else {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch tier limits"})
}
@@ -175,7 +183,7 @@ func (h *AdminLimitationsHandler) UpdateTierLimits(c echo.Context) error {
var req UpdateTierLimitsRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var limits models.TierLimits
@@ -188,13 +196,23 @@ func (h *AdminLimitationsHandler) UpdateTierLimits(c echo.Context) error {
}
}
// Update fields - note: we need to handle nil vs zero difference
// A nil pointer in the request means "don't change"
// The actual limit value can be nil (unlimited) or a number
limits.PropertiesLimit = req.PropertiesLimit
limits.TasksLimit = req.TasksLimit
limits.ContractorsLimit = req.ContractorsLimit
limits.DocumentsLimit = req.DocumentsLimit
// Update fields only when explicitly provided in the request body.
// JSON unmarshaling sets *int to nil when the key is absent, and to
// a non-nil *int (possibly pointing to 0) when the key is present.
// We rely on Bind populating these before calling this handler, so
// a nil pointer here means "don't change".
if req.PropertiesLimit != nil {
limits.PropertiesLimit = req.PropertiesLimit
}
if req.TasksLimit != nil {
limits.TasksLimit = req.TasksLimit
}
if req.ContractorsLimit != nil {
limits.ContractorsLimit = req.ContractorsLimit
}
if req.DocumentsLimit != nil {
limits.DocumentsLimit = req.DocumentsLimit
}
if err := h.db.Save(&limits).Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update tier limits"})
@@ -297,9 +315,9 @@ func (h *AdminLimitationsHandler) GetUpgradeTrigger(c echo.Context) error {
// CreateUpgradeTriggerRequest represents the create request
type CreateUpgradeTriggerRequest struct {
TriggerKey string `json:"trigger_key" binding:"required"`
Title string `json:"title" binding:"required"`
Message string `json:"message" binding:"required"`
TriggerKey string `json:"trigger_key" validate:"required"`
Title string `json:"title" validate:"required"`
Message string `json:"message" validate:"required"`
PromoHTML string `json:"promo_html"`
ButtonText string `json:"button_text"`
IsActive *bool `json:"is_active"`
@@ -309,7 +327,10 @@ type CreateUpgradeTriggerRequest struct {
func (h *AdminLimitationsHandler) CreateUpgradeTrigger(c echo.Context) error {
var req CreateUpgradeTriggerRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
// Validate trigger key
@@ -380,7 +401,7 @@ func (h *AdminLimitationsHandler) UpdateUpgradeTrigger(c echo.Context) error {
var req UpdateUpgradeTriggerRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if req.TriggerKey != nil {

View File

@@ -162,10 +162,10 @@ type TaskCategoryResponse struct {
}
type CreateUpdateCategoryRequest struct {
Name string `json:"name" binding:"required,max=50"`
Name string `json:"name" validate:"required,max=50"`
Description string `json:"description"`
Icon string `json:"icon" binding:"max=50"`
Color string `json:"color" binding:"max=7"`
Icon string `json:"icon" validate:"max=50"`
Color string `json:"color" validate:"max=7"`
DisplayOrder *int `json:"display_order"`
}
@@ -193,7 +193,10 @@ func (h *AdminLookupHandler) ListCategories(c echo.Context) error {
func (h *AdminLookupHandler) CreateCategory(c echo.Context) error {
var req CreateUpdateCategoryRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
category := models.TaskCategory{
@@ -239,7 +242,10 @@ func (h *AdminLookupHandler) UpdateCategory(c echo.Context) error {
var req CreateUpdateCategoryRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
category.Name = req.Name
@@ -301,9 +307,9 @@ type TaskPriorityResponse struct {
}
type CreateUpdatePriorityRequest struct {
Name string `json:"name" binding:"required,max=20"`
Level int `json:"level" binding:"required,min=1,max=10"`
Color string `json:"color" binding:"max=7"`
Name string `json:"name" validate:"required,max=20"`
Level int `json:"level" validate:"required,min=1,max=10"`
Color string `json:"color" validate:"max=7"`
DisplayOrder *int `json:"display_order"`
}
@@ -330,7 +336,10 @@ func (h *AdminLookupHandler) ListPriorities(c echo.Context) error {
func (h *AdminLookupHandler) CreatePriority(c echo.Context) error {
var req CreateUpdatePriorityRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
priority := models.TaskPriority{
@@ -374,7 +383,10 @@ func (h *AdminLookupHandler) UpdatePriority(c echo.Context) error {
var req CreateUpdatePriorityRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
priority.Name = req.Name
@@ -434,7 +446,7 @@ type TaskFrequencyResponse struct {
}
type CreateUpdateFrequencyRequest struct {
Name string `json:"name" binding:"required,max=20"`
Name string `json:"name" validate:"required,max=20"`
Days *int `json:"days"`
DisplayOrder *int `json:"display_order"`
}
@@ -480,7 +492,10 @@ func (h *AdminLookupHandler) ListFrequencies(c echo.Context) error {
func (h *AdminLookupHandler) CreateFrequency(c echo.Context) error {
var req CreateUpdateFrequencyRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
frequency := models.TaskFrequency{
@@ -528,7 +543,10 @@ func (h *AdminLookupHandler) UpdateFrequency(c echo.Context) error {
var req CreateUpdateFrequencyRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
frequency.Name = req.Name
@@ -588,7 +606,7 @@ type ResidenceTypeResponse struct {
}
type CreateUpdateResidenceTypeRequest struct {
Name string `json:"name" binding:"required,max=20"`
Name string `json:"name" validate:"required,max=20"`
}
func (h *AdminLookupHandler) ListResidenceTypes(c echo.Context) error {
@@ -611,7 +629,10 @@ func (h *AdminLookupHandler) ListResidenceTypes(c echo.Context) error {
func (h *AdminLookupHandler) CreateResidenceType(c echo.Context) error {
var req CreateUpdateResidenceTypeRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
residenceType := models.ResidenceType{Name: req.Name}
@@ -644,7 +665,10 @@ func (h *AdminLookupHandler) UpdateResidenceType(c echo.Context) error {
var req CreateUpdateResidenceTypeRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
residenceType.Name = req.Name
@@ -694,9 +718,9 @@ type ContractorSpecialtyResponse struct {
}
type CreateUpdateSpecialtyRequest struct {
Name string `json:"name" binding:"required,max=50"`
Name string `json:"name" validate:"required,max=50"`
Description string `json:"description"`
Icon string `json:"icon" binding:"max=50"`
Icon string `json:"icon" validate:"max=50"`
DisplayOrder *int `json:"display_order"`
}
@@ -723,7 +747,10 @@ func (h *AdminLookupHandler) ListSpecialties(c echo.Context) error {
func (h *AdminLookupHandler) CreateSpecialty(c echo.Context) error {
var req CreateUpdateSpecialtyRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
specialty := models.ContractorSpecialty{
@@ -767,7 +794,10 @@ func (h *AdminLookupHandler) UpdateSpecialty(c echo.Context) error {
var req CreateUpdateSpecialtyRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
specialty.Name = req.Name

View File

@@ -8,6 +8,7 @@ import (
"time"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"github.com/treytartt/honeydue-api/internal/admin/dto"
@@ -36,7 +37,7 @@ func NewAdminNotificationHandler(db *gorm.DB, emailService *services.EmailServic
func (h *AdminNotificationHandler) List(c echo.Context) error {
var filters dto.NotificationFilters
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var notifications []models.Notification
@@ -151,7 +152,7 @@ func (h *AdminNotificationHandler) Update(c echo.Context) error {
var req dto.UpdateNotificationRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
now := time.Now().UTC()
@@ -235,7 +236,7 @@ func (h *AdminNotificationHandler) toNotificationDetailResponse(notif *models.No
func (h *AdminNotificationHandler) SendTestNotification(c echo.Context) error {
var req dto.SendTestNotificationRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
// Verify user exists
@@ -294,13 +295,10 @@ func (h *AdminNotificationHandler) SendTestNotification(c echo.Context) error {
err := h.pushClient.SendToAll(ctx, iosTokens, androidTokens, req.Title, req.Body, pushData)
if err != nil {
// Update notification with error
h.db.Model(&notification).Updates(map[string]interface{}{
"error": err.Error(),
})
// Log the real error for debugging
log.Error().Err(err).Uint("notification_id", notification.ID).Msg("Failed to send push notification")
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": "Failed to send push notification",
"details": err.Error(),
"error": "Failed to send push notification",
})
}
} else {
@@ -327,7 +325,7 @@ func (h *AdminNotificationHandler) SendTestNotification(c echo.Context) error {
func (h *AdminNotificationHandler) SendTestEmail(c echo.Context) error {
var req dto.SendTestEmailRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
// Verify user exists
@@ -369,9 +367,9 @@ func (h *AdminNotificationHandler) SendTestEmail(c echo.Context) error {
err := h.emailService.SendEmail(user.Email, req.Subject, htmlBody, req.Body)
if err != nil {
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send test email")
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": "Failed to send email",
"details": err.Error(),
"error": "Failed to send email",
})
}
@@ -384,11 +382,14 @@ func (h *AdminNotificationHandler) SendTestEmail(c echo.Context) error {
// SendPostVerificationEmail handles POST /api/admin/emails/send-post-verification
func (h *AdminNotificationHandler) SendPostVerificationEmail(c echo.Context) error {
var req struct {
UserID uint `json:"user_id" binding:"required"`
UserID uint `json:"user_id" validate:"required"`
}
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "user_id is required"})
}
if err := c.Validate(&req); err != nil {
return err
}
// Verify user exists
var user models.User
@@ -410,9 +411,9 @@ func (h *AdminNotificationHandler) SendPostVerificationEmail(c echo.Context) err
err := h.emailService.SendPostVerificationEmail(user.Email, user.FirstName)
if err != nil {
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send post-verification email")
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
"error": "Failed to send email",
"details": err.Error(),
"error": "Failed to send email",
})
}

View File

@@ -55,7 +55,7 @@ type NotificationPrefResponse struct {
func (h *AdminNotificationPrefsHandler) List(c echo.Context) error {
var filters dto.PaginationParams
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var prefs []models.NotificationPreference
@@ -212,7 +212,7 @@ func (h *AdminNotificationPrefsHandler) Update(c echo.Context) error {
var req UpdateNotificationPrefRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
// Apply updates

View File

@@ -240,7 +240,7 @@ func (h *AdminOnboardingHandler) Delete(c echo.Context) error {
// DELETE /api/admin/onboarding-emails/bulk
func (h *AdminOnboardingHandler) BulkDelete(c echo.Context) error {
var req struct {
IDs []uint `json:"ids" binding:"required"`
IDs []uint `json:"ids" validate:"required"`
}
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request"})
@@ -263,8 +263,8 @@ func (h *AdminOnboardingHandler) BulkDelete(c echo.Context) error {
// SendOnboardingEmailRequest represents a request to send an onboarding email
type SendOnboardingEmailRequest struct {
UserID uint `json:"user_id" binding:"required"`
EmailType string `json:"email_type" binding:"required"`
UserID uint `json:"user_id" validate:"required"`
EmailType string `json:"email_type" validate:"required"`
}
// Send sends an onboarding email to a specific user
@@ -278,6 +278,9 @@ func (h *AdminOnboardingHandler) Send(c echo.Context) error {
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request: user_id and email_type are required"})
}
if err := c.Validate(&req); err != nil {
return err
}
// Validate email type
var emailType models.OnboardingEmailType
@@ -301,7 +304,7 @@ func (h *AdminOnboardingHandler) Send(c echo.Context) error {
// Send the email
if err := h.onboardingService.SendOnboardingEmailToUser(req.UserID, emailType); err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to send onboarding email"})
}
return c.JSON(http.StatusOK, map[string]interface{}{

View File

@@ -39,7 +39,7 @@ type PasswordResetCodeResponse struct {
func (h *AdminPasswordResetCodeHandler) List(c echo.Context) error {
var filters dto.PaginationParams
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var codes []models.PasswordResetCode
@@ -147,7 +147,7 @@ func (h *AdminPasswordResetCodeHandler) Delete(c echo.Context) error {
func (h *AdminPasswordResetCodeHandler) BulkDelete(c echo.Context) error {
var req dto.BulkDeleteRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
result := h.db.Where("id IN ?", req.IDs).Delete(&models.PasswordResetCode{})

View File

@@ -41,7 +41,7 @@ type PromotionResponse struct {
func (h *AdminPromotionHandler) List(c echo.Context) error {
var filters dto.PaginationParams
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var promotions []models.Promotion
@@ -123,18 +123,21 @@ func (h *AdminPromotionHandler) Get(c echo.Context) error {
// Create handles POST /api/admin/promotions
func (h *AdminPromotionHandler) Create(c echo.Context) error {
var req struct {
PromotionID string `json:"promotion_id" binding:"required"`
Title string `json:"title" binding:"required"`
Message string `json:"message" binding:"required"`
PromotionID string `json:"promotion_id" validate:"required"`
Title string `json:"title" validate:"required"`
Message string `json:"message" validate:"required"`
Link *string `json:"link"`
StartDate string `json:"start_date" binding:"required"`
EndDate string `json:"end_date" binding:"required"`
StartDate string `json:"start_date" validate:"required"`
EndDate string `json:"end_date" validate:"required"`
TargetTier string `json:"target_tier"`
IsActive *bool `json:"is_active"`
}
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
startDate, err := time.Parse("2006-01-02T15:04:05Z", req.StartDate)
@@ -219,7 +222,7 @@ func (h *AdminPromotionHandler) Update(c echo.Context) error {
}
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if req.PromotionID != nil {

View File

@@ -27,7 +27,7 @@ func NewAdminResidenceHandler(db *gorm.DB) *AdminResidenceHandler {
func (h *AdminResidenceHandler) List(c echo.Context) error {
var filters dto.ResidenceFilters
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var residences []models.Residence
@@ -143,7 +143,7 @@ func (h *AdminResidenceHandler) Update(c echo.Context) error {
var req dto.UpdateResidenceRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if req.OwnerID != nil {
@@ -204,8 +204,7 @@ func (h *AdminResidenceHandler) Update(c echo.Context) error {
}
}
if req.PurchasePrice != nil {
d := decimal.NewFromFloat(*req.PurchasePrice)
residence.PurchasePrice = &d
residence.PurchasePrice = req.PurchasePrice
}
if req.IsActive != nil {
residence.IsActive = *req.IsActive
@@ -226,7 +225,7 @@ func (h *AdminResidenceHandler) Update(c echo.Context) error {
func (h *AdminResidenceHandler) Create(c echo.Context) error {
var req dto.CreateResidenceRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
// Verify owner exists
@@ -300,7 +299,7 @@ func (h *AdminResidenceHandler) Delete(c echo.Context) error {
func (h *AdminResidenceHandler) BulkDelete(c echo.Context) error {
var req dto.BulkDeleteRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
// Soft delete - deactivate all

View File

@@ -47,7 +47,9 @@ func (h *AdminSettingsHandler) GetSettings(c echo.Context) error {
TrialEnabled: true,
TrialDurationDays: 14,
}
h.db.Create(&settings)
if err := h.db.Create(&settings).Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to create default settings"})
}
} else {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch settings"})
}
@@ -73,7 +75,7 @@ type UpdateSettingsRequest struct {
func (h *AdminSettingsHandler) UpdateSettings(c echo.Context) error {
var req UpdateSettingsRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var settings models.SubscriptionSettings
@@ -123,12 +125,12 @@ func (h *AdminSettingsHandler) UpdateSettings(c echo.Context) error {
func (h *AdminSettingsHandler) SeedLookups(c echo.Context) error {
// First seed lookup tables
if err := h.runSeedFile("001_lookups.sql"); err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed lookups: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed lookups"})
}
// Then seed task templates
if err := h.runSeedFile("003_task_templates.sql"); err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed task templates: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed task templates"})
}
// Cache all lookups in Redis
@@ -349,7 +351,7 @@ func parseTags(tags string) []string {
// SeedTestData handles POST /api/admin/settings/seed-test-data
func (h *AdminSettingsHandler) SeedTestData(c echo.Context) error {
if err := h.runSeedFile("002_test_data.sql"); err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed test data: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed test data"})
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Test data seeded successfully"})
@@ -358,7 +360,7 @@ func (h *AdminSettingsHandler) SeedTestData(c echo.Context) error {
// SeedTaskTemplates handles POST /api/admin/settings/seed-task-templates
func (h *AdminSettingsHandler) SeedTaskTemplates(c echo.Context) error {
if err := h.runSeedFile("003_task_templates.sql"); err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed task templates: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed task templates"})
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Task templates seeded successfully"})
@@ -590,38 +592,38 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
// 1. Delete task completion images
if err := tx.Exec("DELETE FROM task_taskcompletionimage").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete task completion images: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete task completion images"})
}
// 2. Delete task completions
if err := tx.Exec("DELETE FROM task_taskcompletion").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete task completions: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete task completions"})
}
// 3. Delete notifications (must be before tasks since notifications have task_id FK)
if len(preservedUserIDs) > 0 {
if err := tx.Exec("DELETE FROM notifications_notification WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notifications: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notifications"})
}
} else {
if err := tx.Exec("DELETE FROM notifications_notification").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notifications: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notifications"})
}
}
// 4. Delete document images
if err := tx.Exec("DELETE FROM task_documentimage").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete document images: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete document images"})
}
// 5. Delete documents
if err := tx.Exec("DELETE FROM task_document").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete documents: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete documents"})
}
// 6. Delete task reminder logs (must be before tasks since reminder logs have task_id FK)
@@ -631,64 +633,64 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
if tableExists {
if err := tx.Exec("DELETE FROM task_reminderlog").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete task reminder logs: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete task reminder logs"})
}
}
// 7. Delete tasks (must be before contractors since tasks reference contractors)
if err := tx.Exec("DELETE FROM task_task").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete tasks: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete tasks"})
}
// 8. Delete contractor specialties (many-to-many)
if err := tx.Exec("DELETE FROM task_contractor_specialties").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete contractor specialties: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete contractor specialties"})
}
// 9. Delete contractors
if err := tx.Exec("DELETE FROM task_contractor").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete contractors: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete contractors"})
}
// 10. Delete residence_users (many-to-many for shared residences)
if err := tx.Exec("DELETE FROM residence_residence_users").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete residence users: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete residence users"})
}
// 11. Delete residence share codes (must be before residences since share codes have residence_id FK)
if err := tx.Exec("DELETE FROM residence_residencesharecode").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete residence share codes: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete residence share codes"})
}
// 12. Delete residences
if err := tx.Exec("DELETE FROM residence_residence").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete residences: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete residences"})
}
// 13. Delete push devices for non-superusers (both APNS and GCM)
if len(preservedUserIDs) > 0 {
if err := tx.Exec("DELETE FROM push_notifications_apnsdevice WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete APNS devices: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete APNS devices"})
}
if err := tx.Exec("DELETE FROM push_notifications_gcmdevice WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete GCM devices: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete GCM devices"})
}
} else {
if err := tx.Exec("DELETE FROM push_notifications_apnsdevice").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete APNS devices: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete APNS devices"})
}
if err := tx.Exec("DELETE FROM push_notifications_gcmdevice").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete GCM devices: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete GCM devices"})
}
}
@@ -696,12 +698,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
if len(preservedUserIDs) > 0 {
if err := tx.Exec("DELETE FROM notifications_notificationpreference WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notification preferences: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notification preferences"})
}
} else {
if err := tx.Exec("DELETE FROM notifications_notificationpreference").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notification preferences: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notification preferences"})
}
}
@@ -709,12 +711,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
if len(preservedUserIDs) > 0 {
if err := tx.Exec("DELETE FROM subscription_usersubscription WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete subscriptions: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete subscriptions"})
}
} else {
if err := tx.Exec("DELETE FROM subscription_usersubscription").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete subscriptions: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete subscriptions"})
}
}
@@ -722,12 +724,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
if len(preservedUserIDs) > 0 {
if err := tx.Exec("DELETE FROM user_passwordresetcode WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete password reset codes: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete password reset codes"})
}
} else {
if err := tx.Exec("DELETE FROM user_passwordresetcode").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete password reset codes: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete password reset codes"})
}
}
@@ -735,12 +737,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
if len(preservedUserIDs) > 0 {
if err := tx.Exec("DELETE FROM user_confirmationcode WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete confirmation codes: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete confirmation codes"})
}
} else {
if err := tx.Exec("DELETE FROM user_confirmationcode").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete confirmation codes: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete confirmation codes"})
}
}
@@ -748,12 +750,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
if len(preservedUserIDs) > 0 {
if err := tx.Exec("DELETE FROM user_authtoken WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete auth tokens: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete auth tokens"})
}
} else {
if err := tx.Exec("DELETE FROM user_authtoken").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete auth tokens: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete auth tokens"})
}
}
@@ -761,12 +763,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
if len(preservedUserIDs) > 0 {
if err := tx.Exec("DELETE FROM user_applesocialauth WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth"})
}
} else {
if err := tx.Exec("DELETE FROM user_applesocialauth").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth"})
}
}
@@ -774,12 +776,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
if len(preservedUserIDs) > 0 {
if err := tx.Exec("DELETE FROM user_userprofile WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete user profiles: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete user profiles"})
}
} else {
if err := tx.Exec("DELETE FROM user_userprofile").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete user profiles: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete user profiles"})
}
}
@@ -790,12 +792,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
if len(preservedUserIDs) > 0 {
if err := tx.Exec("DELETE FROM onboarding_emails WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete onboarding emails: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete onboarding emails"})
}
} else {
if err := tx.Exec("DELETE FROM onboarding_emails").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete onboarding emails: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete onboarding emails"})
}
}
}
@@ -804,12 +806,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
// Always filter by is_superuser to be safe, regardless of preservedUserIDs
if err := tx.Exec("DELETE FROM auth_user WHERE is_superuser = false").Error; err != nil {
tx.Rollback()
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete users: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete users"})
}
// Commit the transaction
if err := tx.Commit().Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to commit transaction: " + err.Error()})
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to commit transaction"})
}
return c.JSON(http.StatusOK, ClearAllDataResponse{

View File

@@ -38,7 +38,7 @@ type ShareCodeResponse struct {
func (h *AdminShareCodeHandler) List(c echo.Context) error {
var filters dto.PaginationParams
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var codes []models.ResidenceShareCode
@@ -156,7 +156,7 @@ func (h *AdminShareCodeHandler) Update(c echo.Context) error {
IsActive *bool `json:"is_active"`
}
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
// Only update IsActive when explicitly provided (non-nil).
@@ -216,7 +216,7 @@ func (h *AdminShareCodeHandler) Delete(c echo.Context) error {
func (h *AdminShareCodeHandler) BulkDelete(c echo.Context) error {
var req dto.BulkDeleteRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
result := h.db.Where("id IN ?", req.IDs).Delete(&models.ResidenceShareCode{})

View File

@@ -26,7 +26,7 @@ func NewAdminSubscriptionHandler(db *gorm.DB) *AdminSubscriptionHandler {
func (h *AdminSubscriptionHandler) List(c echo.Context) error {
var filters dto.SubscriptionFilters
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var subscriptions []models.UserSubscription
@@ -38,8 +38,8 @@ func (h *AdminSubscriptionHandler) List(c echo.Context) error {
// Apply search (search by user email)
if filters.Search != "" {
search := "%" + filters.Search + "%"
query = query.Joins("JOIN users ON users.id = subscription_usersubscription.user_id").
Where("users.email ILIKE ? OR users.username ILIKE ?", search, search)
query = query.Joins("JOIN auth_user ON auth_user.id = subscription_usersubscription.user_id").
Where("auth_user.email ILIKE ? OR auth_user.username ILIKE ?", search, search)
}
// Apply filters
@@ -140,7 +140,7 @@ func (h *AdminSubscriptionHandler) Update(c echo.Context) error {
var req dto.UpdateSubscriptionRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if req.Tier != nil {

View File

@@ -6,7 +6,6 @@ import (
"time"
"github.com/labstack/echo/v4"
"github.com/shopspring/decimal"
"gorm.io/gorm"
"github.com/treytartt/honeydue-api/internal/admin/dto"
@@ -27,7 +26,7 @@ func NewAdminTaskHandler(db *gorm.DB) *AdminTaskHandler {
func (h *AdminTaskHandler) List(c echo.Context) error {
var filters dto.TaskFilters
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var tasks []models.Task
@@ -149,7 +148,7 @@ func (h *AdminTaskHandler) Update(c echo.Context) error {
var req dto.UpdateTaskRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
// Verify residence if changing
@@ -216,10 +215,10 @@ func (h *AdminTaskHandler) Update(c echo.Context) error {
}
}
if req.EstimatedCost != nil {
updates["estimated_cost"] = decimal.NewFromFloat(*req.EstimatedCost)
updates["estimated_cost"] = *req.EstimatedCost
}
if req.ActualCost != nil {
updates["actual_cost"] = decimal.NewFromFloat(*req.ActualCost)
updates["actual_cost"] = *req.ActualCost
}
if req.ContractorID != nil {
updates["contractor_id"] = *req.ContractorID
@@ -248,7 +247,7 @@ func (h *AdminTaskHandler) Update(c echo.Context) error {
func (h *AdminTaskHandler) Create(c echo.Context) error {
var req dto.CreateTaskRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
// Verify residence exists
@@ -285,8 +284,7 @@ func (h *AdminTaskHandler) Create(c echo.Context) error {
}
}
if req.EstimatedCost != nil {
d := decimal.NewFromFloat(*req.EstimatedCost)
task.EstimatedCost = &d
task.EstimatedCost = req.EstimatedCost
}
if err := h.db.Create(&task).Error; err != nil {
@@ -326,7 +324,7 @@ func (h *AdminTaskHandler) Delete(c echo.Context) error {
func (h *AdminTaskHandler) BulkDelete(c echo.Context) error {
var req dto.BulkDeleteRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
// Soft delete - archive and cancel all

View File

@@ -28,6 +28,8 @@ func NewAdminTaskTemplateHandler(db *gorm.DB) *AdminTaskTemplateHandler {
func (h *AdminTaskTemplateHandler) refreshTaskTemplatesCache(ctx context.Context) {
cache := services.GetCache()
if cache == nil {
log.Warn().Msg("Cache service unavailable, skipping task templates cache refresh")
return
}
var templates []models.TaskTemplate
@@ -68,12 +70,12 @@ type TaskTemplateResponse struct {
// CreateUpdateTaskTemplateRequest represents the request body for creating/updating templates
type CreateUpdateTaskTemplateRequest struct {
Title string `json:"title" binding:"required,max=200"`
Title string `json:"title" validate:"required,max=200"`
Description string `json:"description"`
CategoryID *uint `json:"category_id"`
FrequencyID *uint `json:"frequency_id"`
IconIOS string `json:"icon_ios" binding:"max=100"`
IconAndroid string `json:"icon_android" binding:"max=100"`
IconIOS string `json:"icon_ios" validate:"max=100"`
IconAndroid string `json:"icon_android" validate:"max=100"`
Tags string `json:"tags"`
DisplayOrder *int `json:"display_order"`
IsActive *bool `json:"is_active"`
@@ -140,7 +142,10 @@ func (h *AdminTaskTemplateHandler) GetTemplate(c echo.Context) error {
func (h *AdminTaskTemplateHandler) CreateTemplate(c echo.Context) error {
var req CreateUpdateTaskTemplateRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
template := models.TaskTemplate{
@@ -191,7 +196,10 @@ func (h *AdminTaskTemplateHandler) UpdateTemplate(c echo.Context) error {
var req CreateUpdateTaskTemplateRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
template.Title = req.Title
@@ -271,10 +279,13 @@ func (h *AdminTaskTemplateHandler) ToggleActive(c echo.Context) error {
// BulkCreate handles POST /admin/api/task-templates/bulk/
func (h *AdminTaskTemplateHandler) BulkCreate(c echo.Context) error {
var req struct {
Templates []CreateUpdateTaskTemplateRequest `json:"templates" binding:"required,dive"`
Templates []CreateUpdateTaskTemplateRequest `json:"templates" validate:"required,dive"`
}
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := c.Validate(&req); err != nil {
return err
}
templates := make([]models.TaskTemplate, len(req.Templates))

View File

@@ -3,14 +3,25 @@ package handlers
import (
"net/http"
"strconv"
"strings"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"github.com/treytartt/honeydue-api/internal/admin/dto"
"github.com/treytartt/honeydue-api/internal/models"
)
// escapeLikeWildcards escapes SQL LIKE wildcards (%, _) in user input
// to prevent wildcard injection in LIKE queries.
func escapeLikeWildcards(s string) string {
s = strings.ReplaceAll(s, "\\", "\\\\")
s = strings.ReplaceAll(s, "%", "\\%")
s = strings.ReplaceAll(s, "_", "\\_")
return s
}
// AdminUserHandler handles admin user management endpoints
type AdminUserHandler struct {
db *gorm.DB
@@ -25,7 +36,7 @@ func NewAdminUserHandler(db *gorm.DB) *AdminUserHandler {
func (h *AdminUserHandler) List(c echo.Context) error {
var filters dto.UserFilters
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var users []models.User
@@ -35,7 +46,7 @@ func (h *AdminUserHandler) List(c echo.Context) error {
// Apply search
if filters.Search != "" {
search := "%" + filters.Search + "%"
search := "%" + escapeLikeWildcards(filters.Search) + "%"
query = query.Where(
"username ILIKE ? OR email ILIKE ? OR first_name ILIKE ? OR last_name ILIKE ?",
search, search, search, search,
@@ -70,10 +81,49 @@ func (h *AdminUserHandler) List(c echo.Context) error {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch users"})
}
// Batch COUNT queries for residence and task counts instead of N+1
userIDs := make([]uint, len(users))
for i, user := range users {
userIDs[i] = user.ID
}
type countResult struct {
OwnerID uint
Count int64
}
residenceCounts := make(map[uint]int64)
taskCounts := make(map[uint]int64)
if len(userIDs) > 0 {
var resCounts []countResult
h.db.Model(&models.Residence{}).
Select("owner_id, COUNT(*) as count").
Where("owner_id IN ?", userIDs).
Group("owner_id").
Scan(&resCounts)
for _, rc := range resCounts {
residenceCounts[rc.OwnerID] = rc.Count
}
var tskCounts []struct {
CreatedByID uint
Count int64
}
h.db.Model(&models.Task{}).
Select("created_by_id, COUNT(*) as count").
Where("created_by_id IN ?", userIDs).
Group("created_by_id").
Scan(&tskCounts)
for _, tc := range tskCounts {
taskCounts[tc.CreatedByID] = tc.Count
}
}
// Build response
responses := make([]dto.UserResponse, len(users))
for i, user := range users {
responses[i] = h.toUserResponse(&user)
responses[i] = h.toUserResponseWithCounts(&user, int(residenceCounts[user.ID]), int(taskCounts[user.ID]))
}
return c.JSON(http.StatusOK, dto.NewPaginatedResponse(responses, total, filters.GetPage(), filters.GetPerPage()))
@@ -122,7 +172,7 @@ func (h *AdminUserHandler) Get(c echo.Context) error {
func (h *AdminUserHandler) Create(c echo.Context) error {
var req dto.CreateUserRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
// Check if username exists
@@ -170,7 +220,9 @@ func (h *AdminUserHandler) Create(c echo.Context) error {
UserID: user.ID,
PhoneNumber: req.PhoneNumber,
}
h.db.Create(&profile)
if err := h.db.Create(&profile).Error; err != nil {
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create user profile")
}
// Reload with profile
h.db.Preload("Profile").First(&user, user.ID)
@@ -194,7 +246,7 @@ func (h *AdminUserHandler) Update(c echo.Context) error {
var req dto.UpdateUserRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
// Check username uniqueness if changing
@@ -298,7 +350,7 @@ func (h *AdminUserHandler) Delete(c echo.Context) error {
func (h *AdminUserHandler) BulkDelete(c echo.Context) error {
var req dto.BulkDeleteRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
// Soft delete - deactivate all
@@ -309,8 +361,30 @@ func (h *AdminUserHandler) BulkDelete(c echo.Context) error {
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Users deactivated successfully", "count": len(req.IDs)})
}
// toUserResponse converts a User model to UserResponse DTO
// toUserResponseWithCounts converts a User model to UserResponse DTO with pre-computed counts
func (h *AdminUserHandler) toUserResponseWithCounts(user *models.User, residenceCount, taskCount int) dto.UserResponse {
response := h.buildUserResponse(user)
response.ResidenceCount = residenceCount
response.TaskCount = taskCount
return response
}
// toUserResponse converts a User model to UserResponse DTO (fetches counts individually)
func (h *AdminUserHandler) toUserResponse(user *models.User) dto.UserResponse {
response := h.buildUserResponse(user)
// Get counts individually (used for single-user views)
var residenceCount, taskCount int64
h.db.Model(&models.Residence{}).Where("owner_id = ?", user.ID).Count(&residenceCount)
h.db.Model(&models.Task{}).Where("created_by_id = ?", user.ID).Count(&taskCount)
response.ResidenceCount = int(residenceCount)
response.TaskCount = int(taskCount)
return response
}
// buildUserResponse creates the base UserResponse without counts
func (h *AdminUserHandler) buildUserResponse(user *models.User) dto.UserResponse {
response := dto.UserResponse{
ID: user.ID,
Username: user.Username,
@@ -335,12 +409,5 @@ func (h *AdminUserHandler) toUserResponse(user *models.User) dto.UserResponse {
}
}
// Get counts
var residenceCount, taskCount int64
h.db.Model(&models.Residence{}).Where("owner_id = ?", user.ID).Count(&residenceCount)
h.db.Model(&models.Task{}).Where("created_by_id = ?", user.ID).Count(&taskCount)
response.ResidenceCount = int(residenceCount)
response.TaskCount = int(taskCount)
return response
}

View File

@@ -50,7 +50,7 @@ type UpdateUserProfileRequest struct {
func (h *AdminUserProfileHandler) List(c echo.Context) error {
var filters dto.PaginationParams
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var profiles []models.UserProfile
@@ -144,7 +144,7 @@ func (h *AdminUserProfileHandler) Update(c echo.Context) error {
var req UpdateUserProfileRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if req.Verified != nil {
@@ -205,14 +205,15 @@ func (h *AdminUserProfileHandler) Delete(c echo.Context) error {
func (h *AdminUserProfileHandler) BulkDelete(c echo.Context) error {
var req dto.BulkDeleteRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if err := h.db.Where("id IN ?", req.IDs).Delete(&models.UserProfile{}).Error; err != nil {
result := h.db.Where("id IN ?", req.IDs).Delete(&models.UserProfile{})
if result.Error != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete user profiles"})
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "User profiles deleted successfully", "count": len(req.IDs)})
return c.JSON(http.StatusOK, map[string]interface{}{"message": "User profiles deleted successfully", "count": result.RowsAffected})
}
// toProfileResponse converts a UserProfile model to UserProfileResponse

View File

@@ -379,9 +379,10 @@ func SetupRoutes(router *echo.Echo, db *gorm.DB, cfg *config.Config, deps *Depen
documentImages.DELETE("/:id", documentImageHandler.Delete)
}
// System settings management
// System settings management (super admin only)
settingsHandler := handlers.NewAdminSettingsHandler(db)
settings := protected.Group("/settings")
settings.Use(middleware.RequireSuperAdmin())
{
settings.GET("", settingsHandler.GetSettings)
settings.PUT("", settingsHandler.UpdateSettings)
@@ -500,9 +501,21 @@ func setupAdminProxy(router *echo.Echo) {
})
}
// Proxy Next.js static assets (served from /_next/ regardless of host)
// Proxy Next.js static assets — only when the request comes from a
// known host to prevent SSRF via crafted Host headers.
var allowedProxyHosts []string
if adminHost != "" {
allowedProxyHosts = append(allowedProxyHosts, adminHost)
}
// Also allow localhost variants for development
allowedProxyHosts = append(allowedProxyHosts,
"localhost:3001", "127.0.0.1:3001",
"localhost:8000", "127.0.0.1:8000",
)
hostCheck := middleware.HostCheck(allowedProxyHosts)
router.Any("/_next/*", func(c echo.Context) error {
proxy.ServeHTTP(c.Response(), c.Request())
return nil
})
}, hostCheck)
}

View File

@@ -6,6 +6,7 @@ import (
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/spf13/viper"
@@ -32,6 +33,7 @@ type Config struct {
type ServerConfig struct {
Port int
Debug bool
DebugFixedCodes bool // Separate from Debug: enables fixed confirmation codes for local testing
AllowedHosts []string
CorsAllowedOrigins []string // Comma-separated origins for CORS (production only; debug uses wildcard)
Timezone string
@@ -75,7 +77,12 @@ type PushConfig struct {
APNSSandbox bool
APNSProduction bool // If true, use production APNs; if false, use sandbox
// FCM (Android) - uses direct HTTP to FCM legacy API
// FCM (Android) - uses FCM HTTP v1 API with OAuth 2.0
FCMProjectID string // Firebase project ID (required for v1 API)
FCMServiceAccountPath string // Path to Google service account JSON file
FCMServiceAccountJSON string // Raw service account JSON (alternative to path, e.g. for env var injection)
// Deprecated: FCMServerKey is for the legacy HTTP API. Use FCMProjectID + service account instead.
FCMServerKey string
}
@@ -147,135 +154,166 @@ type FeatureFlags struct {
WorkerEnabled bool // FEATURE_WORKER_ENABLED (default: true)
}
var cfg *Config
var (
cfg *Config
cfgOnce sync.Once
)
// knownWeakSecretKeys contains well-known default or placeholder secret keys
// that must not be used in production.
var knownWeakSecretKeys = map[string]bool{
"secret": true,
"changeme": true,
"change-me": true,
"password": true,
"change-me-in-production-secret-key-12345": true,
}
// Load reads configuration from environment variables
func Load() (*Config, error) {
viper.SetEnvPrefix("")
viper.AutomaticEnv()
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
var loadErr error
// Set defaults
setDefaults()
cfgOnce.Do(func() {
viper.SetEnvPrefix("")
viper.AutomaticEnv()
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
// Parse DATABASE_URL if set (Dokku-style)
dbConfig := DatabaseConfig{
Host: viper.GetString("DB_HOST"),
Port: viper.GetInt("DB_PORT"),
User: viper.GetString("POSTGRES_USER"),
Password: viper.GetString("POSTGRES_PASSWORD"),
Database: viper.GetString("POSTGRES_DB"),
SSLMode: viper.GetString("DB_SSLMODE"),
MaxOpenConns: viper.GetInt("DB_MAX_OPEN_CONNS"),
MaxIdleConns: viper.GetInt("DB_MAX_IDLE_CONNS"),
MaxLifetime: viper.GetDuration("DB_MAX_LIFETIME"),
}
// Set defaults
setDefaults()
// Override with DATABASE_URL if present
if databaseURL := viper.GetString("DATABASE_URL"); databaseURL != "" {
parsed, err := parseDatabaseURL(databaseURL)
if err == nil {
dbConfig.Host = parsed.Host
dbConfig.Port = parsed.Port
dbConfig.User = parsed.User
dbConfig.Password = parsed.Password
dbConfig.Database = parsed.Database
if parsed.SSLMode != "" {
dbConfig.SSLMode = parsed.SSLMode
// Parse DATABASE_URL if set (Dokku-style)
dbConfig := DatabaseConfig{
Host: viper.GetString("DB_HOST"),
Port: viper.GetInt("DB_PORT"),
User: viper.GetString("POSTGRES_USER"),
Password: viper.GetString("POSTGRES_PASSWORD"),
Database: viper.GetString("POSTGRES_DB"),
SSLMode: viper.GetString("DB_SSLMODE"),
MaxOpenConns: viper.GetInt("DB_MAX_OPEN_CONNS"),
MaxIdleConns: viper.GetInt("DB_MAX_IDLE_CONNS"),
MaxLifetime: viper.GetDuration("DB_MAX_LIFETIME"),
}
// Override with DATABASE_URL if present (F-16: log warning on parse failure)
if databaseURL := viper.GetString("DATABASE_URL"); databaseURL != "" {
parsed, err := parseDatabaseURL(databaseURL)
if err != nil {
maskedURL := MaskURLCredentials(databaseURL)
fmt.Printf("WARNING: Failed to parse DATABASE_URL (%s): %v — falling back to individual DB_* env vars\n", maskedURL, err)
} else {
dbConfig.Host = parsed.Host
dbConfig.Port = parsed.Port
dbConfig.User = parsed.User
dbConfig.Password = parsed.Password
dbConfig.Database = parsed.Database
if parsed.SSLMode != "" {
dbConfig.SSLMode = parsed.SSLMode
}
}
}
}
cfg = &Config{
Server: ServerConfig{
Port: viper.GetInt("PORT"),
Debug: viper.GetBool("DEBUG"),
AllowedHosts: strings.Split(viper.GetString("ALLOWED_HOSTS"), ","),
CorsAllowedOrigins: parseCorsOrigins(viper.GetString("CORS_ALLOWED_ORIGINS")),
Timezone: viper.GetString("TIMEZONE"),
StaticDir: viper.GetString("STATIC_DIR"),
BaseURL: viper.GetString("BASE_URL"),
},
Database: dbConfig,
Redis: RedisConfig{
URL: viper.GetString("REDIS_URL"),
Password: viper.GetString("REDIS_PASSWORD"),
DB: viper.GetInt("REDIS_DB"),
},
Email: EmailConfig{
Host: viper.GetString("EMAIL_HOST"),
Port: viper.GetInt("EMAIL_PORT"),
User: viper.GetString("EMAIL_HOST_USER"),
Password: viper.GetString("EMAIL_HOST_PASSWORD"),
From: viper.GetString("DEFAULT_FROM_EMAIL"),
UseTLS: viper.GetBool("EMAIL_USE_TLS"),
},
Push: PushConfig{
APNSKeyPath: viper.GetString("APNS_AUTH_KEY_PATH"),
APNSKeyID: viper.GetString("APNS_AUTH_KEY_ID"),
APNSTeamID: viper.GetString("APNS_TEAM_ID"),
APNSTopic: viper.GetString("APNS_TOPIC"),
APNSSandbox: viper.GetBool("APNS_USE_SANDBOX"),
APNSProduction: viper.GetBool("APNS_PRODUCTION"),
FCMServerKey: viper.GetString("FCM_SERVER_KEY"),
},
Worker: WorkerConfig{
TaskReminderHour: viper.GetInt("TASK_REMINDER_HOUR"),
OverdueReminderHour: viper.GetInt("OVERDUE_REMINDER_HOUR"),
DailyNotifHour: viper.GetInt("DAILY_DIGEST_HOUR"),
},
Security: SecurityConfig{
SecretKey: viper.GetString("SECRET_KEY"),
TokenCacheTTL: 5 * time.Minute,
PasswordResetExpiry: 15 * time.Minute,
ConfirmationExpiry: 24 * time.Hour,
MaxPasswordResetRate: 3,
},
Storage: StorageConfig{
UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"),
BaseURL: viper.GetString("STORAGE_BASE_URL"),
MaxFileSize: viper.GetInt64("STORAGE_MAX_FILE_SIZE"),
AllowedTypes: viper.GetString("STORAGE_ALLOWED_TYPES"),
},
AppleAuth: AppleAuthConfig{
ClientID: viper.GetString("APPLE_CLIENT_ID"),
TeamID: viper.GetString("APPLE_TEAM_ID"),
},
GoogleAuth: GoogleAuthConfig{
ClientID: viper.GetString("GOOGLE_CLIENT_ID"),
AndroidClientID: viper.GetString("GOOGLE_ANDROID_CLIENT_ID"),
IOSClientID: viper.GetString("GOOGLE_IOS_CLIENT_ID"),
},
AppleIAP: AppleIAPConfig{
KeyPath: viper.GetString("APPLE_IAP_KEY_PATH"),
KeyID: viper.GetString("APPLE_IAP_KEY_ID"),
IssuerID: viper.GetString("APPLE_IAP_ISSUER_ID"),
BundleID: viper.GetString("APPLE_IAP_BUNDLE_ID"),
Sandbox: viper.GetBool("APPLE_IAP_SANDBOX"),
},
GoogleIAP: GoogleIAPConfig{
ServiceAccountPath: viper.GetString("GOOGLE_IAP_SERVICE_ACCOUNT_PATH"),
PackageName: viper.GetString("GOOGLE_IAP_PACKAGE_NAME"),
},
Stripe: StripeConfig{
SecretKey: viper.GetString("STRIPE_SECRET_KEY"),
WebhookSecret: viper.GetString("STRIPE_WEBHOOK_SECRET"),
PriceMonthly: viper.GetString("STRIPE_PRICE_MONTHLY"),
PriceYearly: viper.GetString("STRIPE_PRICE_YEARLY"),
},
Features: FeatureFlags{
PushEnabled: viper.GetBool("FEATURE_PUSH_ENABLED"),
EmailEnabled: viper.GetBool("FEATURE_EMAIL_ENABLED"),
WebhooksEnabled: viper.GetBool("FEATURE_WEBHOOKS_ENABLED"),
OnboardingEmailsEnabled: viper.GetBool("FEATURE_ONBOARDING_EMAILS_ENABLED"),
PDFReportsEnabled: viper.GetBool("FEATURE_PDF_REPORTS_ENABLED"),
WorkerEnabled: viper.GetBool("FEATURE_WORKER_ENABLED"),
},
}
cfg = &Config{
Server: ServerConfig{
Port: viper.GetInt("PORT"),
Debug: viper.GetBool("DEBUG"),
DebugFixedCodes: viper.GetBool("DEBUG_FIXED_CODES"),
AllowedHosts: strings.Split(viper.GetString("ALLOWED_HOSTS"), ","),
CorsAllowedOrigins: parseCorsOrigins(viper.GetString("CORS_ALLOWED_ORIGINS")),
Timezone: viper.GetString("TIMEZONE"),
StaticDir: viper.GetString("STATIC_DIR"),
BaseURL: viper.GetString("BASE_URL"),
},
Database: dbConfig,
Redis: RedisConfig{
URL: viper.GetString("REDIS_URL"),
Password: viper.GetString("REDIS_PASSWORD"),
DB: viper.GetInt("REDIS_DB"),
},
Email: EmailConfig{
Host: viper.GetString("EMAIL_HOST"),
Port: viper.GetInt("EMAIL_PORT"),
User: viper.GetString("EMAIL_HOST_USER"),
Password: viper.GetString("EMAIL_HOST_PASSWORD"),
From: viper.GetString("DEFAULT_FROM_EMAIL"),
UseTLS: viper.GetBool("EMAIL_USE_TLS"),
},
Push: PushConfig{
APNSKeyPath: viper.GetString("APNS_AUTH_KEY_PATH"),
APNSKeyID: viper.GetString("APNS_AUTH_KEY_ID"),
APNSTeamID: viper.GetString("APNS_TEAM_ID"),
APNSTopic: viper.GetString("APNS_TOPIC"),
APNSSandbox: viper.GetBool("APNS_USE_SANDBOX"),
APNSProduction: viper.GetBool("APNS_PRODUCTION"),
FCMProjectID: viper.GetString("FCM_PROJECT_ID"),
FCMServiceAccountPath: viper.GetString("FCM_SERVICE_ACCOUNT_PATH"),
FCMServiceAccountJSON: viper.GetString("FCM_SERVICE_ACCOUNT_JSON"),
FCMServerKey: viper.GetString("FCM_SERVER_KEY"),
},
Worker: WorkerConfig{
TaskReminderHour: viper.GetInt("TASK_REMINDER_HOUR"),
OverdueReminderHour: viper.GetInt("OVERDUE_REMINDER_HOUR"),
DailyNotifHour: viper.GetInt("DAILY_DIGEST_HOUR"),
},
Security: SecurityConfig{
SecretKey: viper.GetString("SECRET_KEY"),
TokenCacheTTL: 5 * time.Minute,
PasswordResetExpiry: 15 * time.Minute,
ConfirmationExpiry: 24 * time.Hour,
MaxPasswordResetRate: 3,
},
Storage: StorageConfig{
UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"),
BaseURL: viper.GetString("STORAGE_BASE_URL"),
MaxFileSize: viper.GetInt64("STORAGE_MAX_FILE_SIZE"),
AllowedTypes: viper.GetString("STORAGE_ALLOWED_TYPES"),
},
AppleAuth: AppleAuthConfig{
ClientID: viper.GetString("APPLE_CLIENT_ID"),
TeamID: viper.GetString("APPLE_TEAM_ID"),
},
GoogleAuth: GoogleAuthConfig{
ClientID: viper.GetString("GOOGLE_CLIENT_ID"),
AndroidClientID: viper.GetString("GOOGLE_ANDROID_CLIENT_ID"),
IOSClientID: viper.GetString("GOOGLE_IOS_CLIENT_ID"),
},
AppleIAP: AppleIAPConfig{
KeyPath: viper.GetString("APPLE_IAP_KEY_PATH"),
KeyID: viper.GetString("APPLE_IAP_KEY_ID"),
IssuerID: viper.GetString("APPLE_IAP_ISSUER_ID"),
BundleID: viper.GetString("APPLE_IAP_BUNDLE_ID"),
Sandbox: viper.GetBool("APPLE_IAP_SANDBOX"),
},
GoogleIAP: GoogleIAPConfig{
ServiceAccountPath: viper.GetString("GOOGLE_IAP_SERVICE_ACCOUNT_PATH"),
PackageName: viper.GetString("GOOGLE_IAP_PACKAGE_NAME"),
},
Stripe: StripeConfig{
SecretKey: viper.GetString("STRIPE_SECRET_KEY"),
WebhookSecret: viper.GetString("STRIPE_WEBHOOK_SECRET"),
PriceMonthly: viper.GetString("STRIPE_PRICE_MONTHLY"),
PriceYearly: viper.GetString("STRIPE_PRICE_YEARLY"),
},
Features: FeatureFlags{
PushEnabled: viper.GetBool("FEATURE_PUSH_ENABLED"),
EmailEnabled: viper.GetBool("FEATURE_EMAIL_ENABLED"),
WebhooksEnabled: viper.GetBool("FEATURE_WEBHOOKS_ENABLED"),
OnboardingEmailsEnabled: viper.GetBool("FEATURE_ONBOARDING_EMAILS_ENABLED"),
PDFReportsEnabled: viper.GetBool("FEATURE_PDF_REPORTS_ENABLED"),
WorkerEnabled: viper.GetBool("FEATURE_WORKER_ENABLED"),
},
}
// Validate required fields
if err := validate(cfg); err != nil {
return nil, err
// Validate required fields
if err := validate(cfg); err != nil {
loadErr = err
// Reset so a subsequent call can retry after env is fixed
cfg = nil
cfgOnce = sync.Once{}
}
})
if loadErr != nil {
return nil, loadErr
}
return cfg, nil
@@ -290,6 +328,7 @@ func setDefaults() {
// Server defaults
viper.SetDefault("PORT", 8000)
viper.SetDefault("DEBUG", false)
viper.SetDefault("DEBUG_FIXED_CODES", false) // Separate flag for fixed confirmation codes
viper.SetDefault("ALLOWED_HOSTS", "localhost,127.0.0.1")
viper.SetDefault("TIMEZONE", "UTC")
viper.SetDefault("STATIC_DIR", "/app/static")
@@ -347,7 +386,13 @@ func setDefaults() {
viper.SetDefault("FEATURE_WORKER_ENABLED", true)
}
// isWeakSecretKey checks if the provided key is a known weak/default value.
func isWeakSecretKey(key string) bool {
return knownWeakSecretKeys[strings.ToLower(strings.TrimSpace(key))]
}
func validate(cfg *Config) error {
// S-08: Validate SECRET_KEY against known weak defaults
if cfg.Security.SecretKey == "" {
if cfg.Server.Debug {
// In debug mode, use a default key with a warning for local development
@@ -358,9 +403,12 @@ func validate(cfg *Config) error {
// In production, refuse to start without a proper secret key
return fmt.Errorf("FATAL: SECRET_KEY environment variable is required in production (DEBUG=false)")
}
} else if cfg.Security.SecretKey == "change-me-in-production-secret-key-12345" {
// Warn if someone explicitly set the well-known debug key
fmt.Println("WARNING: SECRET_KEY is set to the well-known debug default. Change it for production use.")
} else if isWeakSecretKey(cfg.Security.SecretKey) {
if cfg.Server.Debug {
fmt.Printf("WARNING: SECRET_KEY is set to a well-known weak value (%q). Change it for production use.\n", cfg.Security.SecretKey)
} else {
return fmt.Errorf("FATAL: SECRET_KEY is set to a well-known weak value (%q). Use a strong, unique secret in production", cfg.Security.SecretKey)
}
}
// Database password might come from DATABASE_URL, don't require it separately
@@ -369,6 +417,21 @@ func validate(cfg *Config) error {
return nil
}
// MaskURLCredentials parses a URL and replaces any password with "***".
// If parsing fails, it returns the string "<unparseable-url>" to avoid leaking credentials.
func MaskURLCredentials(rawURL string) string {
u, err := url.Parse(rawURL)
if err != nil {
return "<unparseable-url>"
}
if u.User != nil {
if _, hasPassword := u.User.Password(); hasPassword {
u.User = url.UserPassword(u.User.Username(), "***")
}
}
return u.Redacted()
}
// DSN returns the database connection string
func (d *DatabaseConfig) DSN() string {
return fmt.Sprintf(

View File

@@ -5,6 +5,8 @@ import (
"time"
"github.com/rs/zerolog/log"
"github.com/spf13/viper"
"golang.org/x/crypto/bcrypt"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
@@ -376,13 +378,23 @@ func migrateGoAdmin() error {
}
db.Exec(`CREATE INDEX IF NOT EXISTS idx_goadmin_site_key ON goadmin_site(key)`)
// Seed default admin user only on first run (ON CONFLICT DO NOTHING).
// Seed default GoAdmin user only on first run (ON CONFLICT DO NOTHING).
// Password is NOT reset on subsequent migrations to preserve operator changes.
db.Exec(`
INSERT INTO goadmin_users (username, password, name, avatar)
VALUES ('admin', '$2a$10$t.GCU24EqIWLSl7F51Hdz.IkkgFK.Qa9/BzEc5Bi2C/I2bXf1nJgm', 'Administrator', '')
ON CONFLICT DO NOTHING
`)
goAdminUsername := viper.GetString("GOADMIN_ADMIN_USERNAME")
goAdminPassword := viper.GetString("GOADMIN_ADMIN_PASSWORD")
if goAdminUsername == "" || goAdminPassword == "" {
log.Warn().Msg("GOADMIN_ADMIN_USERNAME and/or GOADMIN_ADMIN_PASSWORD not set; skipping GoAdmin admin user seed")
} else {
goAdminHash, err := bcrypt.GenerateFromPassword([]byte(goAdminPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash GoAdmin admin password: %w", err)
}
db.Exec(`
INSERT INTO goadmin_users (username, password, name, avatar)
VALUES (?, ?, 'Administrator', '')
ON CONFLICT DO NOTHING
`, goAdminUsername, string(goAdminHash))
}
// Seed default roles
db.Exec(`INSERT INTO goadmin_roles (name, slug) VALUES ('Administrator', 'administrator') ON CONFLICT DO NOTHING`)
@@ -393,15 +405,17 @@ func migrateGoAdmin() error {
db.Exec(`INSERT INTO goadmin_permissions (name, slug, http_method, http_path) VALUES ('Dashboard', 'dashboard', 'GET', '/') ON CONFLICT DO NOTHING`)
// Assign admin user to administrator role (if not already assigned)
db.Exec(`
INSERT INTO goadmin_role_users (role_id, user_id)
SELECT r.id, u.id FROM goadmin_roles r, goadmin_users u
WHERE r.slug = 'administrator' AND u.username = 'admin'
AND NOT EXISTS (
SELECT 1 FROM goadmin_role_users ru
WHERE ru.role_id = r.id AND ru.user_id = u.id
)
`)
if goAdminUsername != "" {
db.Exec(`
INSERT INTO goadmin_role_users (role_id, user_id)
SELECT r.id, u.id FROM goadmin_roles r, goadmin_users u
WHERE r.slug = 'administrator' AND u.username = ?
AND NOT EXISTS (
SELECT 1 FROM goadmin_role_users ru
WHERE ru.role_id = r.id AND ru.user_id = u.id
)
`, goAdminUsername)
}
// Assign all permissions to administrator role (if not already assigned)
db.Exec(`
@@ -448,15 +462,25 @@ func migrateGoAdmin() error {
// Seed default Next.js admin user only on first run.
// Password is NOT reset on subsequent migrations to preserve operator changes.
var adminCount int64
db.Raw(`SELECT COUNT(*) FROM admin_users WHERE email = 'admin@honeydue.com'`).Scan(&adminCount)
if adminCount == 0 {
log.Info().Msg("Seeding default admin user for Next.js admin panel...")
db.Exec(`
INSERT INTO admin_users (email, password, first_name, last_name, role, is_active, created_at, updated_at)
VALUES ('admin@honeydue.com', '$2a$10$t5hGjdXQLxr9Z0193qx.Tef6hd1vYI3JvrfX/piKx2qS9UvQ41I9O', 'Admin', 'User', 'super_admin', true, NOW(), NOW())
`)
log.Info().Msg("Default admin user created: admin@honeydue.com")
adminEmail := viper.GetString("ADMIN_EMAIL")
adminPassword := viper.GetString("ADMIN_PASSWORD")
if adminEmail == "" || adminPassword == "" {
log.Warn().Msg("ADMIN_EMAIL and/or ADMIN_PASSWORD not set; skipping Next.js admin user seed")
} else {
var adminCount int64
db.Raw(`SELECT COUNT(*) FROM admin_users WHERE email = ?`, adminEmail).Scan(&adminCount)
if adminCount == 0 {
log.Info().Str("email", adminEmail).Msg("Seeding default admin user for Next.js admin panel...")
adminHash, err := bcrypt.GenerateFromPassword([]byte(adminPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash admin password: %w", err)
}
db.Exec(`
INSERT INTO admin_users (email, password, first_name, last_name, role, is_active, created_at, updated_at)
VALUES (?, ?, 'Admin', 'User', 'super_admin', true, NOW(), NOW())
`, adminEmail, string(adminHash))
log.Info().Str("email", adminEmail).Msg("Default admin user created")
}
}
return nil

View File

@@ -71,7 +71,7 @@ type CreateTaskRequest struct {
// UpdateTaskRequest represents the request to update a task
type UpdateTaskRequest struct {
Title *string `json:"title" validate:"omitempty,min=1,max=200"`
Description *string `json:"description"`
Description *string `json:"description" validate:"omitempty,max=10000"`
CategoryID *uint `json:"category_id"`
PriorityID *uint `json:"priority_id"`
FrequencyID *uint `json:"frequency_id"`

View File

@@ -21,8 +21,9 @@ type DocumentUserResponse struct {
type DocumentImageResponse struct {
ID uint `json:"id"`
ImageURL string `json:"image_url"`
MediaURL string `json:"media_url"` // Authenticated endpoint: /api/media/document-image/{id}
MediaURL string `json:"media_url"` // Authenticated endpoint: /api/media/document-image/{id}
Caption string `json:"caption"`
Error string `json:"error,omitempty"` // Non-empty when the image could not be resolved
}
// DocumentResponse represents a document in the API response
@@ -35,7 +36,6 @@ type DocumentResponse struct {
Title string `json:"title"`
Description string `json:"description"`
DocumentType models.DocumentType `json:"document_type"`
FileURL string `json:"file_url"`
MediaURL string `json:"media_url"` // Authenticated endpoint: /api/media/document/{id}
FileName string `json:"file_name"`
FileSize *int64 `json:"file_size"`
@@ -80,7 +80,6 @@ func NewDocumentResponse(d *models.Document) DocumentResponse {
Title: d.Title,
Description: d.Description,
DocumentType: d.DocumentType,
FileURL: d.FileURL,
MediaURL: fmt.Sprintf("/api/media/document/%d", d.ID), // Authenticated endpoint
FileName: d.FileName,
FileSize: d.FileSize,
@@ -104,12 +103,16 @@ func NewDocumentResponse(d *models.Document) DocumentResponse {
// Convert images with authenticated media URLs
for _, img := range d.Images {
resp.Images = append(resp.Images, DocumentImageResponse{
imgResp := DocumentImageResponse{
ID: img.ID,
ImageURL: img.ImageURL,
MediaURL: fmt.Sprintf("/api/media/document-image/%d", img.ID), // Authenticated endpoint
Caption: img.Caption,
})
}
if img.ImageURL == "" {
imgResp.Error = "image source URL is missing"
}
resp.Images = append(resp.Images, imgResp)
}
return resp

View File

@@ -281,13 +281,15 @@ func NewTaskListResponse(tasks []models.Task) []TaskResponse {
return results
}
// NewKanbanBoardResponse creates a KanbanBoardResponse from a KanbanBoard model
func NewKanbanBoardResponse(board *models.KanbanBoard, residenceID uint) KanbanBoardResponse {
// NewKanbanBoardResponse creates a KanbanBoardResponse from a KanbanBoard model.
// The `now` parameter should be the start of day in the user's timezone so that
// individual task kanban columns are categorized consistently with the board query.
func NewKanbanBoardResponse(board *models.KanbanBoard, residenceID uint, now time.Time) KanbanBoardResponse {
columns := make([]KanbanColumnResponse, len(board.Columns))
for i, col := range board.Columns {
tasks := make([]TaskResponse, len(col.Tasks))
for j, t := range col.Tasks {
tasks[j] = NewTaskResponse(&t)
tasks[j] = NewTaskResponseWithTime(&t, board.DaysThreshold, now)
}
columns[i] = KanbanColumnResponse{
Name: col.Name,
@@ -306,13 +308,15 @@ func NewKanbanBoardResponse(board *models.KanbanBoard, residenceID uint) KanbanB
}
}
// NewKanbanBoardResponseForAll creates a KanbanBoardResponse for all residences (no specific residence ID)
func NewKanbanBoardResponseForAll(board *models.KanbanBoard) KanbanBoardResponse {
// NewKanbanBoardResponseForAll creates a KanbanBoardResponse for all residences (no specific residence ID).
// The `now` parameter should be the start of day in the user's timezone so that
// individual task kanban columns are categorized consistently with the board query.
func NewKanbanBoardResponseForAll(board *models.KanbanBoard, now time.Time) KanbanBoardResponse {
columns := make([]KanbanColumnResponse, len(board.Columns))
for i, col := range board.Columns {
tasks := make([]TaskResponse, len(col.Tasks))
for j, t := range col.Tasks {
tasks[j] = NewTaskResponse(&t)
tasks[j] = NewTaskResponseWithTime(&t, board.DaysThreshold, now)
}
columns[i] = KanbanColumnResponse{
Name: col.Name,

View File

@@ -8,6 +8,8 @@ import (
"github.com/treytartt/honeydue-api/internal/apperrors"
"github.com/treytartt/honeydue-api/internal/dto/requests"
"github.com/treytartt/honeydue-api/internal/dto/responses"
"github.com/treytartt/honeydue-api/internal/i18n"
"github.com/treytartt/honeydue-api/internal/middleware"
"github.com/treytartt/honeydue-api/internal/services"
)
@@ -115,7 +117,7 @@ func (h *ContractorHandler) DeleteContractor(c echo.Context) error {
if err != nil {
return err
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Contractor deleted successfully"})
return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.contractor_deleted")})
}
// ToggleFavorite handles POST /api/contractors/:id/toggle-favorite/

View File

@@ -12,6 +12,8 @@ import (
"github.com/treytartt/honeydue-api/internal/apperrors"
"github.com/treytartt/honeydue-api/internal/dto/requests"
"github.com/treytartt/honeydue-api/internal/dto/responses"
"github.com/treytartt/honeydue-api/internal/i18n"
"github.com/treytartt/honeydue-api/internal/middleware"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/repositories"
@@ -60,6 +62,9 @@ func (h *DocumentHandler) ListDocuments(c echo.Context) error {
}
if es := c.QueryParam("expiring_soon"); es != "" {
if parsed, err := strconv.Atoi(es); err == nil {
if parsed < 1 || parsed > 3650 {
return apperrors.BadRequest("error.days_out_of_range")
}
filter.ExpiringSoon = &parsed
}
}
@@ -192,7 +197,10 @@ func (h *DocumentHandler) CreateDocument(c echo.Context) error {
}
}
if uploadedFile != nil && h.storageService != nil {
if uploadedFile != nil {
if h.storageService == nil {
return apperrors.Internal(nil)
}
result, err := h.storageService.Upload(uploadedFile, "documents")
if err != nil {
return apperrors.BadRequest("error.failed_to_upload_file")
@@ -262,7 +270,7 @@ func (h *DocumentHandler) DeleteDocument(c echo.Context) error {
if err != nil {
return err
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Document deleted successfully"})
return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.document_deleted")})
}
// ActivateDocument handles POST /api/documents/:id/activate/
@@ -280,7 +288,7 @@ func (h *DocumentHandler) ActivateDocument(c echo.Context) error {
if err != nil {
return err
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Document activated successfully", "document": response})
return c.JSON(http.StatusOK, response)
}
// DeactivateDocument handles POST /api/documents/:id/deactivate/
@@ -298,7 +306,7 @@ func (h *DocumentHandler) DeactivateDocument(c echo.Context) error {
if err != nil {
return err
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Document deactivated successfully", "document": response})
return c.JSON(http.StatusOK, response)
}
// UploadDocumentImage handles POST /api/documents/:id/images/

View File

@@ -7,6 +7,8 @@ import (
"github.com/labstack/echo/v4"
"github.com/treytartt/honeydue-api/internal/apperrors"
"github.com/treytartt/honeydue-api/internal/dto/responses"
"github.com/treytartt/honeydue-api/internal/i18n"
"github.com/treytartt/honeydue-api/internal/middleware"
"github.com/treytartt/honeydue-api/internal/services"
)
@@ -87,7 +89,7 @@ func (h *NotificationHandler) MarkAsRead(c echo.Context) error {
return err
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "message.notification_marked_read"})
return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.notification_marked_read")})
}
// MarkAllAsRead handles POST /api/notifications/mark-all-read/
@@ -102,7 +104,7 @@ func (h *NotificationHandler) MarkAllAsRead(c echo.Context) error {
return err
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "message.all_notifications_marked_read"})
return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.all_notifications_marked_read")})
}
// GetPreferences handles GET /api/notifications/preferences/
@@ -200,7 +202,10 @@ func (h *NotificationHandler) UnregisterDevice(c echo.Context) error {
return apperrors.BadRequest("error.registration_id_required")
}
if req.Platform == "" {
req.Platform = "ios" // Default to iOS
return apperrors.BadRequest("error.platform_required")
}
if req.Platform != "ios" && req.Platform != "android" {
return apperrors.BadRequest("error.invalid_platform")
}
err = h.notificationService.UnregisterDevice(req.RegistrationID, req.Platform, user.ID)
@@ -208,7 +213,7 @@ func (h *NotificationHandler) UnregisterDevice(c echo.Context) error {
return err
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "message.device_unregistered"})
return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.device_removed")})
}
// DeleteDevice handles DELETE /api/notifications/devices/:id/
@@ -225,7 +230,10 @@ func (h *NotificationHandler) DeleteDevice(c echo.Context) error {
platform := c.QueryParam("platform")
if platform == "" {
platform = "ios" // Default to iOS
return apperrors.BadRequest("error.platform_required")
}
if platform != "ios" && platform != "android" {
return apperrors.BadRequest("error.invalid_platform")
}
err = h.notificationService.DeleteDevice(uint(deviceID), platform, user.ID)
@@ -233,5 +241,5 @@ func (h *NotificationHandler) DeleteDevice(c echo.Context) error {
return err
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "message.device_removed"})
return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.device_removed")})
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/labstack/echo/v4"
"github.com/treytartt/honeydue-api/internal/apperrors"
"github.com/treytartt/honeydue-api/internal/i18n"
"github.com/treytartt/honeydue-api/internal/middleware"
"github.com/treytartt/honeydue-api/internal/services"
)
@@ -139,7 +140,7 @@ func (h *SubscriptionHandler) ProcessPurchase(c echo.Context) error {
}
return c.JSON(http.StatusOK, map[string]interface{}{
"message": "message.subscription_upgraded",
"message": i18n.LocalizedMessage(c, "message.subscription_upgraded"),
"subscription": subscription,
})
}
@@ -157,7 +158,7 @@ func (h *SubscriptionHandler) CancelSubscription(c echo.Context) error {
}
return c.JSON(http.StatusOK, map[string]interface{}{
"message": "message.subscription_cancelled",
"message": i18n.LocalizedMessage(c, "message.subscription_cancelled"),
"subscription": subscription,
})
}
@@ -182,8 +183,15 @@ func (h *SubscriptionHandler) RestoreSubscription(c echo.Context) error {
switch req.Platform {
case "ios":
// B-14: Validate that at least one of receipt_data or transaction_id is provided
if req.ReceiptData == "" && req.TransactionID == "" {
return apperrors.BadRequest("error.receipt_data_required")
}
subscription, err = h.subscriptionService.ProcessApplePurchase(user.ID, req.ReceiptData, req.TransactionID)
case "android":
if req.PurchaseToken == "" {
return apperrors.BadRequest("error.purchase_token_required")
}
subscription, err = h.subscriptionService.ProcessGooglePurchase(user.ID, req.PurchaseToken, req.ProductID)
default:
return apperrors.BadRequest("error.invalid_platform")
@@ -194,7 +202,7 @@ func (h *SubscriptionHandler) RestoreSubscription(c echo.Context) error {
}
return c.JSON(http.StatusOK, map[string]interface{}{
"message": "message.subscription_restored",
"message": i18n.LocalizedMessage(c, "message.subscription_restored"),
"subscription": subscription,
})
}

View File

@@ -2,15 +2,18 @@ package handlers
import (
"crypto/ecdsa"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"math/big"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
@@ -23,6 +26,11 @@ import (
"github.com/treytartt/honeydue-api/internal/services"
)
// maxWebhookBodySize is the maximum allowed request body size for webhook
// payloads (1 MB). This prevents a malicious or misbehaving sender from
// forcing the server to allocate unbounded memory.
const maxWebhookBodySize = 1 << 20 // 1 MB
// SubscriptionWebhookHandler handles subscription webhook callbacks
type SubscriptionWebhookHandler struct {
subscriptionRepo *repositories.SubscriptionRepository
@@ -112,7 +120,7 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"})
}
body, err := io.ReadAll(c.Request().Body)
body, err := io.ReadAll(io.LimitReader(c.Request().Body, maxWebhookBodySize))
if err != nil {
log.Error().Err(err).Msg("Apple Webhook: Failed to read body")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
@@ -211,13 +219,22 @@ func (h *SubscriptionWebhookHandler) decodeAppleSignedPayload(signedPayload stri
return &notification, nil
}
// decodeAppleTransaction decodes a signed transaction info JWS
// decodeAppleTransaction decodes and verifies a signed transaction info JWS.
// The inner JWS signature is verified using the same Apple certificate chain
// validation as the outer notification payload.
func (h *SubscriptionWebhookHandler) decodeAppleTransaction(signedTransaction string) (*AppleTransactionInfo, error) {
parts := strings.Split(signedTransaction, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWS format")
}
// S-16: Verify the inner JWS signature for signedTransactionInfo.
// Apple signs each inner JWS independently with the same x5c certificate
// chain as the outer notification, so the same verification applies.
if err := h.VerifyAppleSignature(signedTransaction); err != nil {
return nil, fmt.Errorf("Apple transaction JWS signature verification failed: %w", err)
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("failed to decode payload: %w", err)
@@ -231,13 +248,20 @@ func (h *SubscriptionWebhookHandler) decodeAppleTransaction(signedTransaction st
return &info, nil
}
// decodeAppleRenewalInfo decodes signed renewal info JWS
// decodeAppleRenewalInfo decodes and verifies a signed renewal info JWS.
// The inner JWS signature is verified using the same Apple certificate chain
// validation as the outer notification payload.
func (h *SubscriptionWebhookHandler) decodeAppleRenewalInfo(signedRenewal string) (*AppleRenewalInfo, error) {
parts := strings.Split(signedRenewal, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWS format")
}
// S-16: Verify the inner JWS signature for signedRenewalInfo.
if err := h.VerifyAppleSignature(signedRenewal); err != nil {
return nil, fmt.Errorf("Apple renewal JWS signature verification failed: %w", err)
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("failed to decode payload: %w", err)
@@ -484,7 +508,12 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"})
}
body, err := io.ReadAll(c.Request().Body)
// C-01: Verify the Google Pub/Sub push authentication token before processing
if !h.VerifyGooglePubSubToken(c) {
return c.JSON(http.StatusUnauthorized, map[string]interface{}{"error": "unauthorized"})
}
body, err := io.ReadAll(io.LimitReader(c.Request().Body, maxWebhookBodySize))
if err != nil {
log.Error().Err(err).Msg("Google Webhook: Failed to read body")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
@@ -781,7 +810,7 @@ func (h *SubscriptionWebhookHandler) HandleStripeWebhook(c echo.Context) error {
return c.JSON(http.StatusOK, map[string]interface{}{"status": "not_configured"})
}
body, err := io.ReadAll(c.Request().Body)
body, err := io.ReadAll(io.LimitReader(c.Request().Body, maxWebhookBodySize))
if err != nil {
log.Error().Err(err).Msg("Stripe Webhook: Failed to read body")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
@@ -884,10 +913,109 @@ func (h *SubscriptionWebhookHandler) VerifyAppleSignature(signedPayload string)
return nil
}
// googleOIDCCertsURL is the endpoint that serves Google's public OAuth2
// certificates used to verify JWTs issued by accounts.google.com (including
// Pub/Sub push tokens).
const googleOIDCCertsURL = "https://www.googleapis.com/oauth2/v3/certs"
// googleJWKSCache caches the fetched Google JWKS keys so we don't hit the
// network on every webhook request. Keys are refreshed when the cache expires.
var (
googleJWKSCache map[string]*rsa.PublicKey
googleJWKSCacheMu sync.RWMutex
googleJWKSCacheTime time.Time
googleJWKSCacheTTL = 1 * time.Hour
)
// googleJWKSResponse represents the JSON Web Key Set response from Google.
type googleJWKSResponse struct {
Keys []googleJWK `json:"keys"`
}
// googleJWK represents a single JSON Web Key from Google's OIDC endpoint.
type googleJWK struct {
Kid string `json:"kid"` // Key ID
Kty string `json:"kty"` // Key type (RSA)
Alg string `json:"alg"` // Algorithm (RS256)
N string `json:"n"` // RSA modulus (base64url)
E string `json:"e"` // RSA exponent (base64url)
}
// fetchGoogleOIDCKeys fetches Google's public OIDC keys from their well-known
// endpoint, returning a map of key-id to RSA public key. Results are cached
// for googleJWKSCacheTTL to avoid excessive network calls.
func fetchGoogleOIDCKeys() (map[string]*rsa.PublicKey, error) {
googleJWKSCacheMu.RLock()
if googleJWKSCache != nil && time.Since(googleJWKSCacheTime) < googleJWKSCacheTTL {
cached := googleJWKSCache
googleJWKSCacheMu.RUnlock()
return cached, nil
}
googleJWKSCacheMu.RUnlock()
googleJWKSCacheMu.Lock()
defer googleJWKSCacheMu.Unlock()
// Double-check after acquiring write lock
if googleJWKSCache != nil && time.Since(googleJWKSCacheTime) < googleJWKSCacheTTL {
return googleJWKSCache, nil
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Get(googleOIDCCertsURL)
if err != nil {
return nil, fmt.Errorf("failed to fetch Google OIDC keys: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Google OIDC keys endpoint returned status %d", resp.StatusCode)
}
var jwks googleJWKSResponse
if err := json.NewDecoder(io.LimitReader(resp.Body, maxWebhookBodySize)).Decode(&jwks); err != nil {
return nil, fmt.Errorf("failed to decode Google OIDC keys: %w", err)
}
keys := make(map[string]*rsa.PublicKey, len(jwks.Keys))
for _, k := range jwks.Keys {
if k.Kty != "RSA" {
continue
}
nBytes, err := base64.RawURLEncoding.DecodeString(k.N)
if err != nil {
log.Warn().Str("kid", k.Kid).Err(err).Msg("Google OIDC: failed to decode modulus")
continue
}
eBytes, err := base64.RawURLEncoding.DecodeString(k.E)
if err != nil {
log.Warn().Str("kid", k.Kid).Err(err).Msg("Google OIDC: failed to decode exponent")
continue
}
n := new(big.Int).SetBytes(nBytes)
e := 0
for _, b := range eBytes {
e = e<<8 + int(b)
}
keys[k.Kid] = &rsa.PublicKey{N: n, E: e}
}
if len(keys) == 0 {
return nil, fmt.Errorf("no usable RSA keys found in Google OIDC response")
}
googleJWKSCache = keys
googleJWKSCacheTime = time.Now()
return keys, nil
}
// VerifyGooglePubSubToken verifies the Pub/Sub push authentication token.
// Returns false (deny) when the Authorization header is missing or the token
// cannot be validated. This prevents unauthenticated callers from injecting
// webhook events.
// The token is a JWT signed by Google (accounts.google.com). This function
// verifies the signature against Google's published OIDC public keys, checks
// the issuer claim, and validates the email claim is a Google service account.
// Returns false (deny) when verification fails for any reason.
func (h *SubscriptionWebhookHandler) VerifyGooglePubSubToken(c echo.Context) bool {
authHeader := c.Request().Header.Get("Authorization")
if authHeader == "" {
@@ -907,12 +1035,52 @@ func (h *SubscriptionWebhookHandler) VerifyGooglePubSubToken(c echo.Context) boo
return false
}
// Parse the token as a JWT. Google Pub/Sub push tokens are signed JWTs
// issued by accounts.google.com. We verify the claims to ensure the
// token was intended for our service.
token, _, err := jwt.NewParser().ParseUnverified(bearerToken, jwt.MapClaims{})
// Fetch Google's OIDC public keys for signature verification
googleKeys, err := fetchGoogleOIDCKeys()
if err != nil {
log.Warn().Err(err).Msg("Google Webhook: failed to parse Bearer token")
log.Error().Err(err).Msg("Google Webhook: failed to fetch OIDC keys, denying request")
return false
}
// Parse and verify the JWT signature against Google's published public keys
token, err := jwt.Parse(bearerToken, func(token *jwt.Token) (interface{}, error) {
// Ensure the signing method is RSA (Google uses RS256)
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
kid, _ := token.Header["kid"].(string)
if kid == "" {
return nil, fmt.Errorf("missing kid header in token")
}
key, ok := googleKeys[kid]
if !ok {
// Key may have rotated; try refreshing once
googleJWKSCacheMu.Lock()
googleJWKSCacheTime = time.Time{} // Force refresh
googleJWKSCacheMu.Unlock()
refreshedKeys, err := fetchGoogleOIDCKeys()
if err != nil {
return nil, fmt.Errorf("failed to refresh Google OIDC keys: %w", err)
}
key, ok = refreshedKeys[kid]
if !ok {
return nil, fmt.Errorf("unknown key ID: %s", kid)
}
}
return key, nil
}, jwt.WithValidMethods([]string{"RS256"}))
if err != nil {
log.Warn().Err(err).Msg("Google Webhook: JWT signature verification failed")
return false
}
if !token.Valid {
log.Warn().Msg("Google Webhook: token is invalid after verification")
return false
}

View File

@@ -38,19 +38,29 @@ func (h *TaskHandler) ListTasks(c echo.Context) error {
userNow := middleware.GetUserNow(c)
// Auto-capture timezone from header for background job calculations (e.g., daily digest)
// Runs synchronously — this is a lightweight DB upsert that should complete quickly
// Only write to DB if the timezone has actually changed from the cached value
if tzHeader := c.Request().Header.Get("X-Timezone"); tzHeader != "" {
h.taskService.UpdateUserTimezone(user.ID, tzHeader)
cachedTZ, _ := c.Get("user_timezone").(string)
if cachedTZ != tzHeader {
h.taskService.UpdateUserTimezone(user.ID, tzHeader)
c.Set("user_timezone", tzHeader)
}
}
daysThreshold := 30
// Support "days" param first, fall back to "days_threshold" for backward compatibility
if d := c.QueryParam("days"); d != "" {
if parsed, err := strconv.Atoi(d); err == nil {
if parsed < 1 || parsed > 3650 {
return apperrors.BadRequest("error.days_out_of_range")
}
daysThreshold = parsed
}
} else if d := c.QueryParam("days_threshold"); d != "" {
if parsed, err := strconv.Atoi(d); err == nil {
if parsed < 1 || parsed > 3650 {
return apperrors.BadRequest("error.days_out_of_range")
}
daysThreshold = parsed
}
}
@@ -97,10 +107,16 @@ func (h *TaskHandler) GetTasksByResidence(c echo.Context) error {
// Support "days" param first, fall back to "days_threshold" for backward compatibility
if d := c.QueryParam("days"); d != "" {
if parsed, err := strconv.Atoi(d); err == nil {
if parsed < 1 || parsed > 3650 {
return apperrors.BadRequest("error.days_out_of_range")
}
daysThreshold = parsed
}
} else if d := c.QueryParam("days_threshold"); d != "" {
if parsed, err := strconv.Atoi(d); err == nil {
if parsed < 1 || parsed > 3650 {
return apperrors.BadRequest("error.days_out_of_range")
}
daysThreshold = parsed
}
}

View File

@@ -7,19 +7,31 @@ import (
"github.com/rs/zerolog/log"
"github.com/treytartt/honeydue-api/internal/apperrors"
"github.com/treytartt/honeydue-api/internal/dto/responses"
"github.com/treytartt/honeydue-api/internal/i18n"
"github.com/treytartt/honeydue-api/internal/middleware"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/services"
)
// FileOwnershipChecker verifies whether a user owns a file referenced by URL.
// Implementations should check associated records (e.g., task completion images,
// document files, document images) to determine ownership.
type FileOwnershipChecker interface {
IsFileOwnedByUser(fileURL string, userID uint) (bool, error)
}
// UploadHandler handles file upload endpoints
type UploadHandler struct {
storageService *services.StorageService
storageService *services.StorageService
fileOwnershipChecker FileOwnershipChecker
}
// NewUploadHandler creates a new upload handler
func NewUploadHandler(storageService *services.StorageService) *UploadHandler {
return &UploadHandler{storageService: storageService}
func NewUploadHandler(storageService *services.StorageService, fileOwnershipChecker FileOwnershipChecker) *UploadHandler {
return &UploadHandler{
storageService: storageService,
fileOwnershipChecker: fileOwnershipChecker,
}
}
// UploadImage handles POST /api/uploads/image
@@ -83,13 +95,14 @@ type DeleteFileRequest struct {
// DeleteFile handles DELETE /api/uploads
// Expects JSON body with "url" field.
//
// TODO(SEC-18): Add ownership verification. Currently any authenticated user can delete
// any file if they know the URL. The upload system does not track which user uploaded
// which file, so a proper fix requires adding an uploads table or file ownership metadata.
// For now, deletions are logged with user ID for audit trail, and StorageService.Delete
// enforces path containment to prevent deleting files outside the upload directory.
// Verifies that the requesting user owns the file by checking associated records
// (task completion images, document files/images) before allowing deletion.
func (h *UploadHandler) DeleteFile(c echo.Context) error {
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
var req DeleteFileRequest
if err := c.Bind(&req); err != nil {
@@ -100,17 +113,28 @@ func (h *UploadHandler) DeleteFile(c echo.Context) error {
return apperrors.BadRequest("error.url_required")
}
// Log the deletion with user ID for audit trail
if user, ok := c.Get(middleware.AuthUserKey).(*models.User); ok {
log.Info().
Uint("user_id", user.ID).
Str("file_url", req.URL).
Msg("File deletion requested")
// Verify ownership: the user must own a record that references this file URL
if h.fileOwnershipChecker != nil {
owned, err := h.fileOwnershipChecker.IsFileOwnedByUser(req.URL, user.ID)
if err != nil {
log.Error().Err(err).Uint("user_id", user.ID).Str("file_url", req.URL).Msg("Failed to check file ownership")
return apperrors.Internal(err)
}
if !owned {
log.Warn().Uint("user_id", user.ID).Str("file_url", req.URL).Msg("Unauthorized file deletion attempt")
return apperrors.Forbidden("error.file_access_denied")
}
}
// Log the deletion with user ID for audit trail
log.Info().
Uint("user_id", user.ID).
Str("file_url", req.URL).
Msg("File deletion requested")
if err := h.storageService.Delete(req.URL); err != nil {
return err
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "File deleted successfully"})
return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.file_deleted")})
}

View File

@@ -5,6 +5,7 @@ import (
"testing"
"github.com/treytartt/honeydue-api/internal/i18n"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/testutil"
)
@@ -18,12 +19,16 @@ func init() {
func TestDeleteFile_MissingURL_Returns400(t *testing.T) {
// Use a test storage service — DeleteFile won't reach storage since validation fails first
storageSvc := newTestStorageService("/var/uploads")
handler := NewUploadHandler(storageSvc)
handler := NewUploadHandler(storageSvc, nil)
e := testutil.SetupTestRouter()
// Register route
e.DELETE("/api/uploads/", handler.DeleteFile)
// Register route with mock auth middleware
testUser := &models.User{FirstName: "Test", Email: "test@test.com"}
testUser.ID = 1
authGroup := e.Group("/api")
authGroup.Use(testutil.MockAuthMiddleware(testUser))
authGroup.DELETE("/uploads/", handler.DeleteFile)
// Send request with empty JSON body (url field missing)
w := testutil.MakeRequest(e, http.MethodDelete, "/api/uploads/", map[string]string{}, "test-token")
@@ -32,10 +37,16 @@ func TestDeleteFile_MissingURL_Returns400(t *testing.T) {
func TestDeleteFile_EmptyURL_Returns400(t *testing.T) {
storageSvc := newTestStorageService("/var/uploads")
handler := NewUploadHandler(storageSvc)
handler := NewUploadHandler(storageSvc, nil)
e := testutil.SetupTestRouter()
e.DELETE("/api/uploads/", handler.DeleteFile)
// Register route with mock auth middleware
testUser := &models.User{FirstName: "Test", Email: "test@test.com"}
testUser.ID = 1
authGroup := e.Group("/api")
authGroup.Use(testutil.MockAuthMiddleware(testUser))
authGroup.DELETE("/uploads/", handler.DeleteFile)
// Send request with empty url field
w := testutil.MakeRequest(e, http.MethodDelete, "/api/uploads/", map[string]string{"url": ""}, "test-token")

View File

@@ -111,6 +111,11 @@
"error.purchase_token_required": "purchase_token is required for Android",
"error.no_file_provided": "No file provided",
"error.url_required": "File URL is required",
"error.file_access_denied": "You don't have access to this file",
"error.days_out_of_range": "Days parameter must be between 1 and 3650",
"error.platform_required": "Platform is required (ios or android)",
"error.registration_id_required": "Registration ID is required",
"error.failed_to_fetch_residence_types": "Failed to fetch residence types",
"error.failed_to_fetch_task_categories": "Failed to fetch task categories",

View File

@@ -25,19 +25,24 @@ const (
TokenCacheTTL = 5 * time.Minute
// TokenCachePrefix is the prefix for token cache keys
TokenCachePrefix = "auth_token_"
// UserCacheTTL is how long full user records are cached in memory to
// avoid hitting the database on every authenticated request.
UserCacheTTL = 30 * time.Second
)
// AuthMiddleware provides token authentication middleware
type AuthMiddleware struct {
db *gorm.DB
cache *services.CacheService
db *gorm.DB
cache *services.CacheService
userCache *UserCache
}
// NewAuthMiddleware creates a new auth middleware instance
func NewAuthMiddleware(db *gorm.DB, cache *services.CacheService) *AuthMiddleware {
return &AuthMiddleware{
db: db,
cache: cache,
db: db,
cache: cache,
userCache: NewUserCache(UserCacheTTL),
}
}
@@ -138,7 +143,8 @@ func extractToken(c echo.Context) (string, error) {
return token, nil
}
// getUserFromCache tries to get user from Redis cache
// getUserFromCache tries to get user from Redis cache, then from the
// in-memory user cache, before falling back to the database.
func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*models.User, error) {
if m.cache == nil {
return nil, fmt.Errorf("cache not available")
@@ -152,10 +158,20 @@ func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*m
return nil, err
}
// Get user from database by ID
// Try in-memory user cache first to avoid a DB round-trip
if cached := m.userCache.Get(userID); cached != nil {
if !cached.IsActive {
m.cache.InvalidateAuthToken(ctx, token)
m.userCache.Invalidate(userID)
return nil, fmt.Errorf("user is inactive")
}
return cached, nil
}
// In-memory cache miss — fetch from database
var user models.User
if err := m.db.First(&user, userID).Error; err != nil {
// User was deleted - invalidate cache
// User was deleted - invalidate caches
m.cache.InvalidateAuthToken(ctx, token)
return nil, err
}
@@ -166,10 +182,13 @@ func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*m
return nil, fmt.Errorf("user is inactive")
}
// Store in in-memory cache for subsequent requests
m.userCache.Set(&user)
return &user, nil
}
// getUserFromDatabase looks up the token in the database
// getUserFromDatabase looks up the token in the database and caches the
// resulting user record in memory.
func (m *AuthMiddleware) getUserFromDatabase(token string) (*models.User, error) {
var authToken models.AuthToken
if err := m.db.Preload("User").Where("key = ?", token).First(&authToken).Error; err != nil {
@@ -181,6 +200,8 @@ func (m *AuthMiddleware) getUserFromDatabase(token string) (*models.User, error)
return nil, fmt.Errorf("user is inactive")
}
// Store in in-memory cache for subsequent requests
m.userCache.Set(&authToken.User)
return &authToken.User, nil
}
@@ -220,7 +241,11 @@ func GetAuthToken(c echo.Context) string {
if token == nil {
return ""
}
return token.(string)
tokenStr, ok := token.(string)
if !ok {
return ""
}
return tokenStr
}
// MustGetAuthUser retrieves the authenticated user or returns error with 401

View File

@@ -0,0 +1,40 @@
package middleware
import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/treytartt/honeydue-api/internal/dto/responses"
)
// HostCheck returns middleware that validates the request Host header against
// a set of allowed hosts. This prevents SSRF attacks where an attacker crafts
// a request with an arbitrary Host header to reach internal services via the
// reverse proxy.
//
// If allowedHosts is empty the middleware is a no-op (all hosts pass).
func HostCheck(allowedHosts []string) echo.MiddlewareFunc {
allowed := make(map[string]struct{}, len(allowedHosts))
for _, h := range allowedHosts {
allowed[h] = struct{}{}
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// If no allowed hosts configured, skip the check
if len(allowed) == 0 {
return next(c)
}
host := c.Request().Host
if _, ok := allowed[host]; !ok {
return c.JSON(http.StatusForbidden, responses.ErrorResponse{
Error: "Forbidden",
})
}
return next(c)
}
}
}

View File

@@ -0,0 +1,68 @@
package middleware
import (
"net/http"
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"golang.org/x/time/rate"
"github.com/treytartt/honeydue-api/internal/dto/responses"
)
// AuthRateLimiter returns rate-limiting middleware tuned for authentication
// endpoints. It uses Echo's built-in in-memory rate limiter keyed by client
// IP address.
//
// Parameters:
// - ratePerSecond: sustained request rate (e.g., 10/60.0 for ~10 per minute)
// - burst: maximum burst size above the sustained rate
func AuthRateLimiter(ratePerSecond rate.Limit, burst int) echo.MiddlewareFunc {
store := middleware.NewRateLimiterMemoryStoreWithConfig(
middleware.RateLimiterMemoryStoreConfig{
Rate: ratePerSecond,
Burst: burst,
ExpiresIn: 5 * time.Minute,
},
)
return middleware.RateLimiterWithConfig(middleware.RateLimiterConfig{
Skipper: middleware.DefaultSkipper,
IdentifierExtractor: func(c echo.Context) (string, error) {
return c.RealIP(), nil
},
Store: store,
DenyHandler: func(c echo.Context, _ string, _ error) error {
return c.JSON(http.StatusTooManyRequests, responses.ErrorResponse{
Error: "Too many requests. Please try again later.",
})
},
ErrorHandler: func(c echo.Context, err error) error {
return c.JSON(http.StatusForbidden, responses.ErrorResponse{
Error: "Unable to process request.",
})
},
})
}
// LoginRateLimiter returns rate-limiting middleware for login endpoints.
// Allows 10 requests per minute with a burst of 5.
func LoginRateLimiter() echo.MiddlewareFunc {
// 10 requests per 60 seconds = ~0.167 req/s, burst 5
return AuthRateLimiter(rate.Limit(10.0/60.0), 5)
}
// RegistrationRateLimiter returns rate-limiting middleware for registration
// endpoints. Allows 5 requests per minute with a burst of 3.
func RegistrationRateLimiter() echo.MiddlewareFunc {
// 5 requests per 60 seconds = ~0.083 req/s, burst 3
return AuthRateLimiter(rate.Limit(5.0/60.0), 3)
}
// PasswordResetRateLimiter returns rate-limiting middleware for password
// reset endpoints. Allows 3 requests per minute with a burst of 2.
func PasswordResetRateLimiter() echo.MiddlewareFunc {
// 3 requests per 60 seconds = 0.05 req/s, burst 2
return AuthRateLimiter(rate.Limit(3.0/60.0), 2)
}

View File

@@ -7,14 +7,25 @@ import (
)
const (
// TimezoneKey is the key used to store the user's timezone in the context
// TimezoneKey is the key used to store the user's timezone *time.Location in the context
TimezoneKey = "user_timezone"
// TimezoneNameKey stores the raw IANA timezone string from the request header
TimezoneNameKey = "user_timezone_name"
// TimezoneChangedKey is a bool context key indicating whether the timezone
// differs from the previously cached value for this user. Handlers should
// only persist the timezone to DB when this is true.
TimezoneChangedKey = "timezone_changed"
// UserNowKey is the key used to store the timezone-aware "now" time in the context
UserNowKey = "user_now"
// TimezoneHeader is the HTTP header name for the user's timezone
TimezoneHeader = "X-Timezone"
)
// package-level timezone cache shared across requests. It is safe for
// concurrent use and has no TTL — entries are only updated when a new
// timezone value is observed for a given user.
var tzCache = NewTimezoneCache()
// TimezoneMiddleware extracts the user's timezone from the request header
// and stores a timezone-aware "now" time in the context.
//
@@ -22,14 +33,31 @@ const (
// or a UTC offset (e.g., "-08:00", "+05:30").
//
// If no timezone is provided or it's invalid, UTC is used as the default.
//
// The middleware also compares the incoming timezone with a cached value per
// user and sets TimezoneChangedKey in the context so downstream handlers
// know whether a DB write is needed.
func TimezoneMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
tzName := c.Request().Header.Get(TimezoneHeader)
loc := parseTimezone(tzName)
// Store the location and the current time in that timezone
// Store the location and the raw name in the context
c.Set(TimezoneKey, loc)
c.Set(TimezoneNameKey, tzName)
// Determine whether the timezone changed for this user so handlers
// can skip unnecessary DB writes.
changed := false
if tzName != "" {
if user := GetAuthUser(c); user != nil {
if !tzCache.GetAndCompare(user.ID, tzName) {
changed = true
}
}
}
c.Set(TimezoneChangedKey, changed)
// Calculate "now" in the user's timezone, then get start of day
// For date comparisons, we want to compare against the START of the user's current day
@@ -42,6 +70,20 @@ func TimezoneMiddleware() echo.MiddlewareFunc {
}
}
// IsTimezoneChanged returns true when the user's timezone header differs from
// the previously observed value. Handlers should only persist the timezone to
// DB when this returns true.
func IsTimezoneChanged(c echo.Context) bool {
val, ok := c.Get(TimezoneChangedKey).(bool)
return ok && val
}
// GetTimezoneName returns the raw timezone string from the request header.
func GetTimezoneName(c echo.Context) string {
val, _ := c.Get(TimezoneNameKey).(string)
return val
}
// parseTimezone parses a timezone string and returns a *time.Location.
// Supports IANA timezone names (e.g., "America/Los_Angeles") and
// UTC offsets (e.g., "-08:00", "+05:30").

View File

@@ -0,0 +1,113 @@
package middleware
import (
"sync"
"time"
"github.com/treytartt/honeydue-api/internal/models"
)
// userCacheEntry holds a cached user record with an expiration time.
type userCacheEntry struct {
user *models.User
expiresAt time.Time
}
// UserCache is a concurrency-safe in-memory cache for User records, keyed by
// user ID. Entries expire after a configurable TTL. The cache uses a sync.Map
// for lock-free reads on the hot path, with periodic lazy eviction of stale
// entries during Set operations.
type UserCache struct {
store sync.Map
ttl time.Duration
lastGC time.Time
gcMu sync.Mutex
gcEvery time.Duration
}
// NewUserCache creates a UserCache with the given TTL for entries.
func NewUserCache(ttl time.Duration) *UserCache {
return &UserCache{
ttl: ttl,
lastGC: time.Now(),
gcEvery: 2 * time.Minute,
}
}
// Get returns a cached user by ID, or nil if not found or expired.
func (c *UserCache) Get(userID uint) *models.User {
val, ok := c.store.Load(userID)
if !ok {
return nil
}
entry := val.(*userCacheEntry)
if time.Now().After(entry.expiresAt) {
c.store.Delete(userID)
return nil
}
// Return a shallow copy so callers cannot mutate the cached value.
user := *entry.user
return &user
}
// Set stores a user in the cache. It also triggers a background garbage-
// collection sweep if enough time has elapsed since the last one.
func (c *UserCache) Set(user *models.User) {
// Store a copy to prevent external mutation of the cached object.
copied := *user
c.store.Store(user.ID, &userCacheEntry{
user: &copied,
expiresAt: time.Now().Add(c.ttl),
})
c.maybeGC()
}
// Invalidate removes a user from the cache by ID.
func (c *UserCache) Invalidate(userID uint) {
c.store.Delete(userID)
}
// maybeGC lazily sweeps expired entries at most once per gcEvery interval.
func (c *UserCache) maybeGC() {
c.gcMu.Lock()
if time.Since(c.lastGC) < c.gcEvery {
c.gcMu.Unlock()
return
}
c.lastGC = time.Now()
c.gcMu.Unlock()
now := time.Now()
c.store.Range(func(key, value any) bool {
entry := value.(*userCacheEntry)
if now.After(entry.expiresAt) {
c.store.Delete(key)
}
return true
})
}
// TimezoneCache tracks the last-known timezone per user ID so the timezone
// middleware only writes to the database when the value actually changes.
type TimezoneCache struct {
store sync.Map
}
// NewTimezoneCache creates a new TimezoneCache.
func NewTimezoneCache() *TimezoneCache {
return &TimezoneCache{}
}
// GetAndCompare returns true if the cached timezone for the user matches tz.
// If the timezone is different (or not yet cached), it updates the cache and
// returns false, signaling that a DB write is needed.
func (tc *TimezoneCache) GetAndCompare(userID uint, tz string) (unchanged bool) {
val, loaded := tc.store.Load(userID)
if loaded {
if cached, ok := val.(string); ok && cached == tz {
return true
}
}
tc.store.Store(userID, tz)
return false
}

View File

@@ -13,12 +13,6 @@ type BaseModel struct {
UpdatedAt time.Time `json:"updated_at"`
}
// SoftDeleteModel extends BaseModel with soft delete support
type SoftDeleteModel struct {
BaseModel
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
// BeforeCreate sets timestamps before creating a record
func (b *BaseModel) BeforeCreate(tx *gorm.DB) error {
now := time.Now().UTC()

View File

@@ -29,10 +29,10 @@ type NotificationPreference struct {
// Custom notification times (nullable, stored as UTC hour 0-23)
// When nil, system defaults from config are used
TaskDueSoonHour *int `gorm:"column:task_due_soon_hour" json:"task_due_soon_hour"`
TaskOverdueHour *int `gorm:"column:task_overdue_hour" json:"task_overdue_hour"`
WarrantyExpiringHour *int `gorm:"column:warranty_expiring_hour" json:"warranty_expiring_hour"`
DailyDigestHour *int `gorm:"column:daily_digest_hour" json:"daily_digest_hour"`
TaskDueSoonHour *int `gorm:"column:task_due_soon_hour" json:"task_due_soon_hour" validate:"omitempty,min=0,max=23"`
TaskOverdueHour *int `gorm:"column:task_overdue_hour" json:"task_overdue_hour" validate:"omitempty,min=0,max=23"`
WarrantyExpiringHour *int `gorm:"column:warranty_expiring_hour" json:"warranty_expiring_hour" validate:"omitempty,min=0,max=23"`
DailyDigestHour *int `gorm:"column:daily_digest_hour" json:"daily_digest_hour" validate:"omitempty,min=0,max=23"`
// User timezone for background job calculations (IANA name, e.g., "America/Los_Angeles")
// Auto-captured from X-Timezone header on API calls

View File

@@ -181,88 +181,6 @@ func (t *Task) IsDueSoon(days int) bool {
return !effectiveDate.Before(now) && effectiveDate.Before(threshold)
}
// GetKanbanColumn returns the kanban column name for this task using the
// Chain of Responsibility pattern from the categorization package.
// Uses UTC time for categorization.
//
// For timezone-aware categorization, use GetKanbanColumnWithTimezone.
func (t *Task) GetKanbanColumn(daysThreshold int) string {
// Import would cause circular dependency, so we inline the logic
// This delegates to the categorization package via internal/task re-export
return t.GetKanbanColumnWithTimezone(daysThreshold, time.Now().UTC())
}
// GetKanbanColumnWithTimezone returns the kanban column name using a specific
// time (in the user's timezone). The time is used to determine "today" for
// overdue/due-soon calculations.
//
// Example: For a user in Tokyo, pass time.Now().In(tokyoLocation) to get
// accurate categorization relative to their local date.
func (t *Task) GetKanbanColumnWithTimezone(daysThreshold int, now time.Time) string {
// Note: We can't import categorization directly due to circular dependency.
// Instead, this method implements the categorization logic inline.
// The logic MUST match categorization.Chain exactly.
if daysThreshold <= 0 {
daysThreshold = 30
}
// Start of day normalization
startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
threshold := startOfDay.AddDate(0, 0, daysThreshold)
// Priority 1: Cancelled
if t.IsCancelled {
return "cancelled_tasks"
}
// Priority 2: Archived (goes to cancelled column - both are "inactive" states)
if t.IsArchived {
return "cancelled_tasks"
}
// Priority 3: Completed (NextDueDate nil with completions)
hasCompletions := len(t.Completions) > 0 || t.CompletionCount > 0
if t.NextDueDate == nil && hasCompletions {
return "completed_tasks"
}
// Priority 4: In Progress
if t.InProgress {
return "in_progress_tasks"
}
// Get effective date: NextDueDate ?? DueDate
var effectiveDate *time.Time
if t.NextDueDate != nil {
effectiveDate = t.NextDueDate
} else {
effectiveDate = t.DueDate
}
if effectiveDate != nil {
// Normalize effective date to same timezone for calendar date comparison
// Task dates are stored as UTC but represent calendar dates (YYYY-MM-DD)
normalizedEffective := time.Date(
effectiveDate.Year(), effectiveDate.Month(), effectiveDate.Day(),
0, 0, 0, 0, now.Location(),
)
// Priority 5: Overdue (effective date before today)
if normalizedEffective.Before(startOfDay) {
return "overdue_tasks"
}
// Priority 6: Due Soon (effective date before threshold)
if normalizedEffective.Before(threshold) {
return "due_soon_tasks"
}
}
// Priority 7: Upcoming (default)
return "upcoming_tasks"
}
// TaskCompletion represents the task_taskcompletion table
type TaskCompletion struct {
BaseModel

View File

@@ -248,180 +248,10 @@ func TestDocument_JSONSerialization(t *testing.T) {
assert.Equal(t, "5000", result["purchase_price"]) // Decimal serializes as string
}
// ============================================================================
// TASK KANBAN COLUMN TESTS
// These tests verify GetKanbanColumn and GetKanbanColumnWithTimezone methods
// ============================================================================
func timePtr(t time.Time) *time.Time {
return &t
}
func TestTask_GetKanbanColumn_PriorityOrder(t *testing.T) {
now := time.Date(2025, 12, 16, 12, 0, 0, 0, time.UTC)
yesterday := time.Date(2025, 12, 15, 0, 0, 0, 0, time.UTC)
in5Days := time.Date(2025, 12, 21, 0, 0, 0, 0, time.UTC)
in60Days := time.Date(2026, 2, 14, 0, 0, 0, 0, time.UTC)
tests := []struct {
name string
task *Task
expected string
}{
// Priority 1: Cancelled
{
name: "cancelled takes highest priority",
task: &Task{
IsCancelled: true,
NextDueDate: timePtr(yesterday),
InProgress: true,
},
expected: "cancelled_tasks",
},
// Priority 2: Completed
{
name: "completed: NextDueDate nil with completions",
task: &Task{
IsCancelled: false,
NextDueDate: nil,
DueDate: timePtr(yesterday),
Completions: []TaskCompletion{{BaseModel: BaseModel{ID: 1}}},
},
expected: "completed_tasks",
},
// Priority 3: In Progress
{
name: "in progress takes priority over overdue",
task: &Task{
IsCancelled: false,
NextDueDate: timePtr(yesterday),
InProgress: true,
},
expected: "in_progress_tasks",
},
// Priority 4: Overdue
{
name: "overdue: effective date in past",
task: &Task{
IsCancelled: false,
NextDueDate: timePtr(yesterday),
},
expected: "overdue_tasks",
},
// Priority 5: Due Soon
{
name: "due soon: within 30-day threshold",
task: &Task{
IsCancelled: false,
NextDueDate: timePtr(in5Days),
},
expected: "due_soon_tasks",
},
// Priority 6: Upcoming
{
name: "upcoming: beyond threshold",
task: &Task{
IsCancelled: false,
NextDueDate: timePtr(in60Days),
},
expected: "upcoming_tasks",
},
{
name: "upcoming: no due date",
task: &Task{
IsCancelled: false,
NextDueDate: nil,
DueDate: nil,
},
expected: "upcoming_tasks",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.task.GetKanbanColumnWithTimezone(30, now)
assert.Equal(t, tt.expected, result)
})
}
}
func TestTask_GetKanbanColumnWithTimezone_TimezoneAware(t *testing.T) {
// Task due Dec 17, 2025
taskDueDate := time.Date(2025, 12, 17, 0, 0, 0, 0, time.UTC)
task := &Task{
NextDueDate: timePtr(taskDueDate),
IsCancelled: false,
}
// At 11 PM UTC on Dec 16 (UTC user) - task is tomorrow, due_soon
utcDec16Evening := time.Date(2025, 12, 16, 23, 0, 0, 0, time.UTC)
result := task.GetKanbanColumnWithTimezone(30, utcDec16Evening)
assert.Equal(t, "due_soon_tasks", result, "UTC Dec 16 evening")
// At 8 AM UTC on Dec 17 (UTC user) - task is today, due_soon
utcDec17Morning := time.Date(2025, 12, 17, 8, 0, 0, 0, time.UTC)
result = task.GetKanbanColumnWithTimezone(30, utcDec17Morning)
assert.Equal(t, "due_soon_tasks", result, "UTC Dec 17 morning")
// At 8 AM UTC on Dec 18 (UTC user) - task was yesterday, overdue
utcDec18Morning := time.Date(2025, 12, 18, 8, 0, 0, 0, time.UTC)
result = task.GetKanbanColumnWithTimezone(30, utcDec18Morning)
assert.Equal(t, "overdue_tasks", result, "UTC Dec 18 morning")
// Tokyo user at 11 PM UTC Dec 16 = 8 AM Dec 17 Tokyo
// Task due Dec 17 is TODAY for Tokyo user - due_soon
tokyo, _ := time.LoadLocation("Asia/Tokyo")
tokyoDec17Morning := utcDec16Evening.In(tokyo)
result = task.GetKanbanColumnWithTimezone(30, tokyoDec17Morning)
assert.Equal(t, "due_soon_tasks", result, "Tokyo Dec 17 morning")
// Tokyo at 8 AM Dec 18 UTC = 5 PM Dec 18 Tokyo
// Task due Dec 17 was YESTERDAY for Tokyo - overdue
tokyoDec18 := utcDec18Morning.In(tokyo)
result = task.GetKanbanColumnWithTimezone(30, tokyoDec18)
assert.Equal(t, "overdue_tasks", result, "Tokyo Dec 18")
}
func TestTask_GetKanbanColumnWithTimezone_DueSoonThreshold(t *testing.T) {
now := time.Date(2025, 12, 16, 12, 0, 0, 0, time.UTC)
// Task due in 29 days - within 30-day threshold
due29Days := time.Date(2026, 1, 14, 0, 0, 0, 0, time.UTC)
task29 := &Task{NextDueDate: timePtr(due29Days)}
result := task29.GetKanbanColumnWithTimezone(30, now)
assert.Equal(t, "due_soon_tasks", result, "29 days should be due_soon")
// Task due in exactly 30 days - at threshold boundary (upcoming, not due_soon)
due30Days := time.Date(2026, 1, 15, 0, 0, 0, 0, time.UTC)
task30 := &Task{NextDueDate: timePtr(due30Days)}
result = task30.GetKanbanColumnWithTimezone(30, now)
assert.Equal(t, "upcoming_tasks", result, "30 days should be upcoming (at boundary)")
// Task due in 31 days - beyond threshold
due31Days := time.Date(2026, 1, 16, 0, 0, 0, 0, time.UTC)
task31 := &Task{NextDueDate: timePtr(due31Days)}
result = task31.GetKanbanColumnWithTimezone(30, now)
assert.Equal(t, "upcoming_tasks", result, "31 days should be upcoming")
}
func TestTask_GetKanbanColumn_CompletionCount(t *testing.T) {
// Test that CompletionCount is also used for completion detection
task := &Task{
NextDueDate: nil,
CompletionCount: 1, // Using CompletionCount instead of Completions slice
Completions: []TaskCompletion{},
}
result := task.GetKanbanColumn(30)
assert.Equal(t, "completed_tasks", result)
}
func TestTask_IsOverdueAt_DayBased(t *testing.T) {
// Test that IsOverdueAt uses day-based comparison
now := time.Date(2025, 12, 16, 15, 0, 0, 0, time.UTC) // 3 PM UTC

View File

@@ -18,7 +18,7 @@ type User struct {
Username string `gorm:"column:username;uniqueIndex;size:150;not null" json:"username"`
FirstName string `gorm:"column:first_name;size:150" json:"first_name"`
LastName string `gorm:"column:last_name;size:150" json:"last_name"`
Email string `gorm:"column:email;size:254" json:"email"`
Email string `gorm:"column:email;uniqueIndex;size:254" json:"email"`
IsStaff bool `gorm:"column:is_staff;default:false" json:"is_staff"`
IsActive bool `gorm:"column:is_active;default:true" json:"is_active"`
DateJoined time.Time `gorm:"column:date_joined;autoCreateTime" json:"date_joined"`
@@ -142,7 +142,7 @@ func (UserProfile) TableName() string {
type ConfirmationCode struct {
BaseModel
UserID uint `gorm:"column:user_id;index;not null" json:"user_id"`
Code string `gorm:"column:code;size:6;not null" json:"code"`
Code string `gorm:"column:code;size:6;not null" json:"-"`
ExpiresAt time.Time `gorm:"column:expires_at;not null" json:"expires_at"`
IsUsed bool `gorm:"column:is_used;default:false" json:"is_used"`

View File

@@ -2,6 +2,7 @@ package monitoring
import (
"runtime"
"sync"
"time"
"github.com/hibiken/asynq"
@@ -20,6 +21,7 @@ type Collector struct {
httpCollector *HTTPStatsCollector // nil for worker
asynqClient *asynq.Inspector // nil for api
stopChan chan struct{}
stopOnce sync.Once
}
// NewCollector creates a new stats collector
@@ -193,7 +195,9 @@ func (c *Collector) publishStats() {
}
}
// Stop stops the stats publishing
// Stop stops the stats publishing. It is safe to call multiple times.
func (c *Collector) Stop() {
close(c.stopChan)
c.stopOnce.Do(func() {
close(c.stopChan)
})
}

View File

@@ -150,7 +150,10 @@ func (h *Handler) WebSocket(c echo.Context) error {
defer statsTicker.Stop()
// Send initial stats
h.sendStats(conn, &wsMu)
if err := h.sendStats(conn, &wsMu); err != nil {
cancel()
return nil
}
for {
select {
@@ -173,11 +176,16 @@ func (h *Handler) WebSocket(c echo.Context) error {
if err != nil {
log.Debug().Err(err).Msg("WebSocket write error")
cancel()
return nil
}
case <-statsTicker.C:
// Send periodic stats update
h.sendStats(conn, &wsMu)
if err := h.sendStats(conn, &wsMu); err != nil {
cancel()
return nil
}
case <-ctx.Done():
return nil
@@ -185,9 +193,11 @@ func (h *Handler) WebSocket(c echo.Context) error {
}
}
func (h *Handler) sendStats(conn *websocket.Conn, mu *sync.Mutex) {
func (h *Handler) sendStats(conn *websocket.Conn, mu *sync.Mutex) error {
allStats, err := h.statsStore.GetAllStats()
if err != nil {
log.Error().Err(err).Msg("failed to send stats")
return err
}
wsMsg := WSMessage{
@@ -196,6 +206,12 @@ func (h *Handler) sendStats(conn *websocket.Conn, mu *sync.Mutex) {
}
mu.Lock()
conn.WriteJSON(wsMsg)
err = conn.WriteJSON(wsMsg)
mu.Unlock()
if err != nil {
log.Debug().Err(err).Msg("WebSocket write error sending stats")
}
return err
}

View File

@@ -10,39 +10,33 @@ import (
// HTTPStatsCollector collects HTTP request metrics
type HTTPStatsCollector struct {
mu sync.RWMutex
requests map[string]int64 // endpoint -> count
totalLatency map[string]time.Duration // endpoint -> total latency
errors map[string]int64 // endpoint -> error count
byStatus map[int]int64 // status code -> count
latencies []latencySample // recent latency samples for P95
startTime time.Time
lastReset time.Time
}
type latencySample struct {
endpoint string
latency time.Duration
timestamp time.Time
mu sync.RWMutex
requests map[string]int64 // endpoint -> count
totalLatency map[string]time.Duration // endpoint -> total latency
errors map[string]int64 // endpoint -> error count
byStatus map[int]int64 // status code -> count
endpointLatencies map[string][]time.Duration // per-endpoint sorted latency buffers for P95
startTime time.Time
lastReset time.Time
}
const (
maxLatencySamples = 1000
maxEndpoints = 200 // Cap unique endpoints tracked
statsResetPeriod = 1 * time.Hour // Reset stats periodically to prevent unbounded growth
maxLatencySamplesPerEndpoint = 200 // Max latency samples kept per endpoint
maxEndpoints = 200 // Cap unique endpoints tracked
statsResetPeriod = 1 * time.Hour // Reset stats periodically to prevent unbounded growth
)
// NewHTTPStatsCollector creates a new HTTP stats collector
func NewHTTPStatsCollector() *HTTPStatsCollector {
now := time.Now()
return &HTTPStatsCollector{
requests: make(map[string]int64),
totalLatency: make(map[string]time.Duration),
errors: make(map[string]int64),
byStatus: make(map[int]int64),
latencies: make([]latencySample, 0, maxLatencySamples),
startTime: now,
lastReset: now,
requests: make(map[string]int64),
totalLatency: make(map[string]time.Duration),
errors: make(map[string]int64),
byStatus: make(map[int]int64),
endpointLatencies: make(map[string][]time.Duration),
startTime: now,
lastReset: now,
}
}
@@ -70,17 +64,22 @@ func (c *HTTPStatsCollector) Record(endpoint string, latency time.Duration, stat
c.errors[endpoint]++
}
// Store latency sample
c.latencies = append(c.latencies, latencySample{
endpoint: endpoint,
latency: latency,
timestamp: time.Now(),
// Insert latency into per-endpoint sorted buffer using binary search
buf := c.endpointLatencies[endpoint]
idx := sort.Search(len(buf), func(i int) bool {
return buf[i] >= latency
})
buf = append(buf, 0)
copy(buf[idx+1:], buf[idx:])
buf[idx] = latency
// Keep only recent samples
if len(c.latencies) > maxLatencySamples {
c.latencies = c.latencies[len(c.latencies)-maxLatencySamples:]
// Trim to max samples per endpoint by removing the median element
// to preserve distribution tails (important for P95 accuracy)
if len(buf) > maxLatencySamplesPerEndpoint {
mid := len(buf) / 2
buf = append(buf[:mid], buf[mid+1:]...)
}
c.endpointLatencies[endpoint] = buf
}
// resetLocked resets stats while holding the lock
@@ -89,7 +88,7 @@ func (c *HTTPStatsCollector) resetLocked() {
c.totalLatency = make(map[string]time.Duration)
c.errors = make(map[string]int64)
c.byStatus = make(map[int]int64)
c.latencies = make([]latencySample, 0, maxLatencySamples)
c.endpointLatencies = make(map[string][]time.Duration)
c.lastReset = time.Now()
// Keep startTime for uptime calculation
}
@@ -147,33 +146,23 @@ func (c *HTTPStatsCollector) GetStats() HTTPStats {
return stats
}
// calculateP95 calculates the 95th percentile latency for an endpoint
// Must be called with read lock held
// calculateP95 calculates the 95th percentile latency for an endpoint.
// The per-endpoint buffer is maintained in sorted order during insertion,
// so this is an O(1) index lookup.
// Must be called with read lock held.
func (c *HTTPStatsCollector) calculateP95(endpoint string) float64 {
var endpointLatencies []time.Duration
for _, sample := range c.latencies {
if sample.endpoint == endpoint {
endpointLatencies = append(endpointLatencies, sample.latency)
}
}
if len(endpointLatencies) == 0 {
buf := c.endpointLatencies[endpoint]
if len(buf) == 0 {
return 0
}
// Sort latencies
sort.Slice(endpointLatencies, func(i, j int) bool {
return endpointLatencies[i] < endpointLatencies[j]
})
// Calculate P95 index
p95Index := int(float64(len(endpointLatencies)) * 0.95)
if p95Index >= len(endpointLatencies) {
p95Index = len(endpointLatencies) - 1
// Buffer is already sorted; direct index lookup
p95Index := int(float64(len(buf)) * 0.95)
if p95Index >= len(buf) {
p95Index = len(buf) - 1
}
return float64(endpointLatencies[p95Index].Milliseconds())
return float64(buf[p95Index].Milliseconds())
}
// Reset clears all collected stats
@@ -185,7 +174,7 @@ func (c *HTTPStatsCollector) Reset() {
c.totalLatency = make(map[string]time.Duration)
c.errors = make(map[string]int64)
c.byStatus = make(map[int]int64)
c.latencies = make([]latencySample, 0, maxLatencySamples)
c.endpointLatencies = make(map[string][]time.Duration)
c.startTime = time.Now()
}

View File

@@ -2,6 +2,7 @@ package monitoring
import (
"io"
"sync"
"time"
"github.com/hibiken/asynq"
@@ -31,6 +32,8 @@ type Service struct {
logWriter *RedisLogWriter
db *gorm.DB
settingsStopCh chan struct{}
stopOnce sync.Once
statsInterval time.Duration
}
// Config holds configuration for the monitoring service
@@ -71,6 +74,7 @@ func NewService(cfg Config) *Service {
logWriter: logWriter,
db: cfg.DB,
settingsStopCh: make(chan struct{}),
statsInterval: cfg.StatsInterval,
}
// Check initial setting from database
@@ -90,11 +94,11 @@ func (s *Service) SetAsynqInspector(inspector *asynq.Inspector) {
func (s *Service) Start() {
log.Info().
Str("process", s.process).
Dur("interval", DefaultStatsInterval).
Dur("interval", s.statsInterval).
Bool("enabled", s.logWriter.IsEnabled()).
Msg("Starting monitoring service")
s.collector.StartPublishing(DefaultStatsInterval)
s.collector.StartPublishing(s.statsInterval)
// Start settings sync if database is available
if s.db != nil {
@@ -102,17 +106,19 @@ func (s *Service) Start() {
}
}
// Stop stops the monitoring service
// Stop stops the monitoring service. It is safe to call multiple times.
func (s *Service) Stop() {
// Stop settings sync
close(s.settingsStopCh)
s.stopOnce.Do(func() {
// Stop settings sync
close(s.settingsStopCh)
s.collector.Stop()
s.collector.Stop()
// Flush and close the log writer's background goroutine
s.logWriter.Close()
// Flush and close the log writer's background goroutine
s.logWriter.Close()
log.Info().Str("process", s.process).Msg("Monitoring service stopped")
log.Info().Str("process", s.process).Msg("Monitoring service stopped")
})
}
// syncSettingsFromDB checks the database for the enable_monitoring setting

View File

@@ -2,6 +2,7 @@ package monitoring
import (
"encoding/json"
"sync"
"sync/atomic"
"time"
@@ -18,11 +19,12 @@ const (
// It uses a single background goroutine with a buffered channel instead of
// spawning a new goroutine per log line, preventing unbounded goroutine growth.
type RedisLogWriter struct {
buffer *LogBuffer
process string
enabled atomic.Bool
ch chan LogEntry
done chan struct{}
buffer *LogBuffer
process string
enabled atomic.Bool
ch chan LogEntry
done chan struct{}
closeOnce sync.Once
}
// NewRedisLogWriter creates a new writer that captures logs to Redis.
@@ -53,9 +55,12 @@ func (w *RedisLogWriter) drainLoop() {
// Close shuts down the background goroutine. It should be called during
// graceful shutdown to ensure all buffered entries are flushed.
// It is safe to call multiple times.
func (w *RedisLogWriter) Close() {
close(w.ch)
<-w.done // Wait for drain to finish
w.closeOnce.Do(func() {
close(w.ch)
<-w.done // Wait for drain to finish
})
}
// SetEnabled enables or disables log capture to Redis

View File

@@ -38,16 +38,15 @@ func NewAPNsClient(cfg *config.PushConfig) (*APNsClient, error) {
TeamID: cfg.APNSTeamID,
}
// Create client - production or sandbox
// Use APNSProduction if set, otherwise fall back to inverse of APNSSandbox
// Create client - sandbox if APNSSandbox is true, production otherwise.
// APNSSandbox is the single source of truth (defaults to true for safety).
var client *apns2.Client
useProduction := cfg.APNSProduction || !cfg.APNSSandbox
if useProduction {
client = apns2.NewTokenClient(authToken).Production()
log.Info().Msg("APNs client configured for PRODUCTION")
} else {
if cfg.APNSSandbox {
client = apns2.NewTokenClient(authToken).Development()
log.Info().Msg("APNs client configured for DEVELOPMENT/SANDBOX")
} else {
client = apns2.NewTokenClient(authToken).Production()
log.Info().Msg("APNs client configured for PRODUCTION")
}
return &APNsClient{

View File

@@ -38,17 +38,17 @@ func NewClient(cfg *config.PushConfig, enabled bool) (*Client, error) {
log.Warn().Msg("APNs not configured - iOS push disabled")
}
// Initialize FCM client (Android)
if cfg.FCMServerKey != "" {
// Initialize FCM client (Android) - requires project ID + service account credentials
if cfg.FCMProjectID != "" && (cfg.FCMServiceAccountPath != "" || cfg.FCMServiceAccountJSON != "") {
fcmClient, err := NewFCMClient(cfg)
if err != nil {
log.Warn().Err(err).Msg("Failed to initialize FCM client - Android push disabled")
log.Warn().Err(err).Msg("Failed to initialize FCM v1 client - Android push disabled")
} else {
client.fcm = fcmClient
log.Info().Msg("FCM client initialized successfully")
log.Info().Msg("FCM v1 client initialized successfully")
}
} else {
log.Warn().Msg("FCM not configured - Android push disabled")
log.Warn().Msg("FCM not configured (need FCM_PROJECT_ID + FCM_SERVICE_ACCOUNT_PATH or FCM_SERVICE_ACCOUNT_JSON) - Android push disabled")
}
return client, nil

View File

@@ -5,138 +5,304 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"time"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2/google"
"github.com/treytartt/honeydue-api/internal/config"
)
const fcmEndpoint = "https://fcm.googleapis.com/fcm/send"
const (
// fcmV1EndpointFmt is the FCM HTTP v1 API endpoint template.
fcmV1EndpointFmt = "https://fcm.googleapis.com/v1/projects/%s/messages:send"
// FCMClient handles direct communication with Firebase Cloud Messaging
// fcmScope is the OAuth 2.0 scope required for FCM HTTP v1 API.
fcmScope = "https://www.googleapis.com/auth/firebase.messaging"
)
// FCMClient handles communication with Firebase Cloud Messaging
// using the HTTP v1 API and OAuth 2.0 service account authentication.
type FCMClient struct {
serverKey string
projectID string
endpoint string
httpClient *http.Client
}
// FCMMessage represents an FCM message payload
type FCMMessage struct {
To string `json:"to,omitempty"`
RegistrationIDs []string `json:"registration_ids,omitempty"`
Notification *FCMNotification `json:"notification,omitempty"`
Data map[string]string `json:"data,omitempty"`
Priority string `json:"priority,omitempty"`
ContentAvailable bool `json:"content_available,omitempty"`
// --- Request types (FCM v1 API) ---
// fcmV1Request is the top-level request body for the FCM v1 API.
type fcmV1Request struct {
Message *fcmV1Message `json:"message"`
}
// FCMNotification represents the notification payload
// fcmV1Message represents a single FCM v1 message.
type fcmV1Message struct {
Token string `json:"token"`
Notification *FCMNotification `json:"notification,omitempty"`
Data map[string]string `json:"data,omitempty"`
Android *fcmAndroidConfig `json:"android,omitempty"`
}
// FCMNotification represents the notification payload.
type FCMNotification struct {
Title string `json:"title,omitempty"`
Body string `json:"body,omitempty"`
Sound string `json:"sound,omitempty"`
Badge string `json:"badge,omitempty"`
Icon string `json:"icon,omitempty"`
}
// FCMResponse represents the FCM API response
type FCMResponse struct {
MulticastID int64 `json:"multicast_id"`
Success int `json:"success"`
Failure int `json:"failure"`
CanonicalIDs int `json:"canonical_ids"`
Results []FCMResult `json:"results"`
// fcmAndroidConfig provides Android-specific message configuration.
type fcmAndroidConfig struct {
Priority string `json:"priority,omitempty"`
}
// FCMResult represents a single result in the FCM response
type FCMResult struct {
MessageID string `json:"message_id,omitempty"`
RegistrationID string `json:"registration_id,omitempty"`
Error string `json:"error,omitempty"`
// --- Response types (FCM v1 API) ---
// fcmV1Response is the successful response from the FCM v1 API.
type fcmV1Response struct {
Name string `json:"name"` // e.g. "projects/myproject/messages/0:1234567890"
}
// NewFCMClient creates a new FCM client
// fcmV1ErrorResponse is the error response from the FCM v1 API.
type fcmV1ErrorResponse struct {
Error fcmV1Error `json:"error"`
}
// fcmV1Error contains the structured error details.
type fcmV1Error struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
Details json.RawMessage `json:"details,omitempty"`
}
// --- Error types ---
// FCMErrorCode represents well-known FCM v1 error codes for programmatic handling.
type FCMErrorCode string
const (
FCMErrUnregistered FCMErrorCode = "UNREGISTERED"
FCMErrQuotaExceeded FCMErrorCode = "QUOTA_EXCEEDED"
FCMErrUnavailable FCMErrorCode = "UNAVAILABLE"
FCMErrInternal FCMErrorCode = "INTERNAL"
FCMErrInvalidArgument FCMErrorCode = "INVALID_ARGUMENT"
FCMErrSenderIDMismatch FCMErrorCode = "SENDER_ID_MISMATCH"
FCMErrThirdPartyAuth FCMErrorCode = "THIRD_PARTY_AUTH_ERROR"
)
// FCMSendError is a structured error returned when an individual FCM send fails.
type FCMSendError struct {
Token string
StatusCode int
ErrorCode FCMErrorCode
Message string
}
func (e *FCMSendError) Error() string {
return fmt.Sprintf("FCM send failed for token %s: %s (status %d, code %s)",
truncateToken(e.Token), e.Message, e.StatusCode, e.ErrorCode)
}
// IsUnregistered returns true if the device token is no longer valid and should be removed.
func (e *FCMSendError) IsUnregistered() bool {
return e.ErrorCode == FCMErrUnregistered
}
// --- OAuth 2.0 transport ---
// oauth2BearerTransport is an http.RoundTripper that attaches an OAuth 2.0 Bearer
// token to every outgoing request. The token source handles refresh automatically.
type oauth2BearerTransport struct {
base http.RoundTripper
getToken func() (string, error)
}
func (t *oauth2BearerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
accessToken, err := t.getToken()
if err != nil {
return nil, fmt.Errorf("failed to obtain OAuth 2.0 token for FCM: %w", err)
}
r := req.Clone(req.Context())
r.Header.Set("Authorization", "Bearer "+accessToken)
return t.base.RoundTrip(r)
}
// --- Client construction ---
// NewFCMClient creates a new FCM client using the HTTP v1 API with OAuth 2.0
// service account authentication. It accepts either a path to a service account
// JSON file or the raw JSON content directly via config.
func NewFCMClient(cfg *config.PushConfig) (*FCMClient, error) {
if cfg.FCMServerKey == "" {
return nil, fmt.Errorf("FCM server key not configured")
if cfg.FCMProjectID == "" {
return nil, fmt.Errorf("FCM project ID not configured (set FCM_PROJECT_ID)")
}
return &FCMClient{
serverKey: cfg.FCMServerKey,
httpClient: &http.Client{
Timeout: 30 * time.Second,
credJSON, err := resolveServiceAccountJSON(cfg)
if err != nil {
return nil, err
}
// Create OAuth 2.0 credentials with the FCM messaging scope.
// The google library handles automatic token refresh.
creds, err := google.CredentialsFromJSON(context.Background(), credJSON, fcmScope)
if err != nil {
return nil, fmt.Errorf("failed to parse FCM service account credentials: %w", err)
}
// Build an HTTP client that automatically attaches and refreshes OAuth tokens.
transport := &oauth2BearerTransport{
base: http.DefaultTransport,
getToken: func() (string, error) {
tok, err := creds.TokenSource.Token()
if err != nil {
return "", err
}
return tok.AccessToken, nil
},
}
httpClient := &http.Client{
Timeout: 30 * time.Second,
Transport: transport,
}
endpoint := fmt.Sprintf(fcmV1EndpointFmt, cfg.FCMProjectID)
log.Info().
Str("project_id", cfg.FCMProjectID).
Str("endpoint", endpoint).
Msg("FCM v1 client initialized with OAuth 2.0")
return &FCMClient{
projectID: cfg.FCMProjectID,
endpoint: endpoint,
httpClient: httpClient,
}, nil
}
// Send sends a push notification to Android devices
// resolveServiceAccountJSON returns the service account JSON bytes from config.
func resolveServiceAccountJSON(cfg *config.PushConfig) ([]byte, error) {
if cfg.FCMServiceAccountJSON != "" {
return []byte(cfg.FCMServiceAccountJSON), nil
}
if cfg.FCMServiceAccountPath != "" {
data, err := os.ReadFile(cfg.FCMServiceAccountPath)
if err != nil {
return nil, fmt.Errorf("failed to read FCM service account file %s: %w", cfg.FCMServiceAccountPath, err)
}
return data, nil
}
return nil, fmt.Errorf("FCM service account not configured (set FCM_SERVICE_ACCOUNT_PATH or FCM_SERVICE_ACCOUNT_JSON)")
}
// --- Sending ---
// Send sends a push notification to Android devices via the FCM HTTP v1 API.
// The v1 API requires one request per device token, so this iterates over all tokens.
// The method signature is kept identical to the previous legacy implementation
// so callers do not need to change.
func (c *FCMClient) Send(ctx context.Context, tokens []string, title, message string, data map[string]string) error {
if len(tokens) == 0 {
return nil
}
msg := FCMMessage{
RegistrationIDs: tokens,
Notification: &FCMNotification{
Title: title,
Body: message,
Sound: "default",
},
Data: data,
Priority: "high",
}
var sendErrors []error
successCount := 0
body, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("failed to marshal FCM message: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", fcmEndpoint, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("failed to create FCM request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "key="+c.serverKey)
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to send FCM request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("FCM returned status %d", resp.StatusCode)
}
var fcmResp FCMResponse
if err := json.NewDecoder(resp.Body).Decode(&fcmResp); err != nil {
return fmt.Errorf("failed to decode FCM response: %w", err)
}
// Log individual results
for i, result := range fcmResp.Results {
if i >= len(tokens) {
break
}
if result.Error != "" {
for _, token := range tokens {
err := c.sendOne(ctx, token, title, message, data)
if err != nil {
log.Error().
Str("token", truncateToken(tokens[i])).
Str("error", result.Error).
Msg("FCM notification failed")
Err(err).
Str("token", truncateToken(token)).
Msg("FCM v1 notification failed")
sendErrors = append(sendErrors, err)
continue
}
successCount++
log.Debug().
Str("token", truncateToken(token)).
Msg("FCM v1 notification sent successfully")
}
log.Info().
Int("total", len(tokens)).
Int("success", fcmResp.Success).
Int("failure", fcmResp.Failure).
Msg("FCM batch send complete")
Int("success", successCount).
Int("failed", len(sendErrors)).
Msg("FCM v1 batch send complete")
if fcmResp.Success == 0 && fcmResp.Failure > 0 {
return fmt.Errorf("all FCM notifications failed")
if len(sendErrors) > 0 && successCount == 0 {
return fmt.Errorf("all FCM notifications failed: first error: %w", sendErrors[0])
}
return nil
}
// sendOne sends a single FCM v1 message to one device token.
func (c *FCMClient) sendOne(ctx context.Context, token, title, message string, data map[string]string) error {
reqBody := fcmV1Request{
Message: &fcmV1Message{
Token: token,
Notification: &FCMNotification{
Title: title,
Body: message,
},
Data: data,
Android: &fcmAndroidConfig{
Priority: "HIGH",
},
},
}
body, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("failed to marshal FCM v1 message: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("failed to create FCM v1 request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to send FCM v1 request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read FCM v1 response body: %w", err)
}
if resp.StatusCode == http.StatusOK {
return nil
}
// Parse the error response for structured error information.
return parseFCMV1Error(token, resp.StatusCode, respBody)
}
// parseFCMV1Error extracts a structured FCMSendError from the v1 API error response.
func parseFCMV1Error(token string, statusCode int, body []byte) *FCMSendError {
var errResp fcmV1ErrorResponse
if err := json.Unmarshal(body, &errResp); err != nil {
return &FCMSendError{
Token: token,
StatusCode: statusCode,
Message: fmt.Sprintf("unparseable error response: %s", string(body)),
}
}
return &FCMSendError{
Token: token,
StatusCode: statusCode,
ErrorCode: FCMErrorCode(errResp.Error.Status),
Message: errResp.Error.Message,
}
}

View File

@@ -5,182 +5,266 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newTestFCMClient creates an FCMClient pointing at the given test server URL.
// newTestFCMClient creates an FCMClient whose endpoint points at the given test
// server URL. The HTTP client uses no OAuth transport so tests can run without
// real Google credentials.
func newTestFCMClient(serverURL string) *FCMClient {
return &FCMClient{
serverKey: "test-server-key",
httpClient: http.DefaultClient,
projectID: "test-project",
endpoint: serverURL,
httpClient: &http.Client{},
}
}
// serveFCMResponse creates an httptest.Server that returns the given FCMResponse as JSON.
func serveFCMResponse(t *testing.T, resp FCMResponse) *httptest.Server {
// serveFCMV1Success creates a test server that returns a successful v1 response
// for every request.
func serveFCMV1Success(t *testing.T) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify it looks like a v1 request.
var req fcmV1Request
err := json.NewDecoder(r.Body).Decode(&req)
require.NoError(t, err)
require.NotNil(t, req.Message)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
err := json.NewEncoder(w).Encode(resp)
require.NoError(t, err)
resp := fcmV1Response{Name: "projects/test-project/messages/0:12345"}
_ = json.NewEncoder(w).Encode(resp)
}))
}
// sendWithEndpoint is a helper that sends an FCM notification using a custom endpoint
// (the test server) instead of the real FCM endpoint. This avoids modifying the
// production code to be testable and instead temporarily overrides the client's HTTP
// transport to redirect requests to our test server.
func sendWithEndpoint(client *FCMClient, server *httptest.Server, ctx context.Context, tokens []string, title, message string, data map[string]string) error {
// Override the HTTP client to redirect all requests to the test server
client.httpClient = server.Client()
// We need to intercept the request and redirect it to our test server.
// Use a custom RoundTripper that rewrites the URL.
originalTransport := server.Client().Transport
client.httpClient.Transport = roundTripFunc(func(req *http.Request) (*http.Response, error) {
// Rewrite the URL to point to the test server
req.URL.Scheme = "http"
req.URL.Host = server.Listener.Addr().String()
if originalTransport != nil {
return originalTransport.RoundTrip(req)
// serveFCMV1Error creates a test server that returns the given status code and
// a structured v1 error response for every request.
func serveFCMV1Error(t *testing.T, statusCode int, errStatus string, errMessage string) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
resp := fcmV1ErrorResponse{
Error: fcmV1Error{
Code: statusCode,
Message: errMessage,
Status: errStatus,
},
}
return http.DefaultTransport.RoundTrip(req)
})
return client.Send(ctx, tokens, title, message, data)
_ = json.NewEncoder(w).Encode(resp)
}))
}
// roundTripFunc is a function that implements http.RoundTripper.
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestFCMSend_MoreResultsThanTokens_NoPanic(t *testing.T) {
// FCM returns 5 results but we only sent 2 tokens.
// Before the bounds check fix, this would panic with index out of range.
fcmResp := FCMResponse{
MulticastID: 12345,
Success: 2,
Failure: 3,
Results: []FCMResult{
{MessageID: "msg1"},
{MessageID: "msg2"},
{Error: "InvalidRegistration"},
{Error: "NotRegistered"},
{Error: "InvalidRegistration"},
},
}
server := serveFCMResponse(t, fcmResp)
func TestFCMV1Send_Success_SingleToken(t *testing.T) {
server := serveFCMV1Success(t)
defer server.Close()
client := newTestFCMClient(server.URL)
tokens := []string{"token-aaa-111", "token-bbb-222"}
// This must not panic
err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil)
err := client.Send(context.Background(), []string{"token-aaa-111"}, "Title", "Body", nil)
assert.NoError(t, err)
}
func TestFCMSend_FewerResultsThanTokens_NoPanic(t *testing.T) {
// FCM returns fewer results than tokens we sent.
// This is also a malformed response but should not panic.
fcmResp := FCMResponse{
MulticastID: 12345,
Success: 1,
Failure: 0,
Results: []FCMResult{
{MessageID: "msg1"},
},
}
func TestFCMV1Send_Success_MultipleTokens(t *testing.T) {
var mu sync.Mutex
receivedTokens := make([]string, 0)
server := serveFCMResponse(t, fcmResp)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req fcmV1Request
_ = json.NewDecoder(r.Body).Decode(&req)
mu.Lock()
receivedTokens = append(receivedTokens, req.Message.Token)
mu.Unlock()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
resp := fcmV1Response{Name: "projects/test-project/messages/0:12345"}
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
client := newTestFCMClient(server.URL)
tokens := []string{"token-aaa-111", "token-bbb-222", "token-ccc-333"}
err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil)
err := client.Send(context.Background(), tokens, "Title", "Body", map[string]string{"key": "value"})
assert.NoError(t, err)
assert.ElementsMatch(t, tokens, receivedTokens)
}
func TestFCMSend_EmptyResponse_NoPanic(t *testing.T) {
// FCM returns an empty Results slice.
fcmResp := FCMResponse{
MulticastID: 12345,
Success: 0,
Failure: 0,
Results: []FCMResult{},
}
server := serveFCMResponse(t, fcmResp)
defer server.Close()
client := newTestFCMClient(server.URL)
tokens := []string{"token-aaa-111"}
err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil)
// No panic expected. The function returns nil because fcmResp.Success == 0
// and fcmResp.Failure == 0 (the "all failed" check requires Failure > 0).
assert.NoError(t, err)
}
func TestFCMSend_NilResultsSlice_NoPanic(t *testing.T) {
// FCM returns a response with nil Results (e.g., malformed JSON).
fcmResp := FCMResponse{
MulticastID: 12345,
Success: 0,
Failure: 1,
}
server := serveFCMResponse(t, fcmResp)
defer server.Close()
client := newTestFCMClient(server.URL)
tokens := []string{"token-aaa-111"}
err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil)
// Should return error because Success == 0 and Failure > 0
assert.Error(t, err)
assert.Contains(t, err.Error(), "all FCM notifications failed")
}
func TestFCMSend_EmptyTokens_ReturnsNil(t *testing.T) {
// Verify the early return for empty tokens.
func TestFCMV1Send_EmptyTokens_ReturnsNil(t *testing.T) {
client := &FCMClient{
serverKey: "test-key",
httpClient: http.DefaultClient,
projectID: "test-project",
endpoint: "http://unused",
httpClient: &http.Client{},
}
err := client.Send(context.Background(), []string{}, "Test", "Body", nil)
err := client.Send(context.Background(), []string{}, "Title", "Body", nil)
assert.NoError(t, err)
}
func TestFCMSend_ResultsWithErrorsMatchTokens(t *testing.T) {
// Normal case: results count matches tokens count, all with errors.
fcmResp := FCMResponse{
MulticastID: 12345,
Success: 0,
Failure: 2,
Results: []FCMResult{
{Error: "InvalidRegistration"},
{Error: "NotRegistered"},
},
}
server := serveFCMResponse(t, fcmResp)
func TestFCMV1Send_AllFail_ReturnsError(t *testing.T) {
server := serveFCMV1Error(t, http.StatusNotFound, "UNREGISTERED", "The registration token is not registered")
defer server.Close()
client := newTestFCMClient(server.URL)
tokens := []string{"token-aaa-111", "token-bbb-222"}
tokens := []string{"bad-token-1", "bad-token-2"}
err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil)
err := client.Send(context.Background(), tokens, "Title", "Body", nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "all FCM notifications failed")
}
func TestFCMV1Send_PartialFailure_ReturnsNil(t *testing.T) {
var mu sync.Mutex
callCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
callCount++
n := callCount
mu.Unlock()
w.Header().Set("Content-Type", "application/json")
if n == 1 {
// First token succeeds
w.WriteHeader(http.StatusOK)
resp := fcmV1Response{Name: "projects/test-project/messages/0:12345"}
_ = json.NewEncoder(w).Encode(resp)
} else {
// Second token fails
w.WriteHeader(http.StatusNotFound)
resp := fcmV1ErrorResponse{
Error: fcmV1Error{Code: 404, Message: "not registered", Status: "UNREGISTERED"},
}
_ = json.NewEncoder(w).Encode(resp)
}
}))
defer server.Close()
client := newTestFCMClient(server.URL)
tokens := []string{"good-token", "bad-token"}
// Partial failure: at least one succeeded, so no error returned.
err := client.Send(context.Background(), tokens, "Title", "Body", nil)
assert.NoError(t, err)
}
func TestFCMV1Send_UnregisteredError(t *testing.T) {
server := serveFCMV1Error(t, http.StatusNotFound, "UNREGISTERED", "The registration token is not registered")
defer server.Close()
client := newTestFCMClient(server.URL)
err := client.sendOne(context.Background(), "stale-token", "Title", "Body", nil)
require.Error(t, err)
var sendErr *FCMSendError
require.ErrorAs(t, err, &sendErr)
assert.True(t, sendErr.IsUnregistered())
assert.Equal(t, FCMErrUnregistered, sendErr.ErrorCode)
assert.Equal(t, http.StatusNotFound, sendErr.StatusCode)
}
func TestFCMV1Send_QuotaExceededError(t *testing.T) {
server := serveFCMV1Error(t, http.StatusTooManyRequests, "QUOTA_EXCEEDED", "Sending quota exceeded")
defer server.Close()
client := newTestFCMClient(server.URL)
err := client.sendOne(context.Background(), "some-token", "Title", "Body", nil)
require.Error(t, err)
var sendErr *FCMSendError
require.ErrorAs(t, err, &sendErr)
assert.Equal(t, FCMErrQuotaExceeded, sendErr.ErrorCode)
assert.Equal(t, http.StatusTooManyRequests, sendErr.StatusCode)
}
func TestFCMV1Send_UnparseableErrorResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("not json at all"))
}))
defer server.Close()
client := newTestFCMClient(server.URL)
err := client.sendOne(context.Background(), "some-token", "Title", "Body", nil)
require.Error(t, err)
var sendErr *FCMSendError
require.ErrorAs(t, err, &sendErr)
assert.Equal(t, http.StatusInternalServerError, sendErr.StatusCode)
assert.Contains(t, sendErr.Message, "unparseable error response")
}
func TestFCMV1Send_RequestPayloadFormat(t *testing.T) {
var receivedReq fcmV1Request
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify Content-Type header
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
assert.Equal(t, http.MethodPost, r.Method)
err := json.NewDecoder(r.Body).Decode(&receivedReq)
require.NoError(t, err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(fcmV1Response{Name: "projects/test-project/messages/0:12345"})
}))
defer server.Close()
client := newTestFCMClient(server.URL)
data := map[string]string{"task_id": "42", "action": "complete"}
err := client.Send(context.Background(), []string{"device-token-xyz"}, "Task Due", "Your task is due today", data)
require.NoError(t, err)
// Verify the v1 message structure
require.NotNil(t, receivedReq.Message)
assert.Equal(t, "device-token-xyz", receivedReq.Message.Token)
assert.Equal(t, "Task Due", receivedReq.Message.Notification.Title)
assert.Equal(t, "Your task is due today", receivedReq.Message.Notification.Body)
assert.Equal(t, "42", receivedReq.Message.Data["task_id"])
assert.Equal(t, "complete", receivedReq.Message.Data["action"])
assert.Equal(t, "HIGH", receivedReq.Message.Android.Priority)
}
func TestFCMSendError_Error(t *testing.T) {
sendErr := &FCMSendError{
Token: "abcdef1234567890",
StatusCode: 404,
ErrorCode: FCMErrUnregistered,
Message: "token not registered",
}
errStr := sendErr.Error()
assert.Contains(t, errStr, "abcdef12...")
assert.Contains(t, errStr, "token not registered")
assert.Contains(t, errStr, "404")
assert.Contains(t, errStr, "UNREGISTERED")
}
func TestFCMSendError_IsUnregistered(t *testing.T) {
tests := []struct {
name string
code FCMErrorCode
expected bool
}{
{"unregistered", FCMErrUnregistered, true},
{"quota_exceeded", FCMErrQuotaExceeded, false},
{"internal", FCMErrInternal, false},
{"empty", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := &FCMSendError{ErrorCode: tt.code}
assert.Equal(t, tt.expected, err.IsUnregistered())
})
}
}

View File

@@ -124,29 +124,33 @@ func (r *ContractorRepository) GetTasksForContractor(contractorID uint) ([]model
return tasks, err
}
// SetSpecialties sets the specialties for a contractor
// SetSpecialties sets the specialties for a contractor.
// Wrapped in a transaction so that clearing existing specialties and
// appending new ones are atomic -- a failure in either step rolls back both.
func (r *ContractorRepository) SetSpecialties(contractorID uint, specialtyIDs []uint) error {
var contractor models.Contractor
if err := r.db.First(&contractor, contractorID).Error; err != nil {
return err
}
return r.db.Transaction(func(tx *gorm.DB) error {
var contractor models.Contractor
if err := tx.First(&contractor, contractorID).Error; err != nil {
return err
}
// Clear existing specialties
if err := r.db.Model(&contractor).Association("Specialties").Clear(); err != nil {
return err
}
// Clear existing specialties
if err := tx.Model(&contractor).Association("Specialties").Clear(); err != nil {
return err
}
if len(specialtyIDs) == 0 {
return nil
}
if len(specialtyIDs) == 0 {
return nil
}
// Add new specialties
var specialties []models.ContractorSpecialty
if err := r.db.Where("id IN ?", specialtyIDs).Find(&specialties).Error; err != nil {
return err
}
// Add new specialties
var specialties []models.ContractorSpecialty
if err := tx.Where("id IN ?", specialtyIDs).Find(&specialties).Error; err != nil {
return err
}
return r.db.Model(&contractor).Association("Specialties").Append(specialties)
return tx.Model(&contractor).Association("Specialties").Append(specialties)
})
}
// CountByResidence counts contractors in a residence

View File

@@ -98,7 +98,7 @@ func (r *DocumentRepository) FindByUserFiltered(residenceIDs []uint, filter *Doc
}
var documents []models.Document
err := query.Order("created_at DESC").Find(&documents).Error
err := query.Order("created_at DESC").Limit(500).Find(&documents).Error
return documents, err
}

View File

@@ -88,10 +88,33 @@ func (r *ReminderRepository) HasSentReminderBatch(keys []ReminderKey) (map[int]b
userIDs = append(userIDs, id)
}
// Query all matching reminder logs in one query
// Collect unique stages and due dates for tighter SQL filtering
stageSet := make(map[models.ReminderStage]bool)
dueDateSet := make(map[string]bool)
var minDueDate, maxDueDate time.Time
for _, k := range keys {
stageSet[k.Stage] = true
dueDateOnly := time.Date(k.DueDate.Year(), k.DueDate.Month(), k.DueDate.Day(), 0, 0, 0, 0, time.UTC)
dueDateSet[dueDateOnly.Format("2006-01-02")] = true
if minDueDate.IsZero() || dueDateOnly.Before(minDueDate) {
minDueDate = dueDateOnly
}
if maxDueDate.IsZero() || dueDateOnly.After(maxDueDate) {
maxDueDate = dueDateOnly
}
}
stages := make([]models.ReminderStage, 0, len(stageSet))
for s := range stageSet {
stages = append(stages, s)
}
// Query matching reminder logs with tighter filters to reduce result set.
// Filter on reminder_stage and due_date range in addition to task_id/user_id.
var logs []models.TaskReminderLog
err := r.db.Where("task_id IN ? AND user_id IN ?", taskIDs, userIDs).
Find(&logs).Error
err := r.db.Where(
"task_id IN ? AND user_id IN ? AND reminder_stage IN ? AND due_date >= ? AND due_date <= ?",
taskIDs, userIDs, stages, minDueDate, maxDueDate,
).Find(&logs).Error
if err != nil {
return nil, err
}

View File

@@ -196,6 +196,20 @@ func (r *ResidenceRepository) CountByOwner(userID uint) (int64, error) {
return count, err
}
// FindResidenceIDsByOwner returns just the IDs of residences a user owns.
// This is a lightweight alternative to FindOwnedByUser() when only IDs are needed
// for batch queries against related tables (tasks, contractors, documents).
func (r *ResidenceRepository) FindResidenceIDsByOwner(userID uint) ([]uint, error) {
var ids []uint
err := r.db.Model(&models.Residence{}).
Where("owner_id = ? AND is_active = ?", userID, true).
Pluck("id", &ids).Error
if err != nil {
return nil, err
}
return ids, nil
}
// === Share Code Operations ===
// CreateShareCode creates a new share code for a residence

View File

@@ -129,12 +129,21 @@ func (r *SubscriptionRepository) UpdatePurchaseToken(userID uint, token string)
Update("google_purchase_token", token).Error
}
// FindByAppleReceiptContains finds a subscription by Apple transaction ID
// Used by webhooks to find the user associated with a transaction
// FindByAppleReceiptContains finds a subscription by Apple transaction ID.
// Used by webhooks to find the user associated with a transaction.
//
// PERFORMANCE NOTE: This uses a LIKE '%...%' scan on apple_receipt_data which
// cannot use a B-tree index and results in a full table scan. For better
// performance at scale, add a dedicated indexed column:
//
// AppleTransactionID *string `gorm:"column:apple_transaction_id;size:255;index"`
//
// Then look up by exact match: WHERE apple_transaction_id = ?
func (r *SubscriptionRepository) FindByAppleReceiptContains(transactionID string) (*models.UserSubscription, error) {
var sub models.UserSubscription
// Search for transaction ID in the stored receipt data
err := r.db.Where("apple_receipt_data LIKE ?", "%"+transactionID+"%").First(&sub).Error
// Escape LIKE wildcards in the transaction ID to prevent wildcard injection
escaped := escapeLikeWildcards(transactionID)
err := r.db.Where("apple_receipt_data LIKE ?", "%"+escaped+"%").First(&sub).Error
if err != nil {
return nil, err
}

View File

@@ -38,29 +38,40 @@ func (r *TaskRepository) CreateCompletionTx(tx *gorm.DB, completion *models.Task
return tx.Create(completion).Error
}
// taskUpdateFields returns the canonical field map used by both Update and UpdateTx.
// Centralised here so the two methods never drift out of sync.
func taskUpdateFields(t *models.Task) map[string]interface{} {
return map[string]interface{}{
"title": t.Title,
"description": t.Description,
"category_id": t.CategoryID,
"priority_id": t.PriorityID,
"frequency_id": t.FrequencyID,
"custom_interval_days": t.CustomIntervalDays,
"in_progress": t.InProgress,
"assigned_to_id": t.AssignedToID,
"due_date": t.DueDate,
"next_due_date": t.NextDueDate,
"estimated_cost": t.EstimatedCost,
"actual_cost": t.ActualCost,
"contractor_id": t.ContractorID,
"is_cancelled": t.IsCancelled,
"is_archived": t.IsArchived,
"version": gorm.Expr("version + 1"),
}
}
// taskUpdateOmitAssociations lists the association fields to omit during task updates.
var taskUpdateOmitAssociations = []string{
"Residence", "CreatedBy", "AssignedTo", "Category", "Priority", "Frequency", "ParentTask", "Completions",
}
// UpdateTx updates a task with optimistic locking within an existing transaction.
func (r *TaskRepository) UpdateTx(tx *gorm.DB, task *models.Task) error {
result := tx.Model(task).
Where("id = ? AND version = ?", task.ID, task.Version).
Omit("Residence", "CreatedBy", "AssignedTo", "Category", "Priority", "Frequency", "ParentTask", "Completions").
Updates(map[string]interface{}{
"title": task.Title,
"description": task.Description,
"category_id": task.CategoryID,
"priority_id": task.PriorityID,
"frequency_id": task.FrequencyID,
"custom_interval_days": task.CustomIntervalDays,
"in_progress": task.InProgress,
"assigned_to_id": task.AssignedToID,
"due_date": task.DueDate,
"next_due_date": task.NextDueDate,
"estimated_cost": task.EstimatedCost,
"actual_cost": task.ActualCost,
"contractor_id": task.ContractorID,
"is_cancelled": task.IsCancelled,
"is_archived": task.IsArchived,
"version": gorm.Expr("version + 1"),
})
Omit(taskUpdateOmitAssociations...).
Updates(taskUpdateFields(task))
if result.Error != nil {
return result.Error
}
@@ -350,25 +361,8 @@ func (r *TaskRepository) Create(task *models.Task) error {
func (r *TaskRepository) Update(task *models.Task) error {
result := r.db.Model(task).
Where("id = ? AND version = ?", task.ID, task.Version).
Omit("Residence", "CreatedBy", "AssignedTo", "Category", "Priority", "Frequency", "ParentTask", "Completions").
Updates(map[string]interface{}{
"title": task.Title,
"description": task.Description,
"category_id": task.CategoryID,
"priority_id": task.PriorityID,
"frequency_id": task.FrequencyID,
"custom_interval_days": task.CustomIntervalDays,
"in_progress": task.InProgress,
"assigned_to_id": task.AssignedToID,
"due_date": task.DueDate,
"next_due_date": task.NextDueDate,
"estimated_cost": task.EstimatedCost,
"actual_cost": task.ActualCost,
"contractor_id": task.ContractorID,
"is_cancelled": task.IsCancelled,
"is_archived": task.IsArchived,
"version": gorm.Expr("version + 1"),
})
Omit(taskUpdateOmitAssociations...).
Updates(taskUpdateFields(task))
if result.Error != nil {
return result.Error
}
@@ -728,13 +722,18 @@ func (r *TaskRepository) UpdateCompletion(completion *models.TaskCompletion) err
return r.db.Omit("Task", "CompletedBy", "Images").Save(completion).Error
}
// DeleteCompletion deletes a task completion
// DeleteCompletion deletes a task completion and its associated images atomically.
// Wrapped in a transaction so that if the completion delete fails, image
// deletions are rolled back as well.
func (r *TaskRepository) DeleteCompletion(id uint) error {
// Delete images first
if err := r.db.Where("completion_id = ?", id).Delete(&models.TaskCompletionImage{}).Error; err != nil {
log.Error().Err(err).Uint("completion_id", id).Msg("Failed to delete completion images")
}
return r.db.Delete(&models.TaskCompletion{}, id).Error
return r.db.Transaction(func(tx *gorm.DB) error {
// Delete images first
if err := tx.Where("completion_id = ?", id).Delete(&models.TaskCompletionImage{}).Error; err != nil {
log.Error().Err(err).Uint("completion_id", id).Msg("Failed to delete completion images")
return err
}
return tx.Delete(&models.TaskCompletion{}, id).Error
})
}
// CreateCompletionImage creates a new completion image
@@ -912,3 +911,128 @@ func (r *TaskRepository) GetCompletionSummary(residenceID uint, now time.Time, m
Months: months,
}, nil
}
// GetBatchCompletionSummaries returns completion summaries for multiple residences
// in two queries total (one for all-time counts, one for monthly breakdowns),
// instead of 2*N queries when calling GetCompletionSummary per residence.
func (r *TaskRepository) GetBatchCompletionSummaries(residenceIDs []uint, now time.Time, maxPerMonth int) (map[uint]*responses.CompletionSummary, error) {
result := make(map[uint]*responses.CompletionSummary, len(residenceIDs))
if len(residenceIDs) == 0 {
return result, nil
}
// 1. Total all-time completions per residence (single query)
type allTimeRow struct {
ResidenceID uint
Count int64
}
var allTimeRows []allTimeRow
err := r.db.Model(&models.TaskCompletion{}).
Select("task_task.residence_id, COUNT(*) as count").
Joins("JOIN task_task ON task_task.id = task_taskcompletion.task_id").
Where("task_task.residence_id IN ?", residenceIDs).
Group("task_task.residence_id").
Scan(&allTimeRows).Error
if err != nil {
return nil, err
}
allTimeMap := make(map[uint]int64, len(allTimeRows))
for _, row := range allTimeRows {
allTimeMap[row.ResidenceID] = row.Count
}
// 2. Monthly breakdown for last 12 months across all residences (single query)
startDate := time.Date(now.Year()-1, now.Month(), 1, 0, 0, 0, 0, now.Location())
dateExpr := "TO_CHAR(task_taskcompletion.completed_at, 'YYYY-MM')"
if r.db.Dialector.Name() == "sqlite" {
dateExpr = "strftime('%Y-%m', task_taskcompletion.completed_at)"
}
var rows []completionAggRow
err = r.db.Model(&models.TaskCompletion{}).
Select(fmt.Sprintf("task_task.residence_id, task_taskcompletion.completed_from_column, %s as completed_month, COUNT(*) as count", dateExpr)).
Joins("JOIN task_task ON task_task.id = task_taskcompletion.task_id").
Where("task_task.residence_id IN ? AND task_taskcompletion.completed_at >= ?", residenceIDs, startDate).
Group(fmt.Sprintf("task_task.residence_id, task_taskcompletion.completed_from_column, %s", dateExpr)).
Order("completed_month ASC").
Scan(&rows).Error
if err != nil {
return nil, err
}
// 3. Build per-residence summaries
type monthData struct {
columns map[string]int
total int
}
// Initialize all residences with empty month maps
residenceMonths := make(map[uint]map[string]*monthData, len(residenceIDs))
for _, rid := range residenceIDs {
mm := make(map[string]*monthData, 12)
for i := 0; i < 12; i++ {
m := startDate.AddDate(0, i, 0)
key := m.Format("2006-01")
mm[key] = &monthData{columns: make(map[string]int)}
}
residenceMonths[rid] = mm
}
// Populate from query results
residenceLast12 := make(map[uint]int, len(residenceIDs))
for _, row := range rows {
mm, ok := residenceMonths[row.ResidenceID]
if !ok {
continue
}
md, ok := mm[row.CompletedMonth]
if !ok {
continue
}
md.columns[row.CompletedFromColumn] = int(row.Count)
md.total += int(row.Count)
residenceLast12[row.ResidenceID] += int(row.Count)
}
// Convert to response DTOs per residence
for _, rid := range residenceIDs {
mm := residenceMonths[rid]
months := make([]responses.MonthlyCompletionSummary, 0, 12)
for i := 0; i < 12; i++ {
m := startDate.AddDate(0, i, 0)
key := m.Format("2006-01")
md := mm[key]
completions := make([]responses.ColumnCompletionCount, 0)
for col, count := range md.columns {
completions = append(completions, responses.ColumnCompletionCount{
Column: col,
Color: KanbanColumnColor(col),
Count: count,
})
}
overflow := 0
if md.total > maxPerMonth {
overflow = md.total - maxPerMonth
}
months = append(months, responses.MonthlyCompletionSummary{
Month: key,
Completions: completions,
Total: md.total,
Overflow: overflow,
})
}
result[rid] = &responses.CompletionSummary{
TotalAllTime: int(allTimeMap[rid]),
TotalLast12Months: residenceLast12[rid],
Months: months,
}
}
return result, nil
}

View File

@@ -34,6 +34,16 @@ func NewUserRepository(db *gorm.DB) *UserRepository {
return &UserRepository{db: db}
}
// Transaction runs fn inside a database transaction. The callback receives a
// new UserRepository backed by the transaction so all operations within fn
// share the same transactional connection.
func (r *UserRepository) Transaction(fn func(txRepo *UserRepository) error) error {
return r.db.Transaction(func(tx *gorm.DB) error {
txRepo := &UserRepository{db: tx}
return fn(txRepo)
})
}
// FindByID finds a user by ID
func (r *UserRepository) FindByID(id uint) (*models.User, error) {
var user models.User
@@ -130,18 +140,28 @@ func (r *UserRepository) ExistsByEmail(email string) (bool, error) {
// --- Auth Token Methods ---
// GetOrCreateToken gets or creates an auth token for a user
// GetOrCreateToken gets or creates an auth token for a user.
// Wrapped in a transaction to prevent race conditions where two
// concurrent requests could create duplicate tokens for the same user.
func (r *UserRepository) GetOrCreateToken(userID uint) (*models.AuthToken, error) {
var token models.AuthToken
result := r.db.Where("user_id = ?", userID).First(&token)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
token = models.AuthToken{UserID: userID}
if err := r.db.Create(&token).Error; err != nil {
return nil, err
err := r.db.Transaction(func(tx *gorm.DB) error {
result := tx.Where("user_id = ?", userID).First(&token)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
token = models.AuthToken{UserID: userID}
if err := tx.Create(&token).Error; err != nil {
return err
}
} else if result.Error != nil {
return result.Error
}
} else if result.Error != nil {
return nil, result.Error
return nil
})
if err != nil {
return nil, err
}
return &token, nil
@@ -341,7 +361,7 @@ func (r *UserRepository) SearchUsers(query string, limit, offset int) ([]models.
var users []models.User
var total int64
searchQuery := "%" + strings.ToLower(query) + "%"
searchQuery := "%" + escapeLikeWildcards(strings.ToLower(query)) + "%"
baseQuery := r.db.Model(&models.User{}).
Where("LOWER(username) LIKE ? OR LOWER(email) LIKE ? OR LOWER(first_name) LIKE ? OR LOWER(last_name) LIKE ?",
@@ -384,7 +404,7 @@ func (r *UserRepository) FindUsersInSharedResidences(userID uint) ([]models.User
// 2. Members of residences owned by current user
// 3. Members of residences where current user is also a member
err := r.db.Raw(`
SELECT DISTINCT u.* FROM user_customuser u
SELECT DISTINCT u.* FROM auth_user u
WHERE u.id != ? AND u.is_active = true AND (
-- Users who own residences where current user is a shared user
u.id IN (
@@ -417,7 +437,7 @@ func (r *UserRepository) FindUserIfSharedResidence(targetUserID, requestingUserI
var user models.User
err := r.db.Raw(`
SELECT u.* FROM user_customuser u
SELECT u.* FROM auth_user u
WHERE u.id = ? AND u.is_active = true AND (
u.id = ? OR
-- Target owns a residence where requester is a member
@@ -460,7 +480,7 @@ func (r *UserRepository) FindProfilesInSharedResidences(userID uint) ([]models.U
err := r.db.Raw(`
SELECT p.* FROM user_userprofile p
INNER JOIN user_customuser u ON p.user_id = u.id
INNER JOIN auth_user u ON p.user_id = u.id
WHERE u.is_active = true AND (
u.id = ? OR
-- Users who own residences where current user is a shared user

View File

@@ -59,6 +59,15 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
e.Use(custommiddleware.RequestIDMiddleware())
e.Use(utils.EchoRecovery())
e.Use(custommiddleware.StructuredLogger())
// Security headers (X-Frame-Options, X-Content-Type-Options, X-XSS-Protection, etc.)
e.Use(middleware.SecureWithConfig(middleware.SecureConfig{
XSSProtection: "1; mode=block",
ContentTypeNosniff: "nosniff",
XFrameOptions: "SAMEORIGIN",
HSTSMaxAge: 31536000, // 1 year in seconds
ReferrerPolicy: "strict-origin-when-cross-origin",
}))
e.Use(middleware.BodyLimitWithConfig(middleware.BodyLimitConfig{
Limit: "1M", // 1MB default for JSON payloads
Skipper: func(c echo.Context) bool {
@@ -187,7 +196,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
var uploadHandler *handlers.UploadHandler
var mediaHandler *handlers.MediaHandler
if deps.StorageService != nil {
uploadHandler = handlers.NewUploadHandler(deps.StorageService)
uploadHandler = handlers.NewUploadHandler(deps.StorageService, services.NewFileOwnershipService(deps.DB))
mediaHandler = handlers.NewMediaHandler(documentRepo, taskRepo, residenceRepo, deps.StorageService)
}
@@ -247,13 +256,22 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
}
// corsMiddleware configures CORS with restricted origins in production.
// In debug mode, all origins are allowed for development convenience.
// In debug mode, explicit localhost origins are allowed for development.
// In production, origins are read from the CORS_ALLOWED_ORIGINS environment variable
// (comma-separated), falling back to a restrictive default set.
func corsMiddleware(cfg *config.Config) echo.MiddlewareFunc {
var origins []string
if cfg.Server.Debug {
origins = []string{"*"}
origins = []string{
"http://localhost:3000",
"http://localhost:3001",
"http://localhost:8080",
"http://localhost:8000",
"http://127.0.0.1:3000",
"http://127.0.0.1:3001",
"http://127.0.0.1:8080",
"http://127.0.0.1:8000",
}
} else {
origins = cfg.Server.CorsAllowedOrigins
if len(origins) == 0 {
@@ -286,17 +304,24 @@ func healthCheck(c echo.Context) error {
})
}
// setupPublicAuthRoutes configures public authentication routes
// setupPublicAuthRoutes configures public authentication routes with
// per-endpoint rate limiters to mitigate brute-force and credential-stuffing.
func setupPublicAuthRoutes(api *echo.Group, authHandler *handlers.AuthHandler) {
auth := api.Group("/auth")
// Rate limiters — created once, shared across requests.
loginRL := custommiddleware.LoginRateLimiter() // 10 req/min
registerRL := custommiddleware.RegistrationRateLimiter() // 5 req/min
passwordRL := custommiddleware.PasswordResetRateLimiter() // 3 req/min
{
auth.POST("/login/", authHandler.Login)
auth.POST("/register/", authHandler.Register)
auth.POST("/forgot-password/", authHandler.ForgotPassword)
auth.POST("/verify-reset-code/", authHandler.VerifyResetCode)
auth.POST("/reset-password/", authHandler.ResetPassword)
auth.POST("/apple-sign-in/", authHandler.AppleSignIn)
auth.POST("/google-sign-in/", authHandler.GoogleSignIn)
auth.POST("/login/", authHandler.Login, loginRL)
auth.POST("/register/", authHandler.Register, registerRL)
auth.POST("/forgot-password/", authHandler.ForgotPassword, passwordRL)
auth.POST("/verify-reset-code/", authHandler.VerifyResetCode, passwordRL)
auth.POST("/reset-password/", authHandler.ResetPassword, passwordRL)
auth.POST("/apple-sign-in/", authHandler.AppleSignIn, loginRL)
auth.POST("/google-sign-in/", authHandler.GoogleSignIn, loginRL)
}
}

View File

@@ -9,6 +9,7 @@ import (
"strings"
"time"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
"github.com/treytartt/honeydue-api/internal/apperrors"
@@ -90,7 +91,7 @@ func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginRespons
// Update last login
if err := s.userRepo.UpdateLastLogin(user.ID); err != nil {
// Log error but don't fail the login
fmt.Printf("Failed to update last login: %v\n", err)
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to update last login")
}
return &responses.LoginResponse{
@@ -99,7 +100,9 @@ func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginRespons
}, nil
}
// Register creates a new user account
// Register creates a new user account.
// F-10: User creation, profile creation, notification preferences, and confirmation code
// are wrapped in a transaction for atomicity.
func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.RegisterResponse, string, error) {
// Check if username exists
exists, err := s.userRepo.ExistsByUsername(req.Username)
@@ -133,43 +136,49 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
return nil, "", apperrors.Internal(err)
}
// Save user
if err := s.userRepo.Create(user); err != nil {
return nil, "", apperrors.Internal(err)
}
// Create user profile
if _, err := s.userRepo.GetOrCreateProfile(user.ID); err != nil {
// Log error but don't fail registration
fmt.Printf("Failed to create user profile: %v\n", err)
}
// Create notification preferences with all options enabled
if s.notificationRepo != nil {
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
// Log error but don't fail registration
fmt.Printf("Failed to create notification preferences: %v\n", err)
}
}
// Create auth token
token, err := s.userRepo.GetOrCreateToken(user.ID)
if err != nil {
return nil, "", apperrors.Internal(err)
}
// Generate confirmation code - use fixed code in debug mode for easier local testing
// Generate confirmation code - use fixed code when DEBUG_FIXED_CODES is enabled for easier local testing
var code string
if s.cfg.Server.Debug {
if s.cfg.Server.DebugFixedCodes {
code = "123456"
} else {
code = generateSixDigitCode()
}
expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry)
if _, err := s.userRepo.CreateConfirmationCode(user.ID, code, expiresAt); err != nil {
// Log error but don't fail registration
fmt.Printf("Failed to create confirmation code: %v\n", err)
// Wrap user creation + profile + notification preferences + confirmation code in a transaction
txErr := s.userRepo.Transaction(func(txRepo *repositories.UserRepository) error {
// Save user
if err := txRepo.Create(user); err != nil {
return err
}
// Create user profile
if _, err := txRepo.GetOrCreateProfile(user.ID); err != nil {
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create user profile during registration")
}
// Create notification preferences with all options enabled
if s.notificationRepo != nil {
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences during registration")
}
}
// Create confirmation code
if _, err := txRepo.CreateConfirmationCode(user.ID, code, expiresAt); err != nil {
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create confirmation code during registration")
}
return nil
})
if txErr != nil {
return nil, "", apperrors.Internal(txErr)
}
// Create auth token (outside transaction since token generation is idempotent)
token, err := s.userRepo.GetOrCreateToken(user.ID)
if err != nil {
return nil, "", apperrors.Internal(err)
}
return &responses.RegisterResponse{
@@ -248,8 +257,8 @@ func (s *AuthService) VerifyEmail(userID uint, code string) error {
return apperrors.BadRequest("error.email_already_verified")
}
// Check for test code in debug mode
if s.cfg.Server.Debug && code == "123456" {
// Check for test code when DEBUG_FIXED_CODES is enabled
if s.cfg.Server.DebugFixedCodes && code == "123456" {
if err := s.userRepo.SetProfileVerified(userID, true); err != nil {
return apperrors.Internal(err)
}
@@ -294,9 +303,9 @@ func (s *AuthService) ResendVerificationCode(userID uint) (string, error) {
return "", apperrors.BadRequest("error.email_already_verified")
}
// Generate new code - use fixed code in debug mode for easier local testing
// Generate new code - use fixed code when DEBUG_FIXED_CODES is enabled for easier local testing
var code string
if s.cfg.Server.Debug {
if s.cfg.Server.DebugFixedCodes {
code = "123456"
} else {
code = generateSixDigitCode()
@@ -331,9 +340,9 @@ func (s *AuthService) ForgotPassword(email string) (string, *models.User, error)
return "", nil, apperrors.TooManyRequests("error.rate_limit_exceeded")
}
// Generate code and reset token - use fixed code in debug mode for easier local testing
// Generate code and reset token - use fixed code when DEBUG_FIXED_CODES is enabled for easier local testing
var code string
if s.cfg.Server.Debug {
if s.cfg.Server.DebugFixedCodes {
code = "123456"
} else {
code = generateSixDigitCode()
@@ -365,8 +374,8 @@ func (s *AuthService) VerifyResetCode(email, code string) (string, error) {
return "", apperrors.Internal(err)
}
// Check for test code in debug mode
if s.cfg.Server.Debug && code == "123456" {
// Check for test code when DEBUG_FIXED_CODES is enabled
if s.cfg.Server.DebugFixedCodes && code == "123456" {
return resetCode.ResetToken, nil
}
@@ -422,13 +431,13 @@ func (s *AuthService) ResetPassword(resetToken, newPassword string) error {
// Mark reset code as used
if err := s.userRepo.MarkPasswordResetCodeUsed(resetCode.ID); err != nil {
// Log error but don't fail
fmt.Printf("Failed to mark reset code as used: %v\n", err)
log.Warn().Err(err).Uint("reset_code_id", resetCode.ID).Msg("Failed to mark reset code as used")
}
// Invalidate all existing tokens for this user (security measure)
if err := s.userRepo.DeleteTokenByUserID(user.ID); err != nil {
// Log error but don't fail
fmt.Printf("Failed to delete user tokens: %v\n", err)
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to delete user tokens after password reset")
}
return nil
@@ -482,6 +491,13 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
if email != "" {
existingUser, err := s.userRepo.FindByEmail(email)
if err == nil && existingUser != nil {
// S-06: Log auto-linking of social account to existing user
log.Warn().
Str("email", email).
Str("provider", "apple").
Uint("user_id", existingUser.ID).
Msg("Auto-linking social account to existing user by email match")
// Link Apple ID to existing account
appleAuthRecord := &models.AppleSocialAuth{
UserID: existingUser.ID,
@@ -505,8 +521,11 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
// Update last login
_ = s.userRepo.UpdateLastLogin(existingUser.ID)
// Reload user with profile
existingUser, _ = s.userRepo.FindByIDWithProfile(existingUser.ID)
// B-08: Check error from FindByIDWithProfile
existingUser, err = s.userRepo.FindByIDWithProfile(existingUser.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
return &responses.AppleSignInResponse{
Token: token.Key,
@@ -544,8 +563,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
// Create notification preferences with all options enabled
if s.notificationRepo != nil {
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
// Log error but don't fail registration
fmt.Printf("Failed to create notification preferences: %v\n", err)
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Apple Sign In user")
}
}
@@ -566,8 +584,11 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
return nil, apperrors.Internal(err)
}
// Reload user with profile
user, _ = s.userRepo.FindByIDWithProfile(user.ID)
// B-08: Check error from FindByIDWithProfile
user, err = s.userRepo.FindByIDWithProfile(user.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
return &responses.AppleSignInResponse{
Token: token.Key,
@@ -623,6 +644,13 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
if email != "" {
existingUser, err := s.userRepo.FindByEmail(email)
if err == nil && existingUser != nil {
// S-06: Log auto-linking of social account to existing user
log.Warn().
Str("email", email).
Str("provider", "google").
Uint("user_id", existingUser.ID).
Msg("Auto-linking social account to existing user by email match")
// Link Google ID to existing account
googleAuthRecord := &models.GoogleSocialAuth{
UserID: existingUser.ID,
@@ -649,8 +677,11 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
// Update last login
_ = s.userRepo.UpdateLastLogin(existingUser.ID)
// Reload user with profile
existingUser, _ = s.userRepo.FindByIDWithProfile(existingUser.ID)
// B-08: Check error from FindByIDWithProfile
existingUser, err = s.userRepo.FindByIDWithProfile(existingUser.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
return &responses.GoogleSignInResponse{
Token: token.Key,
@@ -688,8 +719,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
// Create notification preferences with all options enabled
if s.notificationRepo != nil {
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
// Log error but don't fail registration
fmt.Printf("Failed to create notification preferences: %v\n", err)
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Google Sign In user")
}
}
@@ -711,8 +741,11 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
return nil, apperrors.Internal(err)
}
// Reload user with profile
user, _ = s.userRepo.FindByIDWithProfile(user.ID)
// B-08: Check error from FindByIDWithProfile
user, err = s.userRepo.FindByIDWithProfile(user.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
return &responses.GoogleSignInResponse{
Token: token.Key,

View File

@@ -2,9 +2,10 @@ package services
import (
"context"
"crypto/md5"
"encoding/json"
"fmt"
"hash/fnv"
"sync"
"time"
"github.com/redis/go-redis/v9"
@@ -18,38 +19,55 @@ type CacheService struct {
client *redis.Client
}
var cacheInstance *CacheService
var (
cacheInstance *CacheService
cacheOnce sync.Once
)
// NewCacheService creates a new cache service
// NewCacheService creates a new cache service (thread-safe via sync.Once)
func NewCacheService(cfg *config.RedisConfig) (*CacheService, error) {
opt, err := redis.ParseURL(cfg.URL)
if err != nil {
return nil, fmt.Errorf("failed to parse Redis URL: %w", err)
var initErr error
cacheOnce.Do(func() {
opt, err := redis.ParseURL(cfg.URL)
if err != nil {
initErr = fmt.Errorf("failed to parse Redis URL: %w", err)
return
}
if cfg.Password != "" {
opt.Password = cfg.Password
}
if cfg.DB != 0 {
opt.DB = cfg.DB
}
client := redis.NewClient(opt)
// Test connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
initErr = fmt.Errorf("failed to connect to Redis: %w", err)
// Reset Once so a retry is possible after transient failures
cacheOnce = sync.Once{}
return
}
// S-14: Mask credentials in Redis URL before logging
log.Info().
Str("url", config.MaskURLCredentials(cfg.URL)).
Int("db", opt.DB).
Msg("Connected to Redis")
cacheInstance = &CacheService{client: client}
})
if initErr != nil {
return nil, initErr
}
if cfg.Password != "" {
opt.Password = cfg.Password
}
if cfg.DB != 0 {
opt.DB = cfg.DB
}
client := redis.NewClient(opt)
// Test connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
log.Info().
Str("url", cfg.URL).
Int("db", opt.DB).
Msg("Connected to Redis")
cacheInstance = &CacheService{client: client}
return cacheInstance, nil
}
@@ -311,9 +329,10 @@ func (c *CacheService) CacheSeededData(ctx context.Context, data interface{}) (s
return "", fmt.Errorf("failed to marshal seeded data: %w", err)
}
// Generate MD5 ETag from the JSON data
hash := md5.Sum(jsonData)
etag := fmt.Sprintf("\"%x\"", hash)
// Generate FNV-64a ETag from the JSON data (faster than MD5, non-cryptographic)
h := fnv.New64a()
h.Write(jsonData)
etag := fmt.Sprintf("\"%x\"", h.Sum64())
// Store both the data and the ETag
if err := c.client.Set(ctx, SeededDataKey, jsonData, SeededDataTTL).Err(); err != nil {

View File

@@ -4,11 +4,11 @@ import (
"bytes"
"fmt"
"html/template"
"io"
"time"
mail "github.com/wneessen/go-mail"
"github.com/rs/zerolog/log"
"gopkg.in/gomail.v2"
"github.com/treytartt/honeydue-api/internal/config"
)
@@ -16,17 +16,31 @@ import (
// EmailService handles sending emails
type EmailService struct {
cfg *config.EmailConfig
dialer *gomail.Dialer
client *mail.Client
enabled bool
}
// NewEmailService creates a new email service
func NewEmailService(cfg *config.EmailConfig, enabled bool) *EmailService {
dialer := gomail.NewDialer(cfg.Host, cfg.Port, cfg.User, cfg.Password)
client, err := mail.NewClient(cfg.Host,
mail.WithPort(cfg.Port),
mail.WithSMTPAuth(mail.SMTPAuthPlain),
mail.WithUsername(cfg.User),
mail.WithPassword(cfg.Password),
mail.WithTLSPortPolicy(mail.TLSOpportunistic),
)
if err != nil {
log.Error().Err(err).Msg("Failed to create mail client - emails will not be sent")
return &EmailService{
cfg: cfg,
client: nil,
enabled: false,
}
}
return &EmailService{
cfg: cfg,
dialer: dialer,
client: client,
enabled: enabled,
}
}
@@ -37,14 +51,18 @@ func (s *EmailService) SendEmail(to, subject, htmlBody, textBody string) error {
log.Debug().Msg("Email sending disabled by feature flag")
return nil
}
m := gomail.NewMessage()
m.SetHeader("From", s.cfg.From)
m.SetHeader("To", to)
m.SetHeader("Subject", subject)
m.SetBody("text/plain", textBody)
m.AddAlternative("text/html", htmlBody)
m := mail.NewMsg()
if err := m.FromFormat("honeyDue", s.cfg.From); err != nil {
return fmt.Errorf("failed to set from address: %w", err)
}
if err := m.AddTo(to); err != nil {
return fmt.Errorf("failed to set to address: %w", err)
}
m.Subject(subject)
m.SetBodyString(mail.TypeTextPlain, textBody)
m.AddAlternativeString(mail.TypeTextHTML, htmlBody)
if err := s.dialer.DialAndSend(m); err != nil {
if err := s.client.DialAndSend(m); err != nil {
log.Error().Err(err).Str("to", to).Str("subject", subject).Msg("Failed to send email")
return fmt.Errorf("failed to send email: %w", err)
}
@@ -74,26 +92,25 @@ func (s *EmailService) SendEmailWithAttachment(to, subject, htmlBody, textBody s
log.Debug().Msg("Email sending disabled by feature flag")
return nil
}
m := gomail.NewMessage()
m.SetHeader("From", s.cfg.From)
m.SetHeader("To", to)
m.SetHeader("Subject", subject)
m.SetBody("text/plain", textBody)
m.AddAlternative("text/html", htmlBody)
m := mail.NewMsg()
if err := m.FromFormat("honeyDue", s.cfg.From); err != nil {
return fmt.Errorf("failed to set from address: %w", err)
}
if err := m.AddTo(to); err != nil {
return fmt.Errorf("failed to set to address: %w", err)
}
m.Subject(subject)
m.SetBodyString(mail.TypeTextPlain, textBody)
m.AddAlternativeString(mail.TypeTextHTML, htmlBody)
if attachment != nil {
m.Attach(attachment.Filename,
gomail.SetCopyFunc(func(w io.Writer) error {
_, err := w.Write(attachment.Data)
return err
}),
gomail.SetHeader(map[string][]string{
"Content-Type": {attachment.ContentType},
}),
m.AttachReader(attachment.Filename,
bytes.NewReader(attachment.Data),
mail.WithFileContentType(mail.ContentType(attachment.ContentType)),
)
}
if err := s.dialer.DialAndSend(m); err != nil {
if err := s.client.DialAndSend(m); err != nil {
log.Error().Err(err).Str("to", to).Str("subject", subject).Msg("Failed to send email with attachment")
return fmt.Errorf("failed to send email: %w", err)
}
@@ -108,29 +125,28 @@ func (s *EmailService) SendEmailWithEmbeddedImages(to, subject, htmlBody, textBo
log.Debug().Msg("Email sending disabled by feature flag")
return nil
}
m := gomail.NewMessage()
m.SetHeader("From", s.cfg.From)
m.SetHeader("To", to)
m.SetHeader("Subject", subject)
m.SetBody("text/plain", textBody)
m.AddAlternative("text/html", htmlBody)
m := mail.NewMsg()
if err := m.FromFormat("honeyDue", s.cfg.From); err != nil {
return fmt.Errorf("failed to set from address: %w", err)
}
if err := m.AddTo(to); err != nil {
return fmt.Errorf("failed to set to address: %w", err)
}
m.Subject(subject)
m.SetBodyString(mail.TypeTextPlain, textBody)
m.AddAlternativeString(mail.TypeTextHTML, htmlBody)
// Embed each image with Content-ID for inline display
for _, img := range images {
m.Embed(img.Filename,
gomail.SetCopyFunc(func(w io.Writer) error {
_, err := w.Write(img.Data)
return err
}),
gomail.SetHeader(map[string][]string{
"Content-Type": {img.ContentType},
"Content-ID": {"<" + img.ContentID + ">"},
"Content-Disposition": {"inline; filename=\"" + img.Filename + "\""},
}),
img := img // capture range variable for closure
m.EmbedReader(img.Filename,
bytes.NewReader(img.Data),
mail.WithFileContentType(mail.ContentType(img.ContentType)),
mail.WithFileContentID(img.ContentID),
)
}
if err := s.dialer.DialAndSend(m); err != nil {
if err := s.client.DialAndSend(m); err != nil {
log.Error().Err(err).Str("to", to).Str("subject", subject).Int("images", len(images)).Msg("Failed to send email with embedded images")
return fmt.Errorf("failed to send email: %w", err)
}

View File

@@ -0,0 +1,66 @@
package services
import (
"github.com/treytartt/honeydue-api/internal/models"
"gorm.io/gorm"
)
// FileOwnershipService checks whether a user owns a file referenced by URL.
// It queries task completion images, document files, and document images
// to determine ownership through residence access.
type FileOwnershipService struct {
db *gorm.DB
}
// NewFileOwnershipService creates a new FileOwnershipService
func NewFileOwnershipService(db *gorm.DB) *FileOwnershipService {
return &FileOwnershipService{db: db}
}
// IsFileOwnedByUser checks if the given file URL belongs to a record
// that the user has access to (via residence membership).
func (s *FileOwnershipService) IsFileOwnedByUser(fileURL string, userID uint) (bool, error) {
// Check task completion images: image_url -> completion -> task -> residence -> user access
var completionImageCount int64
err := s.db.Model(&models.TaskCompletionImage{}).
Joins("JOIN task_taskcompletion ON task_taskcompletion.id = task_taskcompletionimage.completion_id").
Joins("JOIN task_task ON task_task.id = task_taskcompletion.task_id").
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_task.residence_id").
Where("task_taskcompletionimage.image_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
Count(&completionImageCount).Error
if err != nil {
return false, err
}
if completionImageCount > 0 {
return true, nil
}
// Check document files: file_url -> document -> residence -> user access
var documentCount int64
err = s.db.Model(&models.Document{}).
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_document.residence_id").
Where("task_document.file_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
Count(&documentCount).Error
if err != nil {
return false, err
}
if documentCount > 0 {
return true, nil
}
// Check document images: image_url -> document_image -> document -> residence -> user access
var documentImageCount int64
err = s.db.Model(&models.DocumentImage{}).
Joins("JOIN task_document ON task_document.id = task_documentimage.document_id").
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_document.residence_id").
Where("task_documentimage.image_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
Count(&documentImageCount).Error
if err != nil {
return false, err
}
if documentImageCount > 0 {
return true, nil
}
return false, nil
}

View File

@@ -36,11 +36,12 @@ var (
// AppleIAPClient handles Apple App Store Server API validation
type AppleIAPClient struct {
keyID string
issuerID string
bundleID string
keyID string
issuerID string
bundleID string
privateKey *ecdsa.PrivateKey
sandbox bool
sandbox bool
httpClient *http.Client // P-07: Reused across requests
}
// GoogleIAPClient handles Google Play Developer API validation
@@ -122,6 +123,7 @@ func NewAppleIAPClient(cfg config.AppleIAPConfig) (*AppleIAPClient, error) {
bundleID: cfg.BundleID,
privateKey: ecdsaKey,
sandbox: cfg.Sandbox,
httpClient: &http.Client{Timeout: 30 * time.Second}, // P-07: Single client reused across requests
}, nil
}
@@ -168,8 +170,8 @@ func (c *AppleIAPClient) ValidateTransaction(ctx context.Context, transactionID
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
// P-07: Reuse the single http.Client instead of creating one per request
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to call Apple API: %w", err)
}
@@ -276,8 +278,8 @@ func (c *AppleIAPClient) ValidateReceipt(ctx context.Context, receiptData string
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
// P-07: Reuse the single http.Client
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to call Apple API: %w", err)
}
@@ -357,8 +359,8 @@ func (c *AppleIAPClient) validateLegacyReceiptWithSandbox(ctx context.Context, r
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
// P-07: Reuse the single http.Client
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to call Apple verifyReceipt: %w", err)
}

View File

@@ -4,9 +4,11 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"time"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"github.com/treytartt/honeydue-api/internal/apperrors"
@@ -184,8 +186,33 @@ func (s *NotificationService) GetPreferences(userID uint) (*NotificationPreferen
return NewNotificationPreferencesResponse(prefs), nil
}
// validateHourField checks that an optional hour value is in the valid range 0-23.
func validateHourField(val *int, fieldName string) error {
if val != nil && (*val < 0 || *val > 23) {
return apperrors.BadRequest("error.invalid_hour").
WithMessage(fmt.Sprintf("%s must be between 0 and 23", fieldName))
}
return nil
}
// UpdatePreferences updates notification preferences
func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferencesRequest) (*NotificationPreferencesResponse, error) {
// B-12: Validate hour fields are in range 0-23
hourFields := []struct {
value *int
name string
}{
{req.TaskDueSoonHour, "task_due_soon_hour"},
{req.TaskOverdueHour, "task_overdue_hour"},
{req.WarrantyExpiringHour, "warranty_expiring_hour"},
{req.DailyDigestHour, "daily_digest_hour"},
}
for _, hf := range hourFields {
if err := validateHourField(hf.value, hf.name); err != nil {
return nil, err
}
}
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
if err != nil {
return nil, apperrors.Internal(err)
@@ -256,7 +283,10 @@ func (s *NotificationService) UpdateUserTimezone(userID uint, timezone string) {
// Only update if timezone changed (avoid unnecessary DB writes)
if prefs.Timezone == nil || *prefs.Timezone != timezone {
prefs.Timezone = &timezone
_ = s.notificationRepo.UpdatePreferences(prefs)
if err := s.notificationRepo.UpdatePreferences(prefs); err != nil {
log.Error().Err(err).Uint("user_id", userID).Str("timezone", timezone).
Msg("Failed to update user timezone in notification preferences")
}
}
}
@@ -430,6 +460,7 @@ func (s *NotificationService) UnregisterDevice(registrationID, platform string,
// === Response/Request Types ===
// TODO(hardening): Move to internal/dto/responses/notification.go
// NotificationResponse represents a notification in API response
type NotificationResponse struct {
ID uint `json:"id"`
@@ -473,6 +504,7 @@ func NewNotificationResponse(n *models.Notification) NotificationResponse {
return resp
}
// TODO(hardening): Move to internal/dto/responses/notification.go
// NotificationPreferencesResponse represents notification preferences
type NotificationPreferencesResponse struct {
TaskDueSoon bool `json:"task_due_soon"`
@@ -511,6 +543,7 @@ func NewNotificationPreferencesResponse(p *models.NotificationPreference) *Notif
}
}
// TODO(hardening): Move to internal/dto/requests/notification.go
// UpdatePreferencesRequest represents preferences update request
type UpdatePreferencesRequest struct {
TaskDueSoon *bool `json:"task_due_soon"`
@@ -532,6 +565,7 @@ type UpdatePreferencesRequest struct {
DailyDigestHour *int `json:"daily_digest_hour"`
}
// TODO(hardening): Move to internal/dto/responses/notification.go
// DeviceResponse represents a device in API response
type DeviceResponse struct {
ID uint `json:"id"`
@@ -569,6 +603,7 @@ func NewGCMDeviceResponse(d *models.GCMDevice) *DeviceResponse {
}
}
// TODO(hardening): Move to internal/dto/requests/notification.go
// RegisterDeviceRequest represents device registration request
type RegisterDeviceRequest struct {
Name string `json:"name"`

View File

@@ -39,6 +39,7 @@ func generateTrackingID() string {
// HasSentEmail checks if a specific email type has already been sent to a user
func (s *OnboardingEmailService) HasSentEmail(userID uint, emailType models.OnboardingEmailType) bool {
// TODO(hardening): Replace with OnboardingEmailRepository.HasSentEmail()
var count int64
if err := s.db.Model(&models.OnboardingEmail{}).
Where("user_id = ? AND email_type = ?", userID, emailType).
@@ -51,6 +52,7 @@ func (s *OnboardingEmailService) HasSentEmail(userID uint, emailType models.Onbo
// RecordEmailSent records that an email was sent to a user
func (s *OnboardingEmailService) RecordEmailSent(userID uint, emailType models.OnboardingEmailType, trackingID string) error {
// TODO(hardening): Replace with OnboardingEmailRepository.Create()
email := &models.OnboardingEmail{
UserID: userID,
EmailType: emailType,
@@ -66,6 +68,7 @@ func (s *OnboardingEmailService) RecordEmailSent(userID uint, emailType models.O
// RecordEmailOpened records that an email was opened based on tracking ID
func (s *OnboardingEmailService) RecordEmailOpened(trackingID string) error {
// TODO(hardening): Replace with OnboardingEmailRepository.MarkOpened()
now := time.Now().UTC()
result := s.db.Model(&models.OnboardingEmail{}).
Where("tracking_id = ? AND opened_at IS NULL", trackingID).
@@ -84,6 +87,7 @@ func (s *OnboardingEmailService) RecordEmailOpened(trackingID string) error {
// GetEmailHistory gets all onboarding emails for a specific user
func (s *OnboardingEmailService) GetEmailHistory(userID uint) ([]models.OnboardingEmail, error) {
// TODO(hardening): Replace with OnboardingEmailRepository.FindByUserID()
var emails []models.OnboardingEmail
if err := s.db.Where("user_id = ?", userID).Order("sent_at DESC").Find(&emails).Error; err != nil {
return nil, err
@@ -105,11 +109,13 @@ func (s *OnboardingEmailService) GetAllEmailHistory(page, pageSize int) ([]model
offset := (page - 1) * pageSize
// TODO(hardening): Replace with OnboardingEmailRepository.CountAll()
// Count total
if err := s.db.Model(&models.OnboardingEmail{}).Count(&total).Error; err != nil {
return nil, 0, err
}
// TODO(hardening): Replace with OnboardingEmailRepository.FindAllPaginated()
// Get paginated results with user info
if err := s.db.Preload("User").
Order("sent_at DESC").
@@ -126,6 +132,7 @@ func (s *OnboardingEmailService) GetAllEmailHistory(page, pageSize int) ([]model
func (s *OnboardingEmailService) GetEmailStats() (*OnboardingEmailStats, error) {
stats := &OnboardingEmailStats{}
// TODO(hardening): Replace with OnboardingEmailRepository.GetStats()
// No residence email stats
var noResTotal, noResOpened int64
if err := s.db.Model(&models.OnboardingEmail{}).
@@ -159,6 +166,7 @@ func (s *OnboardingEmailService) GetEmailStats() (*OnboardingEmailStats, error)
return stats, nil
}
// TODO(hardening): Move to internal/dto/responses/onboarding_email.go
// OnboardingEmailStats represents statistics about onboarding emails
type OnboardingEmailStats struct {
NoResidenceTotal int64 `json:"no_residence_total"`
@@ -173,6 +181,7 @@ func (s *OnboardingEmailService) UsersNeedingNoResidenceEmail() ([]models.User,
twoDaysAgo := time.Now().UTC().AddDate(0, 0, -2)
// TODO(hardening): Replace with OnboardingEmailRepository.FindUsersWithoutResidence()
// Find users who:
// 1. Are verified
// 2. Registered 2+ days ago
@@ -201,6 +210,7 @@ func (s *OnboardingEmailService) UsersNeedingNoTasksEmail() ([]models.User, erro
fiveDaysAgo := time.Now().UTC().AddDate(0, 0, -5)
// TODO(hardening): Replace with OnboardingEmailRepository.FindUsersWithoutTasks()
// Find users who:
// 1. Are verified
// 2. Have at least one residence
@@ -325,6 +335,7 @@ func (s *OnboardingEmailService) sendNoTasksEmail(user models.User) error {
// SendOnboardingEmailToUser manually sends an onboarding email to a specific user
// This is used by admin to force-send emails regardless of eligibility criteria
func (s *OnboardingEmailService) SendOnboardingEmailToUser(userID uint, emailType models.OnboardingEmailType) error {
// TODO(hardening): Replace with UserRepository.FindByID() (inject UserRepository)
// Load the user
var user models.User
if err := s.db.First(&user, userID).Error; err != nil {
@@ -362,6 +373,7 @@ func (s *OnboardingEmailService) SendOnboardingEmailToUser(userID uint, emailTyp
// If already sent before, delete the old record first to allow re-recording
// This allows admins to "resend" emails while still tracking them
if alreadySent {
// TODO(hardening): Replace with OnboardingEmailRepository.DeleteByUserAndType()
if err := s.db.Where("user_id = ? AND email_type = ?", userID, emailType).Delete(&models.OnboardingEmail{}).Error; err != nil {
log.Error().Err(err).Uint("user_id", userID).Str("email_type", string(emailType)).Msg("Failed to delete old onboarding email record before resend")
}

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"time"
"github.com/jung-kurt/gofpdf"
"github.com/go-pdf/fpdf"
)
// PDFService handles PDF generation
@@ -18,7 +18,7 @@ func NewPDFService() *PDFService {
// GenerateTasksReportPDF generates a PDF report from task report data
func (s *PDFService) GenerateTasksReportPDF(report *TasksReportResponse) ([]byte, error) {
pdf := gofpdf.New("P", "mm", "A4", "")
pdf := fpdf.New("P", "mm", "A4", "")
pdf.SetMargins(15, 15, 15)
pdf.AddPage()

View File

@@ -133,14 +133,16 @@ func (s *ResidenceService) GetMyResidences(userID uint, now time.Time) (*respons
}
}
// Attach completion summaries (honeycomb grid data)
for i := range residenceResponses {
summary, err := s.taskRepo.GetCompletionSummary(residenceResponses[i].ID, now, 10)
if err != nil {
log.Warn().Err(err).Uint("residence_id", residenceResponses[i].ID).Msg("Failed to fetch completion summary")
continue
// P-01: Batch fetch completion summaries in 2 queries total instead of 2*N
summaries, err := s.taskRepo.GetBatchCompletionSummaries(residenceIDs, now, 10)
if err != nil {
log.Warn().Err(err).Msg("Failed to fetch batch completion summaries")
} else {
for i := range residenceResponses {
if summary, ok := summaries[residenceResponses[i].ID]; ok {
residenceResponses[i].CompletionSummary = summary
}
}
residenceResponses[i].CompletionSummary = summary
}
}

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strings"
@@ -17,7 +18,8 @@ import (
// StorageService handles file uploads to local filesystem
type StorageService struct {
cfg *config.StorageConfig
cfg *config.StorageConfig
allowedTypes map[string]struct{} // P-12: Parsed once at init for O(1) lookups
}
// UploadResult contains information about an uploaded file
@@ -44,9 +46,18 @@ func NewStorageService(cfg *config.StorageConfig) (*StorageService, error) {
}
}
log.Info().Str("upload_dir", cfg.UploadDir).Msg("Storage service initialized")
// P-12: Parse AllowedTypes once at initialization for O(1) lookups
allowedTypes := make(map[string]struct{})
for _, t := range strings.Split(cfg.AllowedTypes, ",") {
trimmed := strings.TrimSpace(t)
if trimmed != "" {
allowedTypes[trimmed] = struct{}{}
}
}
return &StorageService{cfg: cfg}, nil
log.Info().Str("upload_dir", cfg.UploadDir).Int("allowed_types", len(allowedTypes)).Msg("Storage service initialized")
return &StorageService{cfg: cfg, allowedTypes: allowedTypes}, nil
}
// Upload saves a file to the local filesystem
@@ -56,17 +67,47 @@ func (s *StorageService) Upload(file *multipart.FileHeader, category string) (*U
return nil, fmt.Errorf("file size %d exceeds maximum allowed %d bytes", file.Size, s.cfg.MaxFileSize)
}
// Get MIME type
mimeType := file.Header.Get("Content-Type")
if mimeType == "" {
mimeType = "application/octet-stream"
// Get claimed MIME type from header
claimedMimeType := file.Header.Get("Content-Type")
if claimedMimeType == "" {
claimedMimeType = "application/octet-stream"
}
// Validate MIME type
// S-09: Detect actual content type from file bytes to prevent disguised uploads
src, err := file.Open()
if err != nil {
return nil, fmt.Errorf("failed to open uploaded file: %w", err)
}
defer src.Close()
// Read the first 512 bytes for content type detection
sniffBuf := make([]byte, 512)
n, err := src.Read(sniffBuf)
if err != nil && n == 0 {
return nil, fmt.Errorf("failed to read file for content type detection: %w", err)
}
detectedMimeType := http.DetectContentType(sniffBuf[:n])
// Validate that the detected type matches the claimed type (at the category level)
// Allow application/octet-stream from detection since DetectContentType may not
// recognize all valid types, but the claimed type must still be in our allowed list
if detectedMimeType != "application/octet-stream" && !s.mimeTypesCompatible(claimedMimeType, detectedMimeType) {
return nil, fmt.Errorf("file content type mismatch: claimed %s but detected %s", claimedMimeType, detectedMimeType)
}
// Use the claimed MIME type (which is more specific) if it's allowed
mimeType := claimedMimeType
// Validate MIME type against allowed list
if !s.isAllowedType(mimeType) {
return nil, fmt.Errorf("file type %s is not allowed", mimeType)
}
// Seek back to beginning after sniffing
if _, err := src.Seek(0, io.SeekStart); err != nil {
return nil, fmt.Errorf("failed to seek file: %w", err)
}
// Generate unique filename
ext := filepath.Ext(file.Filename)
if ext == "" {
@@ -83,15 +124,11 @@ func (s *StorageService) Upload(file *multipart.FileHeader, category string) (*U
subdir = "completions"
}
// Full path
destPath := filepath.Join(s.cfg.UploadDir, subdir, newFilename)
// Open source file
src, err := file.Open()
// S-18: Sanitize path to prevent traversal attacks
destPath, err := SafeResolvePath(s.cfg.UploadDir, filepath.Join(subdir, newFilename))
if err != nil {
return nil, fmt.Errorf("failed to open uploaded file: %w", err)
return nil, fmt.Errorf("invalid upload path: %w", err)
}
defer src.Close()
// Create destination file
dst, err := os.Create(destPath)
@@ -131,19 +168,11 @@ func (s *StorageService) Delete(fileURL string) error {
// Convert URL to file path
relativePath := strings.TrimPrefix(fileURL, s.cfg.BaseURL)
relativePath = strings.TrimPrefix(relativePath, "/")
fullPath := filepath.Join(s.cfg.UploadDir, relativePath)
// Security check: ensure path is within upload directory
absUploadDir, err := filepath.Abs(s.cfg.UploadDir)
// S-18: Use SafeResolvePath to prevent path traversal
fullPath, err := SafeResolvePath(s.cfg.UploadDir, relativePath)
if err != nil {
return fmt.Errorf("failed to resolve upload directory: %w", err)
}
absFilePath, err := filepath.Abs(fullPath)
if err != nil {
return fmt.Errorf("failed to resolve file path: %w", err)
}
if !strings.HasPrefix(absFilePath, absUploadDir+string(filepath.Separator)) && absFilePath != absUploadDir {
return fmt.Errorf("invalid file path")
return fmt.Errorf("invalid file path: %w", err)
}
if err := os.Remove(fullPath); err != nil {
@@ -157,15 +186,23 @@ func (s *StorageService) Delete(fileURL string) error {
return nil
}
// isAllowedType checks if the MIME type is in the allowed list
// isAllowedType checks if the MIME type is in the allowed list.
// P-12: Uses the pre-parsed allowedTypes map for O(1) lookups instead of
// splitting the config string on every call.
func (s *StorageService) isAllowedType(mimeType string) bool {
allowed := strings.Split(s.cfg.AllowedTypes, ",")
for _, t := range allowed {
if strings.TrimSpace(t) == mimeType {
return true
}
_, ok := s.allowedTypes[mimeType]
return ok
}
// mimeTypesCompatible checks if the claimed and detected MIME types are compatible.
// Two MIME types are compatible if they share the same primary type (e.g., both "image/*").
func (s *StorageService) mimeTypesCompatible(claimed, detected string) bool {
claimedParts := strings.SplitN(claimed, "/", 2)
detectedParts := strings.SplitN(detected, "/", 2)
if len(claimedParts) < 1 || len(detectedParts) < 1 {
return false
}
return false
return claimedParts[0] == detectedParts[0]
}
// getExtensionFromMimeType returns a file extension for common MIME types
@@ -191,5 +228,12 @@ func (s *StorageService) GetUploadDir() string {
// NewStorageServiceForTest creates a StorageService without creating directories.
// This is intended only for unit tests that need a StorageService with a known config.
func NewStorageServiceForTest(cfg *config.StorageConfig) *StorageService {
return &StorageService{cfg: cfg}
allowedTypes := make(map[string]struct{})
for _, t := range strings.Split(cfg.AllowedTypes, ",") {
trimmed := strings.TrimSpace(t)
if trimmed != "" {
allowedTypes[trimmed] = struct{}{}
}
}
return &StorageService{cfg: cfg, allowedTypes: allowedTypes}
}

View File

@@ -3,10 +3,10 @@ package services
import (
"encoding/json"
"fmt"
"os"
"time"
"github.com/rs/zerolog/log"
"github.com/spf13/viper"
"github.com/stripe/stripe-go/v81"
portalsession "github.com/stripe/stripe-go/v81/billingportal/session"
checkoutsession "github.com/stripe/stripe-go/v81/checkout/session"
@@ -34,7 +34,8 @@ func NewStripeService(
subscriptionRepo *repositories.SubscriptionRepository,
userRepo *repositories.UserRepository,
) *StripeService {
key := os.Getenv("STRIPE_SECRET_KEY")
// S-21: Use Viper config instead of os.Getenv for consistent configuration management
key := viper.GetString("STRIPE_SECRET_KEY")
if key == "" {
log.Warn().Msg("STRIPE_SECRET_KEY not set, Stripe integration will not work")
} else {
@@ -42,7 +43,7 @@ func NewStripeService(
log.Info().Msg("Stripe API key configured")
}
webhookSecret := os.Getenv("STRIPE_WEBHOOK_SECRET")
webhookSecret := viper.GetString("STRIPE_WEBHOOK_SECRET")
if webhookSecret == "" {
log.Warn().Msg("STRIPE_WEBHOOK_SECRET not set, webhook verification will fail")
}

View File

@@ -202,18 +202,19 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
}
// getUserUsage calculates current usage for a user.
// P-10: Uses CountByOwner for properties count instead of loading all owned residences.
// Uses batch COUNT queries (O(1) queries) instead of per-residence queries (O(N)).
func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error) {
residences, err := s.residenceRepo.FindOwnedByUser(userID)
// P-10: Use CountByOwner for an efficient COUNT query instead of loading all records
propertiesCount, err := s.residenceRepo.CountByOwner(userID)
if err != nil {
return nil, apperrors.Internal(err)
}
propertiesCount := int64(len(residences))
// Collect residence IDs for batch queries
residenceIDs := make([]uint, len(residences))
for i, r := range residences {
residenceIDs[i] = r.ID
// Still need residence IDs for batch counting tasks/contractors/documents
residenceIDs, err := s.residenceRepo.FindResidenceIDsByOwner(userID)
if err != nil {
return nil, apperrors.Internal(err)
}
// Count tasks, contractors, and documents across all residences with single queries each

View File

@@ -130,7 +130,7 @@ func (s *TaskService) ListTasks(userID uint, daysThreshold int, now time.Time) (
return nil, apperrors.Internal(err)
}
resp := responses.NewKanbanBoardResponseForAll(board)
resp := responses.NewKanbanBoardResponseForAll(board, now)
// NOTE: Summary statistics are calculated client-side from kanban data
return &resp, nil
}
@@ -157,7 +157,7 @@ func (s *TaskService) GetTasksByResidence(residenceID, userID uint, daysThreshol
return nil, apperrors.Internal(err)
}
resp := responses.NewKanbanBoardResponse(board, residenceID)
resp := responses.NewKanbanBoardResponse(board, residenceID, now)
// NOTE: Summary statistics are calculated client-side from kanban data
return &resp, nil
}
@@ -601,8 +601,8 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
task.InProgress = false
}
// P1-5: Wrap completion creation and task update in a transaction.
// If either operation fails, both are rolled back to prevent orphaned completions.
// P1-5 + B-07: Wrap completion creation, task update, and image creation
// in a single transaction for atomicity. If any operation fails, all are rolled back.
txErr := s.taskRepo.DB().Transaction(func(tx *gorm.DB) error {
if err := s.taskRepo.CreateCompletionTx(tx, completion); err != nil {
return err
@@ -610,6 +610,18 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
if err := s.taskRepo.UpdateTx(tx, task); err != nil {
return err
}
// B-07: Create images inside the same transaction as completion
for _, imageURL := range req.ImageURLs {
if imageURL != "" {
img := &models.TaskCompletionImage{
CompletionID: completion.ID,
ImageURL: imageURL,
}
if err := tx.Create(img).Error; err != nil {
return fmt.Errorf("failed to create completion image: %w", err)
}
}
}
return nil
})
if txErr != nil {
@@ -621,19 +633,6 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
return nil, apperrors.Internal(txErr)
}
// Create images if provided
for _, imageURL := range req.ImageURLs {
if imageURL != "" {
img := &models.TaskCompletionImage{
CompletionID: completion.ID,
ImageURL: imageURL,
}
if err := s.taskRepo.CreateCompletionImage(img); err != nil {
log.Error().Err(err).Uint("completion_id", completion.ID).Msg("Failed to create completion image")
}
}
}
// Reload completion with user info and images
completion, err = s.taskRepo.FindCompletionByID(completion.ID)
if err != nil {
@@ -663,8 +662,10 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
}, nil
}
// QuickComplete creates a minimal task completion (for widget use)
// Returns only success/error, no response body
// QuickComplete creates a minimal task completion (for widget use).
// LE-01: The entire operation (completion creation + task update) is wrapped in a
// transaction for atomicity.
// Returns only success/error, no response body.
func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
// Get the task
task, err := s.taskRepo.FindByID(taskID)
@@ -697,10 +698,6 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
CompletedFromColumn: completedFromColumn,
}
if err := s.taskRepo.CreateCompletion(completion); err != nil {
return apperrors.Internal(err)
}
// Update next_due_date and in_progress based on frequency
// Determine interval days: Custom frequency uses task.CustomIntervalDays, otherwise use frequency.Days
// Note: Frequency is no longer preloaded for performance, so we load it separately if needed
@@ -729,7 +726,6 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
} else {
// Recurring task - calculate next due date from completion date + interval
nextDue := completedAt.AddDate(0, 0, *quickIntervalDays)
// frequencyName was already set when loading frequency above
log.Info().
Uint("task_id", task.ID).
Str("frequency_name", frequencyName).
@@ -742,12 +738,23 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
// Reset in_progress to false
task.InProgress = false
}
if err := s.taskRepo.Update(task); err != nil {
if errors.Is(err, repositories.ErrVersionConflict) {
// LE-01: Wrap completion creation and task update in a transaction for atomicity
txErr := s.taskRepo.DB().Transaction(func(tx *gorm.DB) error {
if err := s.taskRepo.CreateCompletionTx(tx, completion); err != nil {
return err
}
if err := s.taskRepo.UpdateTx(tx, task); err != nil {
return err
}
return nil
})
if txErr != nil {
if errors.Is(txErr, repositories.ErrVersionConflict) {
return apperrors.Conflict("error.version_conflict")
}
log.Error().Err(err).Uint("task_id", task.ID).Msg("Failed to update task after quick completion")
return apperrors.Internal(err) // Return error so caller knows the update failed
log.Error().Err(txErr).Uint("task_id", task.ID).Msg("Failed to create completion and update task in QuickComplete")
return apperrors.Internal(txErr)
}
log.Info().Uint("task_id", task.ID).Msg("QuickComplete: Task updated successfully")
@@ -813,8 +820,16 @@ func (s *TaskService) sendTaskCompletedNotification(task *models.Task, completio
// Send email notification (to everyone INCLUDING the person who completed it)
// Check user's email notification preferences first
if s.emailService != nil && user.Email != "" && s.notificationService != nil {
prefs, err := s.notificationService.GetPreferences(user.ID)
if err != nil || (prefs != nil && prefs.EmailTaskCompleted) {
prefs, prefsErr := s.notificationService.GetPreferences(user.ID)
// LE-06: Log fail-open behavior when preferences cannot be loaded
if prefsErr != nil {
log.Warn().
Err(prefsErr).
Uint("user_id", user.ID).
Uint("task_id", task.ID).
Msg("Failed to load notification preferences, falling back to sending email (fail-open)")
}
if prefsErr != nil || (prefs != nil && prefs.EmailTaskCompleted) {
// Send email if we couldn't get prefs (fail-open) or if email notifications are enabled
if err := s.emailService.SendTaskCompletedEmail(
user.Email,

View File

@@ -32,7 +32,8 @@ func (s *UserService) ListUsersInSharedResidences(userID uint) ([]responses.User
return nil, apperrors.Internal(err)
}
var result []responses.UserSummary
// F-23: Initialize as empty slice so JSON serialization produces [] instead of null
result := make([]responses.UserSummary, 0, len(users))
for _, u := range users {
result = append(result, responses.UserSummary{
ID: u.ID,
@@ -72,7 +73,8 @@ func (s *UserService) ListProfilesInSharedResidences(userID uint) ([]responses.U
return nil, apperrors.Internal(err)
}
var result []responses.UserProfileSummary
// F-23: Initialize as empty slice so JSON serialization produces [] instead of null
result := make([]responses.UserProfileSummary, 0, len(profiles))
for _, p := range profiles {
result = append(result, responses.UserProfileSummary{
ID: p.ID,

View File

@@ -27,7 +27,10 @@ var testDB *gorm.DB
// testUserID is a user ID that exists in the database for foreign key constraints
var testUserID uint = 1
// TestMain sets up the database connection for all tests in this package
// TestMain sets up the database connection for all tests in this package.
// If the database is not available, testDB remains nil and individual tests
// will call t.Skip() instead of using os.Exit(0), which preserves proper
// test reporting and coverage output.
func TestMain(m *testing.M) {
dsn := os.Getenv("TEST_DATABASE_URL")
if dsn == "" {
@@ -39,15 +42,23 @@ func TestMain(m *testing.M) {
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
println("Skipping consistency integration tests: database not available")
// Explicitly nil out testDB; individual tests will t.Skip("Database not available")
testDB = nil
println("Consistency integration tests will be skipped: database not available")
println("Set TEST_DATABASE_URL to run these tests")
os.Exit(0)
os.Exit(m.Run())
}
sqlDB, err := testDB.DB()
if err != nil || sqlDB.Ping() != nil {
println("Failed to connect to database")
os.Exit(0)
if err != nil {
println("Failed to get underlying DB:", err.Error())
testDB = nil
os.Exit(m.Run())
}
if pingErr := sqlDB.Ping(); pingErr != nil {
println("Failed to ping database:", pingErr.Error())
testDB = nil
os.Exit(m.Run())
}
println("Database connected, running consistency tests...")

View File

@@ -17,7 +17,10 @@ import (
// testDB holds the database connection for integration tests
var testDB *gorm.DB
// TestMain sets up the database connection for all tests in this package
// TestMain sets up the database connection for all tests in this package.
// If the database is not available, testDB remains nil and individual tests
// will call t.Skip() instead of using os.Exit(0), which preserves proper
// test reporting and coverage output.
func TestMain(m *testing.M) {
// Get database URL from environment or use default
dsn := os.Getenv("TEST_DATABASE_URL")
@@ -30,22 +33,25 @@ func TestMain(m *testing.M) {
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
// Print message and skip tests if database is not available
println("Skipping scope integration tests: database not available")
// Explicitly nil out testDB; individual tests will t.Skip("Database not available")
testDB = nil
println("Scope integration tests will be skipped: database not available")
println("Set TEST_DATABASE_URL to run these tests")
println("Error:", err.Error())
os.Exit(0)
os.Exit(m.Run())
}
// Verify connection works
sqlDB, err := testDB.DB()
if err != nil {
println("Failed to get underlying DB:", err.Error())
os.Exit(0)
testDB = nil
os.Exit(m.Run())
}
if err := sqlDB.Ping(); err != nil {
println("Failed to ping database:", err.Error())
os.Exit(0)
testDB = nil
os.Exit(m.Run())
}
println("Database connected successfully, running integration tests...")
@@ -57,7 +63,9 @@ func TestMain(m *testing.M) {
&models.Residence{},
)
if err != nil {
os.Exit(1)
println("Failed to run migrations:", err.Error())
testDB = nil
os.Exit(m.Run())
}
// Run tests

View File

@@ -3,6 +3,7 @@ package testutil
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"sync"
@@ -71,7 +72,12 @@ func SetupTestDB(t *testing.T) *gorm.DB {
return db
}
// SetupTestRouter creates a test Echo router with the custom error handler
// SetupTestRouter creates a test Echo router with the custom error handler.
// Uses apperrors.HTTPErrorHandler which is the base error handler shared with
// production (router.customHTTPErrorHandler). Both handle AppError, ValidationErrors,
// and echo.HTTPError identically. Production additionally maps legacy service sentinel
// errors (e.g., services.ErrTaskNotFound) which are being migrated to AppError types.
// Tests exercise handlers that return AppError, so this handler covers all test scenarios.
func SetupTestRouter() *echo.Echo {
e := echo.New()
e.Validator = validator.NewCustomValidator()
@@ -79,17 +85,52 @@ func SetupTestRouter() *echo.Echo {
return e
}
// MakeRequest makes a test HTTP request and returns the response
// MakeRequest makes a test HTTP request and returns the response.
// Errors from JSON marshaling and HTTP request construction are checked and
// will panic if they occur, since these indicate programming errors in tests.
// Prefer MakeRequestT for new tests, which uses t.Fatal for better reporting.
func MakeRequest(router *echo.Echo, method, path string, body interface{}, token string) *httptest.ResponseRecorder {
var reqBody *bytes.Buffer
if body != nil {
jsonBody, _ := json.Marshal(body)
jsonBody, err := json.Marshal(body)
if err != nil {
panic(fmt.Sprintf("testutil.MakeRequest: failed to marshal request body: %v", err))
}
reqBody = bytes.NewBuffer(jsonBody)
} else {
reqBody = bytes.NewBuffer(nil)
}
req, _ := http.NewRequest(method, path, reqBody)
req, err := http.NewRequest(method, path, reqBody)
if err != nil {
panic(fmt.Sprintf("testutil.MakeRequest: failed to create HTTP request: %v", err))
}
req.Header.Set("Content-Type", "application/json")
if token != "" {
req.Header.Set("Authorization", "Token "+token)
}
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
return rec
}
// MakeRequestT is like MakeRequest but accepts a *testing.T for proper test
// failure reporting. Prefer this over MakeRequest in new tests.
func MakeRequestT(t *testing.T, router *echo.Echo, method, path string, body interface{}, token string) *httptest.ResponseRecorder {
t.Helper()
var reqBody *bytes.Buffer
if body != nil {
jsonBody, err := json.Marshal(body)
require.NoError(t, err, "failed to marshal request body")
reqBody = bytes.NewBuffer(jsonBody)
} else {
reqBody = bytes.NewBuffer(nil)
}
req, err := http.NewRequest(method, path, reqBody)
require.NoError(t, err, "failed to create HTTP request")
req.Header.Set("Content-Type", "application/json")
if token != "" {
req.Header.Set("Authorization", "Token "+token)
@@ -215,8 +256,12 @@ func CreateTestTask(t *testing.T, db *gorm.DB, residenceID, createdByID uint, ti
return task
}
// SeedLookupData seeds all lookup tables with test data
// SeedLookupData seeds all lookup tables with test data.
// All GORM create operations are checked for errors to prevent silent failures
// that could cause misleading test results.
func SeedLookupData(t *testing.T, db *gorm.DB) {
t.Helper()
// Residence types
residenceTypes := []models.ResidenceType{
{Name: "House"},
@@ -224,8 +269,9 @@ func SeedLookupData(t *testing.T, db *gorm.DB) {
{Name: "Condo"},
{Name: "Townhouse"},
}
for _, rt := range residenceTypes {
db.Create(&rt)
for i := range residenceTypes {
err := db.Create(&residenceTypes[i]).Error
require.NoError(t, err, "failed to seed residence type: %s", residenceTypes[i].Name)
}
// Task categories
@@ -235,8 +281,9 @@ func SeedLookupData(t *testing.T, db *gorm.DB) {
{Name: "HVAC", DisplayOrder: 3},
{Name: "General", DisplayOrder: 99},
}
for _, c := range categories {
db.Create(&c)
for i := range categories {
err := db.Create(&categories[i]).Error
require.NoError(t, err, "failed to seed task category: %s", categories[i].Name)
}
// Task priorities
@@ -246,8 +293,9 @@ func SeedLookupData(t *testing.T, db *gorm.DB) {
{Name: "High", Level: 3, DisplayOrder: 3},
{Name: "Urgent", Level: 4, DisplayOrder: 4},
}
for _, p := range priorities {
db.Create(&p)
for i := range priorities {
err := db.Create(&priorities[i]).Error
require.NoError(t, err, "failed to seed task priority: %s", priorities[i].Name)
}
// Task frequencies
@@ -258,8 +306,9 @@ func SeedLookupData(t *testing.T, db *gorm.DB) {
{Name: "Weekly", Days: &days7, DisplayOrder: 2},
{Name: "Monthly", Days: &days30, DisplayOrder: 3},
}
for _, f := range frequencies {
db.Create(&f)
for i := range frequencies {
err := db.Create(&frequencies[i]).Error
require.NoError(t, err, "failed to seed task frequency: %s", frequencies[i].Name)
}
// Contractor specialties
@@ -269,8 +318,9 @@ func SeedLookupData(t *testing.T, db *gorm.DB) {
{Name: "HVAC Technician"},
{Name: "Handyman"},
}
for _, s := range specialties {
db.Create(&s)
for i := range specialties {
err := db.Create(&specialties[i]).Error
require.NoError(t, err, "failed to seed contractor specialty: %s", specialties[i].Name)
}
}

View File

@@ -20,8 +20,6 @@ import (
// Task types
const (
TypeTaskReminder = "notification:task_reminder"
TypeOverdueReminder = "notification:overdue_reminder"
TypeSmartReminder = "notification:smart_reminder" // Frequency-aware reminders
TypeDailyDigest = "notification:daily_digest"
TypeSendEmail = "email:send"
@@ -36,6 +34,7 @@ type Handler struct {
taskRepo *repositories.TaskRepository
residenceRepo *repositories.ResidenceRepository
reminderRepo *repositories.ReminderRepository
notificationRepo *repositories.NotificationRepository
pushClient *push.Client
emailService *services.EmailService
notificationService *services.NotificationService
@@ -56,6 +55,7 @@ func NewHandler(db *gorm.DB, pushClient *push.Client, emailService *services.Ema
taskRepo: repositories.NewTaskRepository(db),
residenceRepo: repositories.NewResidenceRepository(db),
reminderRepo: repositories.NewReminderRepository(db),
notificationRepo: repositories.NewNotificationRepository(db),
pushClient: pushClient,
emailService: emailService,
notificationService: notificationService,
@@ -64,218 +64,6 @@ func NewHandler(db *gorm.DB, pushClient *push.Client, emailService *services.Ema
}
}
// TaskReminderData represents a task due soon for reminder notifications
type TaskReminderData struct {
TaskID uint
TaskTitle string
DueDate time.Time
UserID uint
UserEmail string
UserName string
ResidenceName string
}
// HandleTaskReminder processes task reminder notifications for tasks due today or tomorrow with actionable buttons
func (h *Handler) HandleTaskReminder(ctx context.Context, task *asynq.Task) error {
log.Info().Msg("Processing task reminder notifications...")
now := time.Now().UTC()
currentHour := now.Hour()
systemDefaultHour := h.config.Worker.TaskReminderHour
log.Info().Int("current_hour", currentHour).Int("system_default_hour", systemDefaultHour).Msg("Task reminder check")
// Step 1: Find users who should receive notifications THIS hour
// Logic: Each user gets notified ONCE per day at exactly ONE hour:
// - If user has custom hour set: notify ONLY at that custom hour
// - If user has NO custom hour (NULL): notify ONLY at system default hour
// This prevents duplicates: a user with custom hour is NEVER notified at default hour
var eligibleUserIDs []uint
query := h.db.Model(&models.NotificationPreference{}).
Select("user_id").
Where("task_due_soon = true")
if currentHour == systemDefaultHour {
// At system default hour: notify users who have NO custom hour (NULL) OR whose custom hour equals default
query = query.Where("task_due_soon_hour IS NULL OR task_due_soon_hour = ?", currentHour)
} else {
// At non-default hour: only notify users who have this specific custom hour set
// Exclude users with NULL (they get notified at default hour only)
query = query.Where("task_due_soon_hour = ?", currentHour)
}
err := query.Pluck("user_id", &eligibleUserIDs).Error
if err != nil {
log.Error().Err(err).Msg("Failed to query eligible users for task reminders")
return err
}
// Early exit if no users need notifications this hour
if len(eligibleUserIDs) == 0 {
log.Debug().Int("hour", currentHour).Msg("No users scheduled for task reminder notifications this hour")
return nil
}
log.Info().Int("eligible_users", len(eligibleUserIDs)).Msg("Found users eligible for task reminders this hour")
// Step 2: Query tasks due today or tomorrow using the single-purpose repository function
// Uses the same scopes as kanban for consistency, with IncludeInProgress=true
// so users still get notified about in-progress tasks that are due soon.
opts := repositories.TaskFilterOptions{
UserIDs: eligibleUserIDs,
IncludeInProgress: true, // Notifications should include in-progress tasks
PreloadResidence: true,
PreloadCompletions: true,
}
// Due soon = due within 2 days (today and tomorrow)
dueSoonTasks, err := h.taskRepo.GetDueSoonTasks(now, 2, opts)
if err != nil {
log.Error().Err(err).Msg("Failed to query tasks due soon")
return err
}
log.Info().Int("count", len(dueSoonTasks)).Msg("Found tasks due today/tomorrow for eligible users")
// Build set for O(1) eligibility lookups instead of O(N) linear scan
eligibleSet := make(map[uint]bool, len(eligibleUserIDs))
for _, id := range eligibleUserIDs {
eligibleSet[id] = true
}
// Group tasks by user (assigned_to or residence owner)
userTasks := make(map[uint][]models.Task)
for _, t := range dueSoonTasks {
var userID uint
if t.AssignedToID != nil {
userID = *t.AssignedToID
} else if t.Residence.ID != 0 {
userID = t.Residence.OwnerID
} else {
continue
}
// Only include if user is in eligible set (O(1) lookup)
if eligibleSet[userID] {
userTasks[userID] = append(userTasks[userID], t)
}
}
// Step 3: Send notifications (no need to check preferences again - already filtered)
// Send individual task-specific notification for each task (all tasks, no limit)
for userID, taskList := range userTasks {
for _, t := range taskList {
if err := h.notificationService.CreateAndSendTaskNotification(ctx, userID, models.NotificationTaskDueSoon, &t); err != nil {
log.Error().Err(err).Uint("user_id", userID).Uint("task_id", t.ID).Msg("Failed to send task reminder notification")
}
}
}
log.Info().Int("users_notified", len(userTasks)).Msg("Task reminder notifications completed")
return nil
}
// HandleOverdueReminder processes overdue task notifications with actionable buttons
func (h *Handler) HandleOverdueReminder(ctx context.Context, task *asynq.Task) error {
log.Info().Msg("Processing overdue task notifications...")
now := time.Now().UTC()
currentHour := now.Hour()
systemDefaultHour := h.config.Worker.OverdueReminderHour
log.Info().Int("current_hour", currentHour).Int("system_default_hour", systemDefaultHour).Msg("Overdue reminder check")
// Step 1: Find users who should receive notifications THIS hour
// Logic: Each user gets notified ONCE per day at exactly ONE hour:
// - If user has custom hour set: notify ONLY at that custom hour
// - If user has NO custom hour (NULL): notify ONLY at system default hour
// This prevents duplicates: a user with custom hour is NEVER notified at default hour
var eligibleUserIDs []uint
query := h.db.Model(&models.NotificationPreference{}).
Select("user_id").
Where("task_overdue = true")
if currentHour == systemDefaultHour {
// At system default hour: notify users who have NO custom hour (NULL) OR whose custom hour equals default
query = query.Where("task_overdue_hour IS NULL OR task_overdue_hour = ?", currentHour)
} else {
// At non-default hour: only notify users who have this specific custom hour set
// Exclude users with NULL (they get notified at default hour only)
query = query.Where("task_overdue_hour = ?", currentHour)
}
err := query.Pluck("user_id", &eligibleUserIDs).Error
if err != nil {
log.Error().Err(err).Msg("Failed to query eligible users for overdue reminders")
return err
}
// Early exit if no users need notifications this hour
if len(eligibleUserIDs) == 0 {
log.Debug().Int("hour", currentHour).Msg("No users scheduled for overdue notifications this hour")
return nil
}
log.Info().Int("eligible_users", len(eligibleUserIDs)).Msg("Found users eligible for overdue reminders this hour")
// Step 2: Query overdue tasks using the single-purpose repository function
// Uses the same scopes as kanban for consistency, with IncludeInProgress=true
// so users still get notified about in-progress tasks that are overdue.
opts := repositories.TaskFilterOptions{
UserIDs: eligibleUserIDs,
IncludeInProgress: true, // Notifications should include in-progress tasks
PreloadResidence: true,
PreloadCompletions: true,
}
overdueTasks, err := h.taskRepo.GetOverdueTasks(now, opts)
if err != nil {
log.Error().Err(err).Msg("Failed to query overdue tasks")
return err
}
log.Info().Int("count", len(overdueTasks)).Msg("Found overdue tasks for eligible users")
// Build set for O(1) eligibility lookups instead of O(N) linear scan
eligibleSet := make(map[uint]bool, len(eligibleUserIDs))
for _, id := range eligibleUserIDs {
eligibleSet[id] = true
}
// Group tasks by user (assigned_to or residence owner)
userTasks := make(map[uint][]models.Task)
for _, t := range overdueTasks {
var userID uint
if t.AssignedToID != nil {
userID = *t.AssignedToID
} else if t.Residence.ID != 0 {
userID = t.Residence.OwnerID
} else {
continue
}
// Only include if user is in eligible set (O(1) lookup)
if eligibleSet[userID] {
userTasks[userID] = append(userTasks[userID], t)
}
}
// Step 3: Send notifications (no need to check preferences again - already filtered)
// Send individual task-specific notification for each task (all tasks, no limit)
for userID, taskList := range userTasks {
for _, t := range taskList {
if err := h.notificationService.CreateAndSendTaskNotification(ctx, userID, models.NotificationTaskOverdue, &t); err != nil {
log.Error().Err(err).Uint("user_id", userID).Uint("task_id", t.ID).Msg("Failed to send overdue notification")
}
}
}
log.Info().Int("users_notified", len(userTasks)).Msg("Overdue task notifications completed")
return nil
}
// HandleDailyDigest processes daily digest notifications with task statistics
func (h *Handler) HandleDailyDigest(ctx context.Context, task *asynq.Task) error {
log.Info().Msg("Processing daily digest notifications...")
@@ -328,8 +116,7 @@ func (h *Handler) HandleDailyDigest(ctx context.Context, task *asynq.Task) error
// Get user's timezone from notification preferences for accurate overdue calculation
// This ensures the daily digest matches what the user sees in the kanban UI
var userNow time.Time
var prefs models.NotificationPreference
if err := h.db.Where("user_id = ?", userID).First(&prefs).Error; err == nil && prefs.Timezone != nil {
if prefs, err := h.notificationRepo.FindPreferencesByUser(userID); err == nil && prefs.Timezone != nil {
if loc, err := time.LoadLocation(*prefs.Timezone); err == nil {
// Use start of day in user's timezone (matches kanban behavior)
userNowInTz := time.Now().In(loc)
@@ -481,22 +268,11 @@ func (h *Handler) sendPushToUser(ctx context.Context, userID uint, title, messag
return nil
}
// Get iOS device tokens
var iosTokens []string
err := h.db.Model(&models.APNSDevice{}).
Where("user_id = ? AND active = ?", userID, true).
Pluck("registration_id", &iosTokens).Error
// Get active device tokens via repository
iosTokens, androidTokens, err := h.notificationRepo.GetActiveTokensForUser(userID)
if err != nil {
log.Error().Err(err).Uint("user_id", userID).Msg("Failed to get iOS tokens")
}
// Get Android device tokens
var androidTokens []string
err = h.db.Model(&models.GCMDevice{}).
Where("user_id = ? AND active = ?", userID, true).
Pluck("registration_id", &androidTokens).Error
if err != nil {
log.Error().Err(err).Uint("user_id", userID).Msg("Failed to get Android tokens")
log.Error().Err(err).Uint("user_id", userID).Msg("Failed to get device tokens")
return err
}
if len(iosTokens) == 0 && len(androidTokens) == 0 {
@@ -561,18 +337,17 @@ func (h *Handler) HandleOnboardingEmails(ctx context.Context, task *asynq.Task)
}
// Send no-residence emails (users without any residences after 2 days)
noResCount, err := h.onboardingService.CheckAndSendNoResidenceEmails()
if err != nil {
log.Error().Err(err).Msg("Failed to process no-residence onboarding emails")
// Continue to next type, don't return error
noResCount, noResErr := h.onboardingService.CheckAndSendNoResidenceEmails()
if noResErr != nil {
log.Error().Err(noResErr).Msg("Failed to process no-residence onboarding emails")
} else {
log.Info().Int("count", noResCount).Msg("Sent no-residence onboarding emails")
}
// Send no-tasks emails (users with residence but no tasks after 5 days)
noTasksCount, err := h.onboardingService.CheckAndSendNoTasksEmails()
if err != nil {
log.Error().Err(err).Msg("Failed to process no-tasks onboarding emails")
noTasksCount, noTasksErr := h.onboardingService.CheckAndSendNoTasksEmails()
if noTasksErr != nil {
log.Error().Err(noTasksErr).Msg("Failed to process no-tasks onboarding emails")
} else {
log.Info().Int("count", noTasksCount).Msg("Sent no-tasks onboarding emails")
}
@@ -582,6 +357,11 @@ func (h *Handler) HandleOnboardingEmails(ctx context.Context, task *asynq.Task)
Int("no_tasks_sent", noTasksCount).
Msg("Onboarding email processing completed")
// If all sub-tasks failed, return an error so Asynq retries
if noResErr != nil && noTasksErr != nil {
return fmt.Errorf("all onboarding email sub-tasks failed: no-residence: %w, no-tasks: %v", noResErr, noTasksErr)
}
return nil
}
@@ -603,7 +383,6 @@ func (h *Handler) HandleSmartReminder(ctx context.Context, task *asynq.Task) err
log.Info().Msg("Processing smart task reminders...")
now := time.Now().UTC()
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)
currentHour := now.Hour()
dueSoonDefault := h.config.Worker.TaskReminderHour
@@ -673,6 +452,22 @@ func (h *Handler) HandleSmartReminder(ctx context.Context, task *asynq.Task) err
Int("want_overdue", len(userWantsOverdue)).
Msg("Found users eligible for reminders")
// Build per-user "today" using their timezone preference
// This ensures reminder stage calculations (overdue, due-soon) match the user's local date
userToday := make(map[uint]time.Time, len(allUserIDs))
utcToday := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)
for _, uid := range allUserIDs {
prefs, err := h.notificationRepo.FindPreferencesByUser(uid)
if err == nil && prefs.Timezone != nil {
if loc, locErr := time.LoadLocation(*prefs.Timezone); locErr == nil {
userNowInTz := time.Now().In(loc)
userToday[uid] = time.Date(userNowInTz.Year(), userNowInTz.Month(), userNowInTz.Day(), 0, 0, 0, 0, loc)
continue
}
}
userToday[uid] = utcToday // Fallback to UTC
}
// Step 2: Single query to get ALL active tasks (both due-soon and overdue) for these users
opts := repositories.TaskFilterOptions{
UserIDs: allUserIDs,
@@ -734,7 +529,8 @@ func (h *Handler) HandleSmartReminder(ctx context.Context, task *asynq.Task) err
frequencyDays = &days
}
// Determine which reminder stage applies today
// Determine which reminder stage applies today using the user's local date
today := userToday[userID]
stage := notifications.GetReminderStageForToday(effectiveDate, frequencyDays, today)
if stage == "" {
continue
@@ -800,6 +596,11 @@ func (h *Handler) HandleSmartReminder(ctx context.Context, task *asynq.Task) err
notificationType = models.NotificationTaskDueSoon
}
// Log the reminder BEFORE sending so we have a record even if the send crashes
if _, err := h.reminderRepo.LogReminder(t.ID, c.userID, c.effectiveDate, c.reminderStage, nil); err != nil {
log.Error().Err(err).Uint("task_id", t.ID).Str("stage", c.stage).Msg("Failed to log reminder")
}
// Send notification
if err := h.notificationService.CreateAndSendTaskNotification(ctx, c.userID, notificationType, &t); err != nil {
log.Error().Err(err).
@@ -810,11 +611,6 @@ func (h *Handler) HandleSmartReminder(ctx context.Context, task *asynq.Task) err
continue
}
// Log the reminder
if _, err := h.reminderRepo.LogReminder(t.ID, c.userID, c.effectiveDate, c.reminderStage, nil); err != nil {
log.Error().Err(err).Uint("task_id", t.ID).Str("stage", c.stage).Msg("Failed to log reminder")
}
if c.isOverdue {
overdueSent++
} else {

View File

@@ -1,27 +1,18 @@
package worker
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/hibiken/asynq"
"github.com/rs/zerolog/log"
)
// Task types
// Task types for email jobs
const (
TypeWelcomeEmail = "email:welcome"
TypeVerificationEmail = "email:verification"
TypePasswordResetEmail = "email:password_reset"
TypePasswordChangedEmail = "email:password_changed"
TypeTaskCompletionEmail = "email:task_completion"
TypeGeneratePDFReport = "pdf:generate_report"
TypeUpdateContractorRating = "contractor:update_rating"
TypeDailyNotifications = "notifications:daily"
TypeTaskReminders = "notifications:task_reminders"
TypeOverdueReminders = "notifications:overdue_reminders"
TypeWelcomeEmail = "email:welcome"
TypeVerificationEmail = "email:verification"
TypePasswordResetEmail = "email:password_reset"
TypePasswordChangedEmail = "email:password_changed"
)
// EmailPayload is the base payload for email tasks
@@ -146,94 +137,3 @@ func (c *TaskClient) EnqueuePasswordChangedEmail(to, firstName string) error {
log.Debug().Str("to", to).Msg("Password changed email task enqueued")
return nil
}
// WorkerServer manages the asynq worker server
type WorkerServer struct {
server *asynq.Server
scheduler *asynq.Scheduler
}
// NewWorkerServer creates a new worker server
func NewWorkerServer(redisAddr string, concurrency int) *WorkerServer {
srv := asynq.NewServer(
asynq.RedisClientOpt{Addr: redisAddr},
asynq.Config{
Concurrency: concurrency,
Queues: map[string]int{
"critical": 6,
"default": 3,
"low": 1,
},
ErrorHandler: asynq.ErrorHandlerFunc(func(ctx context.Context, task *asynq.Task, err error) {
log.Error().
Err(err).
Str("type", task.Type()).
Bytes("payload", task.Payload()).
Msg("Task failed")
}),
},
)
// Create scheduler for periodic tasks
loc, _ := time.LoadLocation("UTC")
scheduler := asynq.NewScheduler(
asynq.RedisClientOpt{Addr: redisAddr},
&asynq.SchedulerOpts{
Location: loc,
},
)
return &WorkerServer{
server: srv,
scheduler: scheduler,
}
}
// RegisterHandlers registers task handlers
func (w *WorkerServer) RegisterHandlers(mux *asynq.ServeMux) {
// Handlers will be registered by the main worker process
}
// RegisterScheduledTasks registers periodic tasks
func (w *WorkerServer) RegisterScheduledTasks() error {
// Task reminders - 8:00 PM UTC daily
_, err := w.scheduler.Register("0 20 * * *", asynq.NewTask(TypeTaskReminders, nil))
if err != nil {
return fmt.Errorf("failed to register task reminders: %w", err)
}
// Overdue reminders - 9:00 AM UTC daily
_, err = w.scheduler.Register("0 9 * * *", asynq.NewTask(TypeOverdueReminders, nil))
if err != nil {
return fmt.Errorf("failed to register overdue reminders: %w", err)
}
// Daily notifications - 11:00 AM UTC daily
_, err = w.scheduler.Register("0 11 * * *", asynq.NewTask(TypeDailyNotifications, nil))
if err != nil {
return fmt.Errorf("failed to register daily notifications: %w", err)
}
return nil
}
// Start starts the worker server and scheduler
func (w *WorkerServer) Start(mux *asynq.ServeMux) error {
// Start scheduler
if err := w.scheduler.Start(); err != nil {
return fmt.Errorf("failed to start scheduler: %w", err)
}
// Start server
if err := w.server.Start(mux); err != nil {
return fmt.Errorf("failed to start worker server: %w", err)
}
return nil
}
// Shutdown gracefully shuts down the worker server
func (w *WorkerServer) Shutdown() {
w.scheduler.Shutdown()
w.server.Shutdown()
}

View File

@@ -2,12 +2,16 @@ package utils
import (
"io"
"net/http"
"os"
"runtime/debug"
"time"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/treytartt/honeydue-api/internal/dto/responses"
)
// InitLogger initializes the zerolog logger
@@ -113,14 +117,17 @@ func EchoRecovery() echo.MiddlewareFunc {
return func(c echo.Context) error {
defer func() {
if err := recover(); err != nil {
// F-14: Include full stack trace for debugging
log.Error().
Interface("error", err).
Str("path", c.Request().URL.Path).
Str("method", c.Request().Method).
Str("stack", string(debug.Stack())).
Msg("Panic recovered")
c.JSON(500, map[string]interface{}{
"error": "Internal server error",
// F-15: Use the project's standard ErrorResponse struct
c.JSON(http.StatusInternalServerError, responses.ErrorResponse{
Error: "Internal server error",
})
}
}()