Fix 113 hardening issues across entire Go backend
Security: - Replace all binding: tags with validate: + c.Validate() in admin handlers - Add rate limiting to auth endpoints (login, register, password reset) - Add security headers (HSTS, XSS protection, nosniff, frame options) - Wire Google Pub/Sub token verification into webhook handler - Replace ParseUnverified with proper OIDC/JWKS key verification - Verify inner Apple JWS signatures in webhook handler - Add io.LimitReader (1MB) to all webhook body reads - Add ownership verification to file deletion - Move hardcoded admin credentials to env vars - Add uniqueIndex to User.Email - Hide ConfirmationCode from JSON serialization - Mask confirmation codes in admin responses - Use http.DetectContentType for upload validation - Fix path traversal in storage service - Replace os.Getenv with Viper in stripe service - Sanitize Redis URLs before logging - Separate DEBUG_FIXED_CODES from DEBUG flag - Reject weak SECRET_KEY in production - Add host check on /_next/* proxy routes - Use explicit localhost CORS origins in debug mode - Replace err.Error() with generic messages in all admin error responses Critical fixes: - Rewrite FCM to HTTP v1 API with OAuth 2.0 service account auth - Fix user_customuser -> auth_user table names in raw SQL - Fix dashboard verified query to use UserProfile model - Add escapeLikeWildcards() to prevent SQL wildcard injection Bug fixes: - Add bounds checks for days/expiring_soon query params (1-3650) - Add receipt_data/transaction_id empty-check to RestoreSubscription - Change Active bool -> *bool in device handler - Check all unchecked GORM/FindByIDWithProfile errors - Add validation for notification hour fields (0-23) - Add max=10000 validation on task description updates Transactions & data integrity: - Wrap registration flow in transaction - Wrap QuickComplete in transaction - Move image creation inside completion transaction - Wrap SetSpecialties in transaction - Wrap GetOrCreateToken in transaction - Wrap completion+image deletion in transaction Performance: - Batch completion summaries (2 queries vs 2N) - Reuse single http.Client in IAP validation - Cache dashboard counts (30s TTL) - Batch COUNT queries in admin user list - Add Limit(500) to document queries - Add reminder_stage+due_date filters to reminder queries - Parse AllowedTypes once at init - In-memory user cache in auth middleware (30s TTL) - Timezone change detection cache - Optimize P95 with per-endpoint sorted buffers - Replace crypto/md5 with hash/fnv for ETags Code quality: - Add sync.Once to all monitoring Stop()/Close() methods - Replace 8 fmt.Printf with zerolog in auth service - Log previously discarded errors - Standardize delete response shapes - Route hardcoded English through i18n - Remove FileURL from DocumentResponse (keep MediaURL only) - Thread user timezone through kanban board responses - Initialize empty slices to prevent null JSON - Extract shared field map for task Update/UpdateTx - Delete unused SoftDeleteModel, min(), formatCron, legacy handlers Worker & jobs: - Wire Asynq email infrastructure into worker - Register HandleReminderLogCleanup with daily 3AM cron - Use per-user timezone in HandleSmartReminder - Replace direct DB queries with repository calls - Delete legacy reminder handlers (~200 lines) - Delete unused task type constants Dependencies: - Replace archived jung-kurt/gofpdf with go-pdf/fpdf - Replace unmaintained gomail.v2 with wneessen/go-mail - Add TODO for Echo jwt v3 transitive dep removal Test infrastructure: - Fix MakeRequest/SeedLookupData error handling - Replace os.Exit(0) with t.Skip() in scope/consistency tests - Add 11 new FCM v1 tests Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -47,7 +47,7 @@ func main() {
|
|||||||
Int("db_port", cfg.Database.Port).
|
Int("db_port", cfg.Database.Port).
|
||||||
Str("db_name", cfg.Database.Database).
|
Str("db_name", cfg.Database.Database).
|
||||||
Str("db_user", cfg.Database.User).
|
Str("db_user", cfg.Database.User).
|
||||||
Str("redis_url", cfg.Redis.URL).
|
Str("redis_url", config.MaskURLCredentials(cfg.Redis.URL)).
|
||||||
Msg("Starting HoneyDue API server")
|
Msg("Starting HoneyDue API server")
|
||||||
|
|
||||||
// Connect to database (retry with backoff)
|
// Connect to database (retry with backoff)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -115,12 +114,6 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check worker kill switch
|
|
||||||
if !cfg.Features.WorkerEnabled {
|
|
||||||
log.Warn().Msg("Worker disabled by FEATURE_WORKER_ENABLED=false, exiting")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create Asynq server
|
// Create Asynq server
|
||||||
srv := asynq.NewServer(
|
srv := asynq.NewServer(
|
||||||
redisOpt,
|
redisOpt,
|
||||||
@@ -151,6 +144,13 @@ func main() {
|
|||||||
mux.HandleFunc(jobs.TypeSendEmail, jobHandler.HandleSendEmail)
|
mux.HandleFunc(jobs.TypeSendEmail, jobHandler.HandleSendEmail)
|
||||||
mux.HandleFunc(jobs.TypeSendPush, jobHandler.HandleSendPush)
|
mux.HandleFunc(jobs.TypeSendPush, jobHandler.HandleSendPush)
|
||||||
mux.HandleFunc(jobs.TypeOnboardingEmails, jobHandler.HandleOnboardingEmails)
|
mux.HandleFunc(jobs.TypeOnboardingEmails, jobHandler.HandleOnboardingEmails)
|
||||||
|
mux.HandleFunc(jobs.TypeReminderLogCleanup, jobHandler.HandleReminderLogCleanup)
|
||||||
|
|
||||||
|
// Register email job handlers (welcome, verification, password reset, password changed)
|
||||||
|
if emailService != nil {
|
||||||
|
emailJobHandler := jobs.NewEmailJobHandler(emailService)
|
||||||
|
emailJobHandler.RegisterHandlers(mux)
|
||||||
|
}
|
||||||
|
|
||||||
// Start scheduler for periodic tasks
|
// Start scheduler for periodic tasks
|
||||||
scheduler := asynq.NewScheduler(redisOpt, nil)
|
scheduler := asynq.NewScheduler(redisOpt, nil)
|
||||||
@@ -177,6 +177,13 @@ func main() {
|
|||||||
}
|
}
|
||||||
log.Info().Str("cron", "0 10 * * *").Msg("Registered onboarding emails job (runs daily at 10:00 AM UTC)")
|
log.Info().Str("cron", "0 10 * * *").Msg("Registered onboarding emails job (runs daily at 10:00 AM UTC)")
|
||||||
|
|
||||||
|
// Schedule reminder log cleanup (runs daily at 3:00 AM UTC)
|
||||||
|
// Removes reminder logs older than 90 days to prevent table bloat
|
||||||
|
if _, err := scheduler.Register("0 3 * * *", asynq.NewTask(jobs.TypeReminderLogCleanup, nil)); err != nil {
|
||||||
|
log.Fatal().Err(err).Msg("Failed to register reminder log cleanup job")
|
||||||
|
}
|
||||||
|
log.Info().Str("cron", "0 3 * * *").Msg("Registered reminder log cleanup job (runs daily at 3:00 AM UTC)")
|
||||||
|
|
||||||
// Handle graceful shutdown
|
// Handle graceful shutdown
|
||||||
quit := make(chan os.Signal, 1)
|
quit := make(chan os.Signal, 1)
|
||||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||||
@@ -205,8 +212,3 @@ func main() {
|
|||||||
|
|
||||||
log.Info().Msg("Worker stopped")
|
log.Info().Msg("Worker stopped")
|
||||||
}
|
}
|
||||||
|
|
||||||
// formatCron creates a cron expression for a specific hour (runs at minute 0)
|
|
||||||
func formatCron(hour int) string {
|
|
||||||
return fmt.Sprintf("0 %02d * * *", hour)
|
|
||||||
}
|
|
||||||
|
|||||||
13
go.mod
13
go.mod
@@ -3,12 +3,12 @@ module github.com/treytartt/honeydue-api
|
|||||||
go 1.24.0
|
go 1.24.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/go-pdf/fpdf v0.9.0
|
||||||
github.com/go-playground/validator/v10 v10.23.0
|
github.com/go-playground/validator/v10 v10.23.0
|
||||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/gorilla/websocket v1.5.3
|
github.com/gorilla/websocket v1.5.3
|
||||||
github.com/hibiken/asynq v0.25.1
|
github.com/hibiken/asynq v0.25.1
|
||||||
github.com/jung-kurt/gofpdf v1.16.2
|
|
||||||
github.com/labstack/echo/v4 v4.11.4
|
github.com/labstack/echo/v4 v4.11.4
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.6.0
|
github.com/nicksnyder/go-i18n/v2 v2.6.0
|
||||||
github.com/redis/go-redis/v9 v9.17.1
|
github.com/redis/go-redis/v9 v9.17.1
|
||||||
@@ -18,11 +18,14 @@ require (
|
|||||||
github.com/sideshow/apns2 v0.25.0
|
github.com/sideshow/apns2 v0.25.0
|
||||||
github.com/spf13/viper v1.20.1
|
github.com/spf13/viper v1.20.1
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
|
github.com/stripe/stripe-go/v81 v81.4.0
|
||||||
|
github.com/wneessen/go-mail v0.7.2
|
||||||
golang.org/x/crypto v0.45.0
|
golang.org/x/crypto v0.45.0
|
||||||
golang.org/x/oauth2 v0.34.0
|
golang.org/x/oauth2 v0.34.0
|
||||||
golang.org/x/text v0.31.0
|
golang.org/x/text v0.31.0
|
||||||
|
golang.org/x/time v0.14.0
|
||||||
google.golang.org/api v0.257.0
|
google.golang.org/api v0.257.0
|
||||||
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
gorm.io/driver/postgres v1.6.0
|
gorm.io/driver/postgres v1.6.0
|
||||||
gorm.io/driver/sqlite v1.6.0
|
gorm.io/driver/sqlite v1.6.0
|
||||||
gorm.io/gorm v1.31.1
|
gorm.io/gorm v1.31.1
|
||||||
@@ -44,7 +47,7 @@ require (
|
|||||||
github.com/go-playground/locales v0.14.1 // indirect
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||||
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
|
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect; TODO(S-19): Pulled by echo/v4 middleware — upgrade Echo to v4.12+ which removes built-in JWT middleware (uses echo-jwt/v4 with jwt/v5 instead), eliminating this vulnerable transitive dep
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
|
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
|
||||||
github.com/google/s2a-go v0.1.9 // indirect
|
github.com/google/s2a-go v0.1.9 // indirect
|
||||||
github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect
|
github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect
|
||||||
@@ -71,7 +74,6 @@ require (
|
|||||||
github.com/spf13/afero v1.14.0 // indirect
|
github.com/spf13/afero v1.14.0 // indirect
|
||||||
github.com/spf13/cast v1.10.0 // indirect
|
github.com/spf13/cast v1.10.0 // indirect
|
||||||
github.com/spf13/pflag v1.0.10 // indirect
|
github.com/spf13/pflag v1.0.10 // indirect
|
||||||
github.com/stripe/stripe-go/v81 v81.4.0 // indirect
|
|
||||||
github.com/subosito/gotenv v1.6.0 // indirect
|
github.com/subosito/gotenv v1.6.0 // indirect
|
||||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||||
@@ -86,10 +88,7 @@ require (
|
|||||||
golang.org/x/net v0.47.0 // indirect
|
golang.org/x/net v0.47.0 // indirect
|
||||||
golang.org/x/sync v0.18.0 // indirect
|
golang.org/x/sync v0.18.0 // indirect
|
||||||
golang.org/x/sys v0.38.0 // indirect
|
golang.org/x/sys v0.38.0 // indirect
|
||||||
golang.org/x/time v0.14.0 // indirect
|
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846 // indirect
|
||||||
google.golang.org/grpc v1.77.0 // indirect
|
google.golang.org/grpc v1.77.0 // indirect
|
||||||
google.golang.org/protobuf v1.36.10 // indirect
|
google.golang.org/protobuf v1.36.10 // indirect
|
||||||
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
|
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
|
||||||
)
|
)
|
||||||
|
|||||||
18
go.sum
18
go.sum
@@ -8,7 +8,6 @@ github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg
|
|||||||
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||||
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
github.com/alecthomas/units v0.0.0-20201120081800-1786d5ef83d4/go.mod h1:OMCwj8VM1Kc9e19TLln2VL61YJF0x1XFtfdL4JdbSyE=
|
github.com/alecthomas/units v0.0.0-20201120081800-1786d5ef83d4/go.mod h1:OMCwj8VM1Kc9e19TLln2VL61YJF0x1XFtfdL4JdbSyE=
|
||||||
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
|
||||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||||
@@ -36,6 +35,8 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
|||||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||||
|
github.com/go-pdf/fpdf v0.9.0 h1:PPvSaUuo1iMi9KkaAn90NuKi+P4gwMedWPHhj8YlJQw=
|
||||||
|
github.com/go-pdf/fpdf v0.9.0/go.mod h1:oO8N111TkmKb9D7VvWGLvLJlaZUQVPM+6V42pp3iV4Y=
|
||||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||||
@@ -83,9 +84,6 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
|
|||||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||||
github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
|
||||||
github.com/jung-kurt/gofpdf v1.16.2 h1:jgbatWHfRlPYiK85qgevsZTHviWXKwB1TTiKdz5PtRc=
|
|
||||||
github.com/jung-kurt/gofpdf v1.16.2/go.mod h1:1hl7y57EsiPAkLbOwzpzqgx1A30nQCk/YmFV8S2vmK0=
|
|
||||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
@@ -110,8 +108,6 @@ github.com/nicksnyder/go-i18n/v2 v2.6.0 h1:C/m2NNWNiTB6SK4Ao8df5EWm3JETSTIGNXBpM
|
|||||||
github.com/nicksnyder/go-i18n/v2 v2.6.0/go.mod h1:88sRqr0C6OPyJn0/KRNaEz1uWorjxIKP7rUUcvycecE=
|
github.com/nicksnyder/go-i18n/v2 v2.6.0/go.mod h1:88sRqr0C6OPyJn0/KRNaEz1uWorjxIKP7rUUcvycecE=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||||
github.com/phpdave11/gofpdi v1.0.7/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
|
|
||||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
@@ -126,7 +122,6 @@ github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7
|
|||||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||||
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||||
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||||
github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfFZQK844Gfx8o5WFuvpxWRwnSoipWe/p622j1v06w=
|
|
||||||
github.com/sagikazarmark/locafero v0.9.0 h1:GbgQGNtTrEmddYDSAH9QLRyfAHY12md+8YFTqyMTC9k=
|
github.com/sagikazarmark/locafero v0.9.0 h1:GbgQGNtTrEmddYDSAH9QLRyfAHY12md+8YFTqyMTC9k=
|
||||||
github.com/sagikazarmark/locafero v0.9.0/go.mod h1:UBUyz37V+EdMS3hDF3QWIiVr/2dPrx49OMO0Bn0hJqk=
|
github.com/sagikazarmark/locafero v0.9.0/go.mod h1:UBUyz37V+EdMS3hDF3QWIiVr/2dPrx49OMO0Bn0hJqk=
|
||||||
github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI=
|
github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI=
|
||||||
@@ -150,7 +145,6 @@ github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3A
|
|||||||
github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4=
|
github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4=
|
||||||
github.com/spf13/viper v1.20.1/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4=
|
github.com/spf13/viper v1.20.1/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
@@ -168,6 +162,8 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw
|
|||||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||||
github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo=
|
github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo=
|
||||||
github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
|
github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
|
||||||
|
github.com/wneessen/go-mail v0.7.2 h1:xxPnhZ6IZLSgxShebmZ6DPKh1b6OJcoHfzy7UjOkzS8=
|
||||||
|
github.com/wneessen/go-mail v0.7.2/go.mod h1:+TkW6QP3EVkgTEqHtVmnAE/1MRhmzb8Y9/W3pweuS+k=
|
||||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||||
@@ -189,7 +185,6 @@ go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
|||||||
golang.org/x/crypto v0.0.0-20170512130425-ab89591268e0/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
golang.org/x/crypto v0.0.0-20170512130425-ab89591268e0/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||||
golang.org/x/image v0.0.0-20190910094157-69e4b8554b2a/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
|
||||||
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||||
golang.org/x/net v0.0.0-20220403103023-749bd193bc2b/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
golang.org/x/net v0.0.0-20220403103023-749bd193bc2b/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||||
@@ -213,7 +208,6 @@ golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
|||||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
|
||||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||||
@@ -237,13 +231,9 @@ google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHh
|
|||||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||||
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk=
|
|
||||||
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk=
|
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df h1:n7WqCuqOuCbNr617RXOY0AWRXxgwEyPp2z+p0+hgMuE=
|
|
||||||
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkpBDuZnXENFzX8qRjMDMyPD6BRkCw=
|
|
||||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
import "github.com/treytartt/honeydue-api/internal/middleware"
|
import (
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
|
"github.com/treytartt/honeydue-api/internal/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
// PaginationParams holds pagination query parameters
|
// PaginationParams holds pagination query parameters
|
||||||
type PaginationParams struct {
|
type PaginationParams struct {
|
||||||
@@ -115,9 +118,9 @@ type UpdateResidenceRequest struct {
|
|||||||
YearBuilt *int `json:"year_built"`
|
YearBuilt *int `json:"year_built"`
|
||||||
Description *string `json:"description"`
|
Description *string `json:"description"`
|
||||||
PurchaseDate *string `json:"purchase_date"`
|
PurchaseDate *string `json:"purchase_date"`
|
||||||
PurchasePrice *float64 `json:"purchase_price"`
|
PurchasePrice *decimal.Decimal `json:"purchase_price"`
|
||||||
IsActive *bool `json:"is_active"`
|
IsActive *bool `json:"is_active"`
|
||||||
IsPrimary *bool `json:"is_primary"`
|
IsPrimary *bool `json:"is_primary"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TaskFilters holds task-specific filter parameters
|
// TaskFilters holds task-specific filter parameters
|
||||||
@@ -144,8 +147,8 @@ type UpdateTaskRequest struct {
|
|||||||
InProgress *bool `json:"in_progress"`
|
InProgress *bool `json:"in_progress"`
|
||||||
DueDate *string `json:"due_date"`
|
DueDate *string `json:"due_date"`
|
||||||
NextDueDate *string `json:"next_due_date"`
|
NextDueDate *string `json:"next_due_date"`
|
||||||
EstimatedCost *float64 `json:"estimated_cost"`
|
EstimatedCost *decimal.Decimal `json:"estimated_cost"`
|
||||||
ActualCost *float64 `json:"actual_cost"`
|
ActualCost *decimal.Decimal `json:"actual_cost"`
|
||||||
ContractorID *uint `json:"contractor_id"`
|
ContractorID *uint `json:"contractor_id"`
|
||||||
ParentTaskID *uint `json:"parent_task_id"`
|
ParentTaskID *uint `json:"parent_task_id"`
|
||||||
IsCancelled *bool `json:"is_cancelled"`
|
IsCancelled *bool `json:"is_cancelled"`
|
||||||
@@ -201,8 +204,8 @@ type UpdateDocumentRequest struct {
|
|||||||
MimeType *string `json:"mime_type" validate:"omitempty,max=100"`
|
MimeType *string `json:"mime_type" validate:"omitempty,max=100"`
|
||||||
PurchaseDate *string `json:"purchase_date"`
|
PurchaseDate *string `json:"purchase_date"`
|
||||||
ExpiryDate *string `json:"expiry_date"`
|
ExpiryDate *string `json:"expiry_date"`
|
||||||
PurchasePrice *float64 `json:"purchase_price"`
|
PurchasePrice *decimal.Decimal `json:"purchase_price"`
|
||||||
Vendor *string `json:"vendor" validate:"omitempty,max=200"`
|
Vendor *string `json:"vendor" validate:"omitempty,max=200"`
|
||||||
SerialNumber *string `json:"serial_number" validate:"omitempty,max=100"`
|
SerialNumber *string `json:"serial_number" validate:"omitempty,max=100"`
|
||||||
ModelNumber *string `json:"model_number" validate:"omitempty,max=100"`
|
ModelNumber *string `json:"model_number" validate:"omitempty,max=100"`
|
||||||
Provider *string `json:"provider" validate:"omitempty,max=200"`
|
Provider *string `json:"provider" validate:"omitempty,max=200"`
|
||||||
@@ -292,9 +295,9 @@ type CreateTaskRequest struct {
|
|||||||
FrequencyID *uint `json:"frequency_id"`
|
FrequencyID *uint `json:"frequency_id"`
|
||||||
InProgress bool `json:"in_progress"`
|
InProgress bool `json:"in_progress"`
|
||||||
AssignedToID *uint `json:"assigned_to_id"`
|
AssignedToID *uint `json:"assigned_to_id"`
|
||||||
DueDate *string `json:"due_date"`
|
DueDate *string `json:"due_date"`
|
||||||
EstimatedCost *float64 `json:"estimated_cost"`
|
EstimatedCost *decimal.Decimal `json:"estimated_cost"`
|
||||||
ContractorID *uint `json:"contractor_id"`
|
ContractorID *uint `json:"contractor_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateContractorRequest for creating a new contractor
|
// CreateContractorRequest for creating a new contractor
|
||||||
@@ -328,8 +331,8 @@ type CreateDocumentRequest struct {
|
|||||||
MimeType string `json:"mime_type" validate:"max=100"`
|
MimeType string `json:"mime_type" validate:"max=100"`
|
||||||
PurchaseDate *string `json:"purchase_date"`
|
PurchaseDate *string `json:"purchase_date"`
|
||||||
ExpiryDate *string `json:"expiry_date"`
|
ExpiryDate *string `json:"expiry_date"`
|
||||||
PurchasePrice *float64 `json:"purchase_price"`
|
PurchasePrice *decimal.Decimal `json:"purchase_price"`
|
||||||
Vendor string `json:"vendor" validate:"max=200"`
|
Vendor string `json:"vendor" validate:"max=200"`
|
||||||
SerialNumber string `json:"serial_number" validate:"max=100"`
|
SerialNumber string `json:"serial_number" validate:"max=100"`
|
||||||
ModelNumber string `json:"model_number" validate:"max=100"`
|
ModelNumber string `json:"model_number" validate:"max=100"`
|
||||||
TaskID *uint `json:"task_id"`
|
TaskID *uint `json:"task_id"`
|
||||||
|
|||||||
@@ -33,21 +33,21 @@ type AdminUserFilters struct {
|
|||||||
|
|
||||||
// CreateAdminUserRequest for creating a new admin user
|
// CreateAdminUserRequest for creating a new admin user
|
||||||
type CreateAdminUserRequest struct {
|
type CreateAdminUserRequest struct {
|
||||||
Email string `json:"email" binding:"required,email"`
|
Email string `json:"email" validate:"required,email"`
|
||||||
Password string `json:"password" binding:"required,min=8"`
|
Password string `json:"password" validate:"required,min=8"`
|
||||||
FirstName string `json:"first_name" binding:"max=100"`
|
FirstName string `json:"first_name" validate:"max=100"`
|
||||||
LastName string `json:"last_name" binding:"max=100"`
|
LastName string `json:"last_name" validate:"max=100"`
|
||||||
Role string `json:"role" binding:"omitempty,oneof=admin super_admin"`
|
Role string `json:"role" validate:"omitempty,oneof=admin super_admin"`
|
||||||
IsActive *bool `json:"is_active"`
|
IsActive *bool `json:"is_active"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAdminUserRequest for updating an admin user
|
// UpdateAdminUserRequest for updating an admin user
|
||||||
type UpdateAdminUserRequest struct {
|
type UpdateAdminUserRequest struct {
|
||||||
Email *string `json:"email" binding:"omitempty,email"`
|
Email *string `json:"email" validate:"omitempty,email"`
|
||||||
Password *string `json:"password" binding:"omitempty,min=8"`
|
Password *string `json:"password" validate:"omitempty,min=8"`
|
||||||
FirstName *string `json:"first_name" binding:"omitempty,max=100"`
|
FirstName *string `json:"first_name" validate:"omitempty,max=100"`
|
||||||
LastName *string `json:"last_name" binding:"omitempty,max=100"`
|
LastName *string `json:"last_name" validate:"omitempty,max=100"`
|
||||||
Role *string `json:"role" binding:"omitempty,oneof=admin super_admin"`
|
Role *string `json:"role" validate:"omitempty,oneof=admin super_admin"`
|
||||||
IsActive *bool `json:"is_active"`
|
IsActive *bool `json:"is_active"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ type UpdateAdminUserRequest struct {
|
|||||||
func (h *AdminUserManagementHandler) List(c echo.Context) error {
|
func (h *AdminUserManagementHandler) List(c echo.Context) error {
|
||||||
var filters AdminUserFilters
|
var filters AdminUserFilters
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var adminUsers []models.AdminUser
|
var adminUsers []models.AdminUser
|
||||||
@@ -134,7 +134,10 @@ func (h *AdminUserManagementHandler) Create(c echo.Context) error {
|
|||||||
|
|
||||||
var req CreateAdminUserRequest
|
var req CreateAdminUserRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if email already exists
|
// Check if email already exists
|
||||||
@@ -199,7 +202,10 @@ func (h *AdminUserManagementHandler) Update(c echo.Context) error {
|
|||||||
|
|
||||||
var req UpdateAdminUserRequest
|
var req UpdateAdminUserRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Email != nil {
|
if req.Email != nil {
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ type UpdateAppleSocialAuthRequest struct {
|
|||||||
func (h *AdminAppleSocialAuthHandler) List(c echo.Context) error {
|
func (h *AdminAppleSocialAuthHandler) List(c echo.Context) error {
|
||||||
var filters dto.PaginationParams
|
var filters dto.PaginationParams
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var entries []models.AppleSocialAuth
|
var entries []models.AppleSocialAuth
|
||||||
@@ -139,7 +139,7 @@ func (h *AdminAppleSocialAuthHandler) Update(c echo.Context) error {
|
|||||||
|
|
||||||
var req UpdateAppleSocialAuthRequest
|
var req UpdateAppleSocialAuthRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Email != nil {
|
if req.Email != nil {
|
||||||
@@ -183,14 +183,15 @@ func (h *AdminAppleSocialAuthHandler) Delete(c echo.Context) error {
|
|||||||
func (h *AdminAppleSocialAuthHandler) BulkDelete(c echo.Context) error {
|
func (h *AdminAppleSocialAuthHandler) BulkDelete(c echo.Context) error {
|
||||||
var req dto.BulkDeleteRequest
|
var req dto.BulkDeleteRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.Where("id IN ?", req.IDs).Delete(&models.AppleSocialAuth{}).Error; err != nil {
|
result := h.db.Where("id IN ?", req.IDs).Delete(&models.AppleSocialAuth{})
|
||||||
|
if result.Error != nil {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth entries"})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth entries"})
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Apple social auth entries deleted successfully", "count": len(req.IDs)})
|
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Apple social auth entries deleted successfully", "count": result.RowsAffected})
|
||||||
}
|
}
|
||||||
|
|
||||||
// toResponse converts an AppleSocialAuth model to AppleSocialAuthResponse
|
// toResponse converts an AppleSocialAuth model to AppleSocialAuthResponse
|
||||||
|
|||||||
@@ -27,8 +27,8 @@ func NewAdminAuthHandler(adminRepo *repositories.AdminRepository, cfg *config.Co
|
|||||||
|
|
||||||
// LoginRequest represents the admin login request
|
// LoginRequest represents the admin login request
|
||||||
type LoginRequest struct {
|
type LoginRequest struct {
|
||||||
Email string `json:"email" binding:"required,email"`
|
Email string `json:"email" validate:"required,email"`
|
||||||
Password string `json:"password" binding:"required"`
|
Password string `json:"password" validate:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginResponse represents the admin login response
|
// LoginResponse represents the admin login response
|
||||||
@@ -71,7 +71,10 @@ func NewAdminUserResponse(admin *models.AdminUser) AdminUserResponse {
|
|||||||
func (h *AdminAuthHandler) Login(c echo.Context) error {
|
func (h *AdminAuthHandler) Login(c echo.Context) error {
|
||||||
var req LoginRequest
|
var req LoginRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request: " + err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find admin by email
|
// Find admin by email
|
||||||
@@ -100,7 +103,10 @@ func (h *AdminAuthHandler) Login(c echo.Context) error {
|
|||||||
_ = h.adminRepo.UpdateLastLogin(admin.ID)
|
_ = h.adminRepo.UpdateLastLogin(admin.ID)
|
||||||
|
|
||||||
// Refresh admin data after updating last login
|
// Refresh admin data after updating last login
|
||||||
admin, _ = h.adminRepo.FindByID(admin.ID)
|
admin, err = h.adminRepo.FindByID(admin.ID)
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to refresh admin data"})
|
||||||
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, LoginResponse{
|
return c.JSON(http.StatusOK, LoginResponse{
|
||||||
Token: token,
|
Token: token,
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ type AuthTokenResponse struct {
|
|||||||
func (h *AdminAuthTokenHandler) List(c echo.Context) error {
|
func (h *AdminAuthTokenHandler) List(c echo.Context) error {
|
||||||
var filters dto.PaginationParams
|
var filters dto.PaginationParams
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var tokens []models.AuthToken
|
var tokens []models.AuthToken
|
||||||
@@ -132,7 +132,7 @@ func (h *AdminAuthTokenHandler) Delete(c echo.Context) error {
|
|||||||
func (h *AdminAuthTokenHandler) BulkDelete(c echo.Context) error {
|
func (h *AdminAuthTokenHandler) BulkDelete(c echo.Context) error {
|
||||||
var req dto.BulkDeleteRequest
|
var req dto.BulkDeleteRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
result := h.db.Where("user_id IN ?", req.IDs).Delete(&models.AuthToken{})
|
result := h.db.Where("user_id IN ?", req.IDs).Delete(&models.AuthToken{})
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ type CompletionFilters struct {
|
|||||||
func (h *AdminCompletionHandler) List(c echo.Context) error {
|
func (h *AdminCompletionHandler) List(c echo.Context) error {
|
||||||
var filters CompletionFilters
|
var filters CompletionFilters
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var completions []models.TaskCompletion
|
var completions []models.TaskCompletion
|
||||||
@@ -167,7 +167,7 @@ func (h *AdminCompletionHandler) Delete(c echo.Context) error {
|
|||||||
func (h *AdminCompletionHandler) BulkDelete(c echo.Context) error {
|
func (h *AdminCompletionHandler) BulkDelete(c echo.Context) error {
|
||||||
var req dto.BulkDeleteRequest
|
var req dto.BulkDeleteRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
result := h.db.Where("id IN ?", req.IDs).Delete(&models.TaskCompletion{})
|
result := h.db.Where("id IN ?", req.IDs).Delete(&models.TaskCompletion{})
|
||||||
@@ -201,7 +201,7 @@ func (h *AdminCompletionHandler) Update(c echo.Context) error {
|
|||||||
|
|
||||||
var req UpdateCompletionRequest
|
var req UpdateCompletionRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Notes != nil {
|
if req.Notes != nil {
|
||||||
|
|||||||
@@ -35,8 +35,8 @@ type AdminCompletionImageResponse struct {
|
|||||||
|
|
||||||
// CreateCompletionImageRequest represents the request to create a completion image
|
// CreateCompletionImageRequest represents the request to create a completion image
|
||||||
type CreateCompletionImageRequest struct {
|
type CreateCompletionImageRequest struct {
|
||||||
CompletionID uint `json:"completion_id" binding:"required"`
|
CompletionID uint `json:"completion_id" validate:"required"`
|
||||||
ImageURL string `json:"image_url" binding:"required"`
|
ImageURL string `json:"image_url" validate:"required"`
|
||||||
Caption string `json:"caption"`
|
Caption string `json:"caption"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,7 +50,7 @@ type UpdateCompletionImageRequest struct {
|
|||||||
func (h *AdminCompletionImageHandler) List(c echo.Context) error {
|
func (h *AdminCompletionImageHandler) List(c echo.Context) error {
|
||||||
var filters dto.PaginationParams
|
var filters dto.PaginationParams
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optional completion_id filter
|
// Optional completion_id filter
|
||||||
@@ -91,10 +91,39 @@ func (h *AdminCompletionImageHandler) List(c echo.Context) error {
|
|||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch completion images"})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch completion images"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Batch-load completion+task info to avoid N+1 queries
|
||||||
|
completionIDs := make([]uint, 0, len(images))
|
||||||
|
for _, img := range images {
|
||||||
|
completionIDs = append(completionIDs, img.CompletionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
completionMap := make(map[uint]models.TaskCompletion)
|
||||||
|
if len(completionIDs) > 0 {
|
||||||
|
var completions []models.TaskCompletion
|
||||||
|
h.db.Preload("Task").Where("id IN ?", completionIDs).Find(&completions)
|
||||||
|
for _, c := range completions {
|
||||||
|
completionMap[c.ID] = c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Build response with task info
|
// Build response with task info
|
||||||
responses := make([]AdminCompletionImageResponse, len(images))
|
responses := make([]AdminCompletionImageResponse, len(images))
|
||||||
for i, image := range images {
|
for i, image := range images {
|
||||||
responses[i] = h.toResponse(&image)
|
response := AdminCompletionImageResponse{
|
||||||
|
ID: image.ID,
|
||||||
|
CompletionID: image.CompletionID,
|
||||||
|
ImageURL: image.ImageURL,
|
||||||
|
Caption: image.Caption,
|
||||||
|
CreatedAt: image.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||||
|
UpdatedAt: image.UpdatedAt.Format("2006-01-02T15:04:05Z"),
|
||||||
|
}
|
||||||
|
if comp, ok := completionMap[image.CompletionID]; ok {
|
||||||
|
response.TaskID = comp.TaskID
|
||||||
|
if comp.Task.ID != 0 {
|
||||||
|
response.TaskTitle = comp.Task.Title
|
||||||
|
}
|
||||||
|
}
|
||||||
|
responses[i] = response
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, dto.NewPaginatedResponse(responses, total, filters.GetPage(), filters.GetPerPage()))
|
return c.JSON(http.StatusOK, dto.NewPaginatedResponse(responses, total, filters.GetPage(), filters.GetPerPage()))
|
||||||
@@ -122,7 +151,10 @@ func (h *AdminCompletionImageHandler) Get(c echo.Context) error {
|
|||||||
func (h *AdminCompletionImageHandler) Create(c echo.Context) error {
|
func (h *AdminCompletionImageHandler) Create(c echo.Context) error {
|
||||||
var req CreateCompletionImageRequest
|
var req CreateCompletionImageRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify completion exists
|
// Verify completion exists
|
||||||
@@ -164,7 +196,7 @@ func (h *AdminCompletionImageHandler) Update(c echo.Context) error {
|
|||||||
|
|
||||||
var req UpdateCompletionImageRequest
|
var req UpdateCompletionImageRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.ImageURL != nil {
|
if req.ImageURL != nil {
|
||||||
@@ -207,14 +239,15 @@ func (h *AdminCompletionImageHandler) Delete(c echo.Context) error {
|
|||||||
func (h *AdminCompletionImageHandler) BulkDelete(c echo.Context) error {
|
func (h *AdminCompletionImageHandler) BulkDelete(c echo.Context) error {
|
||||||
var req dto.BulkDeleteRequest
|
var req dto.BulkDeleteRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.Where("id IN ?", req.IDs).Delete(&models.TaskCompletionImage{}).Error; err != nil {
|
result := h.db.Where("id IN ?", req.IDs).Delete(&models.TaskCompletionImage{})
|
||||||
|
if result.Error != nil {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete completion images"})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete completion images"})
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Completion images deleted successfully", "count": len(req.IDs)})
|
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Completion images deleted successfully", "count": result.RowsAffected})
|
||||||
}
|
}
|
||||||
|
|
||||||
// toResponse converts a TaskCompletionImage model to AdminCompletionImageResponse
|
// toResponse converts a TaskCompletionImage model to AdminCompletionImageResponse
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -11,6 +12,14 @@ import (
|
|||||||
"github.com/treytartt/honeydue-api/internal/models"
|
"github.com/treytartt/honeydue-api/internal/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// maskCode masks a confirmation code, showing only the last 4 characters.
|
||||||
|
func maskCode(code string) string {
|
||||||
|
if len(code) <= 4 {
|
||||||
|
return strings.Repeat("*", len(code))
|
||||||
|
}
|
||||||
|
return strings.Repeat("*", len(code)-4) + code[len(code)-4:]
|
||||||
|
}
|
||||||
|
|
||||||
// AdminConfirmationCodeHandler handles admin confirmation code management endpoints
|
// AdminConfirmationCodeHandler handles admin confirmation code management endpoints
|
||||||
type AdminConfirmationCodeHandler struct {
|
type AdminConfirmationCodeHandler struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
@@ -37,7 +46,7 @@ type ConfirmationCodeResponse struct {
|
|||||||
func (h *AdminConfirmationCodeHandler) List(c echo.Context) error {
|
func (h *AdminConfirmationCodeHandler) List(c echo.Context) error {
|
||||||
var filters dto.PaginationParams
|
var filters dto.PaginationParams
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var codes []models.ConfirmationCode
|
var codes []models.ConfirmationCode
|
||||||
@@ -79,7 +88,7 @@ func (h *AdminConfirmationCodeHandler) List(c echo.Context) error {
|
|||||||
UserID: code.UserID,
|
UserID: code.UserID,
|
||||||
Username: code.User.Username,
|
Username: code.User.Username,
|
||||||
Email: code.User.Email,
|
Email: code.User.Email,
|
||||||
Code: code.Code,
|
Code: maskCode(code.Code),
|
||||||
ExpiresAt: code.ExpiresAt.Format("2006-01-02T15:04:05Z"),
|
ExpiresAt: code.ExpiresAt.Format("2006-01-02T15:04:05Z"),
|
||||||
IsUsed: code.IsUsed,
|
IsUsed: code.IsUsed,
|
||||||
CreatedAt: code.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
CreatedAt: code.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||||
@@ -109,7 +118,7 @@ func (h *AdminConfirmationCodeHandler) Get(c echo.Context) error {
|
|||||||
UserID: code.UserID,
|
UserID: code.UserID,
|
||||||
Username: code.User.Username,
|
Username: code.User.Username,
|
||||||
Email: code.User.Email,
|
Email: code.User.Email,
|
||||||
Code: code.Code,
|
Code: maskCode(code.Code),
|
||||||
ExpiresAt: code.ExpiresAt.Format("2006-01-02T15:04:05Z"),
|
ExpiresAt: code.ExpiresAt.Format("2006-01-02T15:04:05Z"),
|
||||||
IsUsed: code.IsUsed,
|
IsUsed: code.IsUsed,
|
||||||
CreatedAt: code.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
CreatedAt: code.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||||
@@ -141,7 +150,7 @@ func (h *AdminConfirmationCodeHandler) Delete(c echo.Context) error {
|
|||||||
func (h *AdminConfirmationCodeHandler) BulkDelete(c echo.Context) error {
|
func (h *AdminConfirmationCodeHandler) BulkDelete(c echo.Context) error {
|
||||||
var req dto.BulkDeleteRequest
|
var req dto.BulkDeleteRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
result := h.db.Where("id IN ?", req.IDs).Delete(&models.ConfirmationCode{})
|
result := h.db.Where("id IN ?", req.IDs).Delete(&models.ConfirmationCode{})
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ func NewAdminContractorHandler(db *gorm.DB) *AdminContractorHandler {
|
|||||||
func (h *AdminContractorHandler) List(c echo.Context) error {
|
func (h *AdminContractorHandler) List(c echo.Context) error {
|
||||||
var filters dto.ContractorFilters
|
var filters dto.ContractorFilters
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var contractors []models.Contractor
|
var contractors []models.Contractor
|
||||||
@@ -130,7 +130,7 @@ func (h *AdminContractorHandler) Update(c echo.Context) error {
|
|||||||
|
|
||||||
var req dto.UpdateContractorRequest
|
var req dto.UpdateContractorRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify residence if changing
|
// Verify residence if changing
|
||||||
@@ -213,7 +213,7 @@ func (h *AdminContractorHandler) Update(c echo.Context) error {
|
|||||||
func (h *AdminContractorHandler) Create(c echo.Context) error {
|
func (h *AdminContractorHandler) Create(c echo.Context) error {
|
||||||
var req dto.CreateContractorRequest
|
var req dto.CreateContractorRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify residence exists if provided
|
// Verify residence exists if provided
|
||||||
@@ -290,7 +290,7 @@ func (h *AdminContractorHandler) Delete(c echo.Context) error {
|
|||||||
func (h *AdminContractorHandler) BulkDelete(c echo.Context) error {
|
func (h *AdminContractorHandler) BulkDelete(c echo.Context) error {
|
||||||
var req dto.BulkDeleteRequest
|
var req dto.BulkDeleteRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Soft delete - deactivate all
|
// Soft delete - deactivate all
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package handlers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
@@ -11,6 +12,18 @@ import (
|
|||||||
"github.com/treytartt/honeydue-api/internal/task/scopes"
|
"github.com/treytartt/honeydue-api/internal/task/scopes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// dashboardCache provides short-lived in-memory caching for dashboard stats
|
||||||
|
// to avoid expensive COUNT queries on every request.
|
||||||
|
type dashboardCache struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
stats *DashboardStats
|
||||||
|
expiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
var dashCache = &dashboardCache{}
|
||||||
|
|
||||||
|
const dashboardCacheTTL = 30 * time.Second
|
||||||
|
|
||||||
// AdminDashboardHandler handles admin dashboard endpoints
|
// AdminDashboardHandler handles admin dashboard endpoints
|
||||||
type AdminDashboardHandler struct {
|
type AdminDashboardHandler struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
@@ -94,6 +107,15 @@ type SubscriptionStats struct {
|
|||||||
|
|
||||||
// GetStats handles GET /api/admin/dashboard/stats
|
// GetStats handles GET /api/admin/dashboard/stats
|
||||||
func (h *AdminDashboardHandler) GetStats(c echo.Context) error {
|
func (h *AdminDashboardHandler) GetStats(c echo.Context) error {
|
||||||
|
// Check cache first
|
||||||
|
dashCache.mu.RLock()
|
||||||
|
if dashCache.stats != nil && time.Now().Before(dashCache.expiresAt) {
|
||||||
|
cached := *dashCache.stats
|
||||||
|
dashCache.mu.RUnlock()
|
||||||
|
return c.JSON(http.StatusOK, cached)
|
||||||
|
}
|
||||||
|
dashCache.mu.RUnlock()
|
||||||
|
|
||||||
stats := DashboardStats{}
|
stats := DashboardStats{}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
thirtyDaysAgo := now.AddDate(0, 0, -30)
|
thirtyDaysAgo := now.AddDate(0, 0, -30)
|
||||||
@@ -101,7 +123,7 @@ func (h *AdminDashboardHandler) GetStats(c echo.Context) error {
|
|||||||
// User stats
|
// User stats
|
||||||
h.db.Model(&models.User{}).Count(&stats.Users.Total)
|
h.db.Model(&models.User{}).Count(&stats.Users.Total)
|
||||||
h.db.Model(&models.User{}).Where("is_active = ?", true).Count(&stats.Users.Active)
|
h.db.Model(&models.User{}).Where("is_active = ?", true).Count(&stats.Users.Active)
|
||||||
h.db.Model(&models.User{}).Where("verified = ?", true).Count(&stats.Users.Verified)
|
h.db.Model(&models.UserProfile{}).Where("verified = ?", true).Count(&stats.Users.Verified)
|
||||||
h.db.Model(&models.User{}).Where("date_joined >= ?", thirtyDaysAgo).Count(&stats.Users.New30d)
|
h.db.Model(&models.User{}).Where("date_joined >= ?", thirtyDaysAgo).Count(&stats.Users.New30d)
|
||||||
|
|
||||||
// Residence stats
|
// Residence stats
|
||||||
@@ -164,5 +186,11 @@ func (h *AdminDashboardHandler) GetStats(c echo.Context) error {
|
|||||||
h.db.Model(&models.UserSubscription{}).Where("tier = ?", "premium").Count(&stats.Subscriptions.Premium)
|
h.db.Model(&models.UserSubscription{}).Where("tier = ?", "premium").Count(&stats.Subscriptions.Premium)
|
||||||
h.db.Model(&models.UserSubscription{}).Where("tier = ?", "pro").Count(&stats.Subscriptions.Pro)
|
h.db.Model(&models.UserSubscription{}).Where("tier = ?", "pro").Count(&stats.Subscriptions.Pro)
|
||||||
|
|
||||||
|
// Cache the result
|
||||||
|
dashCache.mu.Lock()
|
||||||
|
dashCache.stats = &stats
|
||||||
|
dashCache.expiresAt = time.Now().Add(dashboardCacheTTL)
|
||||||
|
dashCache.mu.Unlock()
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, stats)
|
return c.JSON(http.StatusOK, stats)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ type GCMDeviceResponse struct {
|
|||||||
func (h *AdminDeviceHandler) ListAPNS(c echo.Context) error {
|
func (h *AdminDeviceHandler) ListAPNS(c echo.Context) error {
|
||||||
var filters dto.PaginationParams
|
var filters dto.PaginationParams
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var devices []models.APNSDevice
|
var devices []models.APNSDevice
|
||||||
@@ -106,7 +106,7 @@ func (h *AdminDeviceHandler) ListAPNS(c echo.Context) error {
|
|||||||
func (h *AdminDeviceHandler) ListGCM(c echo.Context) error {
|
func (h *AdminDeviceHandler) ListGCM(c echo.Context) error {
|
||||||
var filters dto.PaginationParams
|
var filters dto.PaginationParams
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var devices []models.GCMDevice
|
var devices []models.GCMDevice
|
||||||
@@ -174,13 +174,15 @@ func (h *AdminDeviceHandler) UpdateAPNS(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var req struct {
|
var req struct {
|
||||||
Active bool `json:"active"`
|
Active *bool `json:"active"`
|
||||||
}
|
}
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
device.Active = req.Active
|
if req.Active != nil {
|
||||||
|
device.Active = *req.Active
|
||||||
|
}
|
||||||
if err := h.db.Save(&device).Error; err != nil {
|
if err := h.db.Save(&device).Error; err != nil {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update device"})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update device"})
|
||||||
}
|
}
|
||||||
@@ -204,13 +206,15 @@ func (h *AdminDeviceHandler) UpdateGCM(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var req struct {
|
var req struct {
|
||||||
Active bool `json:"active"`
|
Active *bool `json:"active"`
|
||||||
}
|
}
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
device.Active = req.Active
|
if req.Active != nil {
|
||||||
|
device.Active = *req.Active
|
||||||
|
}
|
||||||
if err := h.db.Save(&device).Error; err != nil {
|
if err := h.db.Save(&device).Error; err != nil {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update device"})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update device"})
|
||||||
}
|
}
|
||||||
@@ -260,7 +264,7 @@ func (h *AdminDeviceHandler) DeleteGCM(c echo.Context) error {
|
|||||||
func (h *AdminDeviceHandler) BulkDeleteAPNS(c echo.Context) error {
|
func (h *AdminDeviceHandler) BulkDeleteAPNS(c echo.Context) error {
|
||||||
var req dto.BulkDeleteRequest
|
var req dto.BulkDeleteRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
result := h.db.Where("id IN ?", req.IDs).Delete(&models.APNSDevice{})
|
result := h.db.Where("id IN ?", req.IDs).Delete(&models.APNSDevice{})
|
||||||
@@ -275,7 +279,7 @@ func (h *AdminDeviceHandler) BulkDeleteAPNS(c echo.Context) error {
|
|||||||
func (h *AdminDeviceHandler) BulkDeleteGCM(c echo.Context) error {
|
func (h *AdminDeviceHandler) BulkDeleteGCM(c echo.Context) error {
|
||||||
var req dto.BulkDeleteRequest
|
var req dto.BulkDeleteRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
result := h.db.Where("id IN ?", req.IDs).Delete(&models.GCMDevice{})
|
result := h.db.Where("id IN ?", req.IDs).Delete(&models.GCMDevice{})
|
||||||
@@ -307,10 +311,3 @@ func (h *AdminDeviceHandler) GetStats(c echo.Context) error {
|
|||||||
"total": apnsTotal + gcmTotal,
|
"total": apnsTotal + gcmTotal,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func min(a, b int) int {
|
|
||||||
if a < b {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/shopspring/decimal"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/treytartt/honeydue-api/internal/admin/dto"
|
"github.com/treytartt/honeydue-api/internal/admin/dto"
|
||||||
@@ -27,7 +26,7 @@ func NewAdminDocumentHandler(db *gorm.DB) *AdminDocumentHandler {
|
|||||||
func (h *AdminDocumentHandler) List(c echo.Context) error {
|
func (h *AdminDocumentHandler) List(c echo.Context) error {
|
||||||
var filters dto.DocumentFilters
|
var filters dto.DocumentFilters
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var documents []models.Document
|
var documents []models.Document
|
||||||
@@ -132,7 +131,7 @@ func (h *AdminDocumentHandler) Update(c echo.Context) error {
|
|||||||
|
|
||||||
var req dto.UpdateDocumentRequest
|
var req dto.UpdateDocumentRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify residence if changing
|
// Verify residence if changing
|
||||||
@@ -183,8 +182,7 @@ func (h *AdminDocumentHandler) Update(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if req.PurchasePrice != nil {
|
if req.PurchasePrice != nil {
|
||||||
d := decimal.NewFromFloat(*req.PurchasePrice)
|
document.PurchasePrice = req.PurchasePrice
|
||||||
document.PurchasePrice = &d
|
|
||||||
}
|
}
|
||||||
if req.Vendor != nil {
|
if req.Vendor != nil {
|
||||||
document.Vendor = *req.Vendor
|
document.Vendor = *req.Vendor
|
||||||
@@ -232,7 +230,7 @@ func (h *AdminDocumentHandler) Update(c echo.Context) error {
|
|||||||
func (h *AdminDocumentHandler) Create(c echo.Context) error {
|
func (h *AdminDocumentHandler) Create(c echo.Context) error {
|
||||||
var req dto.CreateDocumentRequest
|
var req dto.CreateDocumentRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify residence exists
|
// Verify residence exists
|
||||||
@@ -282,8 +280,7 @@ func (h *AdminDocumentHandler) Create(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if req.PurchasePrice != nil {
|
if req.PurchasePrice != nil {
|
||||||
d := decimal.NewFromFloat(*req.PurchasePrice)
|
document.PurchasePrice = req.PurchasePrice
|
||||||
document.PurchasePrice = &d
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.Create(&document).Error; err != nil {
|
if err := h.db.Create(&document).Error; err != nil {
|
||||||
@@ -322,7 +319,7 @@ func (h *AdminDocumentHandler) Delete(c echo.Context) error {
|
|||||||
func (h *AdminDocumentHandler) BulkDelete(c echo.Context) error {
|
func (h *AdminDocumentHandler) BulkDelete(c echo.Context) error {
|
||||||
var req dto.BulkDeleteRequest
|
var req dto.BulkDeleteRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Soft delete - deactivate all
|
// Soft delete - deactivate all
|
||||||
|
|||||||
@@ -36,8 +36,8 @@ type DocumentImageResponse struct {
|
|||||||
|
|
||||||
// CreateDocumentImageRequest represents the request to create a document image
|
// CreateDocumentImageRequest represents the request to create a document image
|
||||||
type CreateDocumentImageRequest struct {
|
type CreateDocumentImageRequest struct {
|
||||||
DocumentID uint `json:"document_id" binding:"required"`
|
DocumentID uint `json:"document_id" validate:"required"`
|
||||||
ImageURL string `json:"image_url" binding:"required"`
|
ImageURL string `json:"image_url" validate:"required"`
|
||||||
Caption string `json:"caption"`
|
Caption string `json:"caption"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,7 +51,7 @@ type UpdateDocumentImageRequest struct {
|
|||||||
func (h *AdminDocumentImageHandler) List(c echo.Context) error {
|
func (h *AdminDocumentImageHandler) List(c echo.Context) error {
|
||||||
var filters dto.PaginationParams
|
var filters dto.PaginationParams
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optional document_id filter
|
// Optional document_id filter
|
||||||
@@ -123,7 +123,10 @@ func (h *AdminDocumentImageHandler) Get(c echo.Context) error {
|
|||||||
func (h *AdminDocumentImageHandler) Create(c echo.Context) error {
|
func (h *AdminDocumentImageHandler) Create(c echo.Context) error {
|
||||||
var req CreateDocumentImageRequest
|
var req CreateDocumentImageRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify document exists
|
// Verify document exists
|
||||||
@@ -165,7 +168,7 @@ func (h *AdminDocumentImageHandler) Update(c echo.Context) error {
|
|||||||
|
|
||||||
var req UpdateDocumentImageRequest
|
var req UpdateDocumentImageRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.ImageURL != nil {
|
if req.ImageURL != nil {
|
||||||
@@ -208,14 +211,15 @@ func (h *AdminDocumentImageHandler) Delete(c echo.Context) error {
|
|||||||
func (h *AdminDocumentImageHandler) BulkDelete(c echo.Context) error {
|
func (h *AdminDocumentImageHandler) BulkDelete(c echo.Context) error {
|
||||||
var req dto.BulkDeleteRequest
|
var req dto.BulkDeleteRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.Where("id IN ?", req.IDs).Delete(&models.DocumentImage{}).Error; err != nil {
|
result := h.db.Where("id IN ?", req.IDs).Delete(&models.DocumentImage{})
|
||||||
|
if result.Error != nil {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete document images"})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete document images"})
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Document images deleted successfully", "count": len(req.IDs)})
|
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Document images deleted successfully", "count": result.RowsAffected})
|
||||||
}
|
}
|
||||||
|
|
||||||
// toResponse converts a DocumentImage model to DocumentImageResponse
|
// toResponse converts a DocumentImage model to DocumentImageResponse
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ type FeatureBenefitResponse struct {
|
|||||||
func (h *AdminFeatureBenefitHandler) List(c echo.Context) error {
|
func (h *AdminFeatureBenefitHandler) List(c echo.Context) error {
|
||||||
var filters dto.PaginationParams
|
var filters dto.PaginationParams
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var benefits []models.FeatureBenefit
|
var benefits []models.FeatureBenefit
|
||||||
@@ -112,15 +112,18 @@ func (h *AdminFeatureBenefitHandler) Get(c echo.Context) error {
|
|||||||
// Create handles POST /api/admin/feature-benefits
|
// Create handles POST /api/admin/feature-benefits
|
||||||
func (h *AdminFeatureBenefitHandler) Create(c echo.Context) error {
|
func (h *AdminFeatureBenefitHandler) Create(c echo.Context) error {
|
||||||
var req struct {
|
var req struct {
|
||||||
FeatureName string `json:"feature_name" binding:"required"`
|
FeatureName string `json:"feature_name" validate:"required"`
|
||||||
FreeTierText string `json:"free_tier_text" binding:"required"`
|
FreeTierText string `json:"free_tier_text" validate:"required"`
|
||||||
ProTierText string `json:"pro_tier_text" binding:"required"`
|
ProTierText string `json:"pro_tier_text" validate:"required"`
|
||||||
DisplayOrder int `json:"display_order"`
|
DisplayOrder int `json:"display_order"`
|
||||||
IsActive *bool `json:"is_active"`
|
IsActive *bool `json:"is_active"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
benefit := models.FeatureBenefit{
|
benefit := models.FeatureBenefit{
|
||||||
@@ -175,7 +178,7 @@ func (h *AdminFeatureBenefitHandler) Update(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.FeatureName != nil {
|
if req.FeatureName != nil {
|
||||||
|
|||||||
@@ -34,7 +34,9 @@ func (h *AdminLimitationsHandler) GetSettings(c echo.Context) error {
|
|||||||
if err == gorm.ErrRecordNotFound {
|
if err == gorm.ErrRecordNotFound {
|
||||||
// Create default settings
|
// Create default settings
|
||||||
settings = models.SubscriptionSettings{ID: 1, EnableLimitations: false}
|
settings = models.SubscriptionSettings{ID: 1, EnableLimitations: false}
|
||||||
h.db.Create(&settings)
|
if err := h.db.Create(&settings).Error; err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to create default settings"})
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch settings"})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch settings"})
|
||||||
}
|
}
|
||||||
@@ -54,7 +56,7 @@ type UpdateLimitationsSettingsRequest struct {
|
|||||||
func (h *AdminLimitationsHandler) UpdateSettings(c echo.Context) error {
|
func (h *AdminLimitationsHandler) UpdateSettings(c echo.Context) error {
|
||||||
var req UpdateLimitationsSettingsRequest
|
var req UpdateLimitationsSettingsRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var settings models.SubscriptionSettings
|
var settings models.SubscriptionSettings
|
||||||
@@ -117,8 +119,12 @@ func (h *AdminLimitationsHandler) ListTierLimits(c echo.Context) error {
|
|||||||
if len(limits) == 0 {
|
if len(limits) == 0 {
|
||||||
freeLimits := models.GetDefaultFreeLimits()
|
freeLimits := models.GetDefaultFreeLimits()
|
||||||
proLimits := models.GetDefaultProLimits()
|
proLimits := models.GetDefaultProLimits()
|
||||||
h.db.Create(&freeLimits)
|
if err := h.db.Create(&freeLimits).Error; err != nil {
|
||||||
h.db.Create(&proLimits)
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to create default free tier limits"})
|
||||||
|
}
|
||||||
|
if err := h.db.Create(&proLimits).Error; err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to create default pro tier limits"})
|
||||||
|
}
|
||||||
limits = []models.TierLimits{freeLimits, proLimits}
|
limits = []models.TierLimits{freeLimits, proLimits}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -149,7 +155,9 @@ func (h *AdminLimitationsHandler) GetTierLimits(c echo.Context) error {
|
|||||||
} else {
|
} else {
|
||||||
limits = models.GetDefaultProLimits()
|
limits = models.GetDefaultProLimits()
|
||||||
}
|
}
|
||||||
h.db.Create(&limits)
|
if err := h.db.Create(&limits).Error; err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to create default tier limits"})
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch tier limits"})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch tier limits"})
|
||||||
}
|
}
|
||||||
@@ -175,7 +183,7 @@ func (h *AdminLimitationsHandler) UpdateTierLimits(c echo.Context) error {
|
|||||||
|
|
||||||
var req UpdateTierLimitsRequest
|
var req UpdateTierLimitsRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var limits models.TierLimits
|
var limits models.TierLimits
|
||||||
@@ -188,13 +196,23 @@ func (h *AdminLimitationsHandler) UpdateTierLimits(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update fields - note: we need to handle nil vs zero difference
|
// Update fields only when explicitly provided in the request body.
|
||||||
// A nil pointer in the request means "don't change"
|
// JSON unmarshaling sets *int to nil when the key is absent, and to
|
||||||
// The actual limit value can be nil (unlimited) or a number
|
// a non-nil *int (possibly pointing to 0) when the key is present.
|
||||||
limits.PropertiesLimit = req.PropertiesLimit
|
// We rely on Bind populating these before calling this handler, so
|
||||||
limits.TasksLimit = req.TasksLimit
|
// a nil pointer here means "don't change".
|
||||||
limits.ContractorsLimit = req.ContractorsLimit
|
if req.PropertiesLimit != nil {
|
||||||
limits.DocumentsLimit = req.DocumentsLimit
|
limits.PropertiesLimit = req.PropertiesLimit
|
||||||
|
}
|
||||||
|
if req.TasksLimit != nil {
|
||||||
|
limits.TasksLimit = req.TasksLimit
|
||||||
|
}
|
||||||
|
if req.ContractorsLimit != nil {
|
||||||
|
limits.ContractorsLimit = req.ContractorsLimit
|
||||||
|
}
|
||||||
|
if req.DocumentsLimit != nil {
|
||||||
|
limits.DocumentsLimit = req.DocumentsLimit
|
||||||
|
}
|
||||||
|
|
||||||
if err := h.db.Save(&limits).Error; err != nil {
|
if err := h.db.Save(&limits).Error; err != nil {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update tier limits"})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update tier limits"})
|
||||||
@@ -297,9 +315,9 @@ func (h *AdminLimitationsHandler) GetUpgradeTrigger(c echo.Context) error {
|
|||||||
|
|
||||||
// CreateUpgradeTriggerRequest represents the create request
|
// CreateUpgradeTriggerRequest represents the create request
|
||||||
type CreateUpgradeTriggerRequest struct {
|
type CreateUpgradeTriggerRequest struct {
|
||||||
TriggerKey string `json:"trigger_key" binding:"required"`
|
TriggerKey string `json:"trigger_key" validate:"required"`
|
||||||
Title string `json:"title" binding:"required"`
|
Title string `json:"title" validate:"required"`
|
||||||
Message string `json:"message" binding:"required"`
|
Message string `json:"message" validate:"required"`
|
||||||
PromoHTML string `json:"promo_html"`
|
PromoHTML string `json:"promo_html"`
|
||||||
ButtonText string `json:"button_text"`
|
ButtonText string `json:"button_text"`
|
||||||
IsActive *bool `json:"is_active"`
|
IsActive *bool `json:"is_active"`
|
||||||
@@ -309,7 +327,10 @@ type CreateUpgradeTriggerRequest struct {
|
|||||||
func (h *AdminLimitationsHandler) CreateUpgradeTrigger(c echo.Context) error {
|
func (h *AdminLimitationsHandler) CreateUpgradeTrigger(c echo.Context) error {
|
||||||
var req CreateUpgradeTriggerRequest
|
var req CreateUpgradeTriggerRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate trigger key
|
// Validate trigger key
|
||||||
@@ -380,7 +401,7 @@ func (h *AdminLimitationsHandler) UpdateUpgradeTrigger(c echo.Context) error {
|
|||||||
|
|
||||||
var req UpdateUpgradeTriggerRequest
|
var req UpdateUpgradeTriggerRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.TriggerKey != nil {
|
if req.TriggerKey != nil {
|
||||||
|
|||||||
@@ -162,10 +162,10 @@ type TaskCategoryResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CreateUpdateCategoryRequest struct {
|
type CreateUpdateCategoryRequest struct {
|
||||||
Name string `json:"name" binding:"required,max=50"`
|
Name string `json:"name" validate:"required,max=50"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Icon string `json:"icon" binding:"max=50"`
|
Icon string `json:"icon" validate:"max=50"`
|
||||||
Color string `json:"color" binding:"max=7"`
|
Color string `json:"color" validate:"max=7"`
|
||||||
DisplayOrder *int `json:"display_order"`
|
DisplayOrder *int `json:"display_order"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,7 +193,10 @@ func (h *AdminLookupHandler) ListCategories(c echo.Context) error {
|
|||||||
func (h *AdminLookupHandler) CreateCategory(c echo.Context) error {
|
func (h *AdminLookupHandler) CreateCategory(c echo.Context) error {
|
||||||
var req CreateUpdateCategoryRequest
|
var req CreateUpdateCategoryRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
category := models.TaskCategory{
|
category := models.TaskCategory{
|
||||||
@@ -239,7 +242,10 @@ func (h *AdminLookupHandler) UpdateCategory(c echo.Context) error {
|
|||||||
|
|
||||||
var req CreateUpdateCategoryRequest
|
var req CreateUpdateCategoryRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
category.Name = req.Name
|
category.Name = req.Name
|
||||||
@@ -301,9 +307,9 @@ type TaskPriorityResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CreateUpdatePriorityRequest struct {
|
type CreateUpdatePriorityRequest struct {
|
||||||
Name string `json:"name" binding:"required,max=20"`
|
Name string `json:"name" validate:"required,max=20"`
|
||||||
Level int `json:"level" binding:"required,min=1,max=10"`
|
Level int `json:"level" validate:"required,min=1,max=10"`
|
||||||
Color string `json:"color" binding:"max=7"`
|
Color string `json:"color" validate:"max=7"`
|
||||||
DisplayOrder *int `json:"display_order"`
|
DisplayOrder *int `json:"display_order"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -330,7 +336,10 @@ func (h *AdminLookupHandler) ListPriorities(c echo.Context) error {
|
|||||||
func (h *AdminLookupHandler) CreatePriority(c echo.Context) error {
|
func (h *AdminLookupHandler) CreatePriority(c echo.Context) error {
|
||||||
var req CreateUpdatePriorityRequest
|
var req CreateUpdatePriorityRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
priority := models.TaskPriority{
|
priority := models.TaskPriority{
|
||||||
@@ -374,7 +383,10 @@ func (h *AdminLookupHandler) UpdatePriority(c echo.Context) error {
|
|||||||
|
|
||||||
var req CreateUpdatePriorityRequest
|
var req CreateUpdatePriorityRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
priority.Name = req.Name
|
priority.Name = req.Name
|
||||||
@@ -434,7 +446,7 @@ type TaskFrequencyResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CreateUpdateFrequencyRequest struct {
|
type CreateUpdateFrequencyRequest struct {
|
||||||
Name string `json:"name" binding:"required,max=20"`
|
Name string `json:"name" validate:"required,max=20"`
|
||||||
Days *int `json:"days"`
|
Days *int `json:"days"`
|
||||||
DisplayOrder *int `json:"display_order"`
|
DisplayOrder *int `json:"display_order"`
|
||||||
}
|
}
|
||||||
@@ -480,7 +492,10 @@ func (h *AdminLookupHandler) ListFrequencies(c echo.Context) error {
|
|||||||
func (h *AdminLookupHandler) CreateFrequency(c echo.Context) error {
|
func (h *AdminLookupHandler) CreateFrequency(c echo.Context) error {
|
||||||
var req CreateUpdateFrequencyRequest
|
var req CreateUpdateFrequencyRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
frequency := models.TaskFrequency{
|
frequency := models.TaskFrequency{
|
||||||
@@ -528,7 +543,10 @@ func (h *AdminLookupHandler) UpdateFrequency(c echo.Context) error {
|
|||||||
|
|
||||||
var req CreateUpdateFrequencyRequest
|
var req CreateUpdateFrequencyRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
frequency.Name = req.Name
|
frequency.Name = req.Name
|
||||||
@@ -588,7 +606,7 @@ type ResidenceTypeResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CreateUpdateResidenceTypeRequest struct {
|
type CreateUpdateResidenceTypeRequest struct {
|
||||||
Name string `json:"name" binding:"required,max=20"`
|
Name string `json:"name" validate:"required,max=20"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *AdminLookupHandler) ListResidenceTypes(c echo.Context) error {
|
func (h *AdminLookupHandler) ListResidenceTypes(c echo.Context) error {
|
||||||
@@ -611,7 +629,10 @@ func (h *AdminLookupHandler) ListResidenceTypes(c echo.Context) error {
|
|||||||
func (h *AdminLookupHandler) CreateResidenceType(c echo.Context) error {
|
func (h *AdminLookupHandler) CreateResidenceType(c echo.Context) error {
|
||||||
var req CreateUpdateResidenceTypeRequest
|
var req CreateUpdateResidenceTypeRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
residenceType := models.ResidenceType{Name: req.Name}
|
residenceType := models.ResidenceType{Name: req.Name}
|
||||||
@@ -644,7 +665,10 @@ func (h *AdminLookupHandler) UpdateResidenceType(c echo.Context) error {
|
|||||||
|
|
||||||
var req CreateUpdateResidenceTypeRequest
|
var req CreateUpdateResidenceTypeRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
residenceType.Name = req.Name
|
residenceType.Name = req.Name
|
||||||
@@ -694,9 +718,9 @@ type ContractorSpecialtyResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CreateUpdateSpecialtyRequest struct {
|
type CreateUpdateSpecialtyRequest struct {
|
||||||
Name string `json:"name" binding:"required,max=50"`
|
Name string `json:"name" validate:"required,max=50"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Icon string `json:"icon" binding:"max=50"`
|
Icon string `json:"icon" validate:"max=50"`
|
||||||
DisplayOrder *int `json:"display_order"`
|
DisplayOrder *int `json:"display_order"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -723,7 +747,10 @@ func (h *AdminLookupHandler) ListSpecialties(c echo.Context) error {
|
|||||||
func (h *AdminLookupHandler) CreateSpecialty(c echo.Context) error {
|
func (h *AdminLookupHandler) CreateSpecialty(c echo.Context) error {
|
||||||
var req CreateUpdateSpecialtyRequest
|
var req CreateUpdateSpecialtyRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
specialty := models.ContractorSpecialty{
|
specialty := models.ContractorSpecialty{
|
||||||
@@ -767,7 +794,10 @@ func (h *AdminLookupHandler) UpdateSpecialty(c echo.Context) error {
|
|||||||
|
|
||||||
var req CreateUpdateSpecialtyRequest
|
var req CreateUpdateSpecialtyRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
specialty.Name = req.Name
|
specialty.Name = req.Name
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/treytartt/honeydue-api/internal/admin/dto"
|
"github.com/treytartt/honeydue-api/internal/admin/dto"
|
||||||
@@ -36,7 +37,7 @@ func NewAdminNotificationHandler(db *gorm.DB, emailService *services.EmailServic
|
|||||||
func (h *AdminNotificationHandler) List(c echo.Context) error {
|
func (h *AdminNotificationHandler) List(c echo.Context) error {
|
||||||
var filters dto.NotificationFilters
|
var filters dto.NotificationFilters
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var notifications []models.Notification
|
var notifications []models.Notification
|
||||||
@@ -151,7 +152,7 @@ func (h *AdminNotificationHandler) Update(c echo.Context) error {
|
|||||||
|
|
||||||
var req dto.UpdateNotificationRequest
|
var req dto.UpdateNotificationRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
@@ -235,7 +236,7 @@ func (h *AdminNotificationHandler) toNotificationDetailResponse(notif *models.No
|
|||||||
func (h *AdminNotificationHandler) SendTestNotification(c echo.Context) error {
|
func (h *AdminNotificationHandler) SendTestNotification(c echo.Context) error {
|
||||||
var req dto.SendTestNotificationRequest
|
var req dto.SendTestNotificationRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify user exists
|
// Verify user exists
|
||||||
@@ -294,13 +295,10 @@ func (h *AdminNotificationHandler) SendTestNotification(c echo.Context) error {
|
|||||||
|
|
||||||
err := h.pushClient.SendToAll(ctx, iosTokens, androidTokens, req.Title, req.Body, pushData)
|
err := h.pushClient.SendToAll(ctx, iosTokens, androidTokens, req.Title, req.Body, pushData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Update notification with error
|
// Log the real error for debugging
|
||||||
h.db.Model(¬ification).Updates(map[string]interface{}{
|
log.Error().Err(err).Uint("notification_id", notification.ID).Msg("Failed to send push notification")
|
||||||
"error": err.Error(),
|
|
||||||
})
|
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||||
"error": "Failed to send push notification",
|
"error": "Failed to send push notification",
|
||||||
"details": err.Error(),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -327,7 +325,7 @@ func (h *AdminNotificationHandler) SendTestNotification(c echo.Context) error {
|
|||||||
func (h *AdminNotificationHandler) SendTestEmail(c echo.Context) error {
|
func (h *AdminNotificationHandler) SendTestEmail(c echo.Context) error {
|
||||||
var req dto.SendTestEmailRequest
|
var req dto.SendTestEmailRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify user exists
|
// Verify user exists
|
||||||
@@ -369,9 +367,9 @@ func (h *AdminNotificationHandler) SendTestEmail(c echo.Context) error {
|
|||||||
|
|
||||||
err := h.emailService.SendEmail(user.Email, req.Subject, htmlBody, req.Body)
|
err := h.emailService.SendEmail(user.Email, req.Subject, htmlBody, req.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send test email")
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||||
"error": "Failed to send email",
|
"error": "Failed to send email",
|
||||||
"details": err.Error(),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -384,11 +382,14 @@ func (h *AdminNotificationHandler) SendTestEmail(c echo.Context) error {
|
|||||||
// SendPostVerificationEmail handles POST /api/admin/emails/send-post-verification
|
// SendPostVerificationEmail handles POST /api/admin/emails/send-post-verification
|
||||||
func (h *AdminNotificationHandler) SendPostVerificationEmail(c echo.Context) error {
|
func (h *AdminNotificationHandler) SendPostVerificationEmail(c echo.Context) error {
|
||||||
var req struct {
|
var req struct {
|
||||||
UserID uint `json:"user_id" binding:"required"`
|
UserID uint `json:"user_id" validate:"required"`
|
||||||
}
|
}
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "user_id is required"})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "user_id is required"})
|
||||||
}
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Verify user exists
|
// Verify user exists
|
||||||
var user models.User
|
var user models.User
|
||||||
@@ -410,9 +411,9 @@ func (h *AdminNotificationHandler) SendPostVerificationEmail(c echo.Context) err
|
|||||||
|
|
||||||
err := h.emailService.SendPostVerificationEmail(user.Email, user.FirstName)
|
err := h.emailService.SendPostVerificationEmail(user.Email, user.FirstName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send post-verification email")
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{
|
||||||
"error": "Failed to send email",
|
"error": "Failed to send email",
|
||||||
"details": err.Error(),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ type NotificationPrefResponse struct {
|
|||||||
func (h *AdminNotificationPrefsHandler) List(c echo.Context) error {
|
func (h *AdminNotificationPrefsHandler) List(c echo.Context) error {
|
||||||
var filters dto.PaginationParams
|
var filters dto.PaginationParams
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var prefs []models.NotificationPreference
|
var prefs []models.NotificationPreference
|
||||||
@@ -212,7 +212,7 @@ func (h *AdminNotificationPrefsHandler) Update(c echo.Context) error {
|
|||||||
|
|
||||||
var req UpdateNotificationPrefRequest
|
var req UpdateNotificationPrefRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply updates
|
// Apply updates
|
||||||
|
|||||||
@@ -240,7 +240,7 @@ func (h *AdminOnboardingHandler) Delete(c echo.Context) error {
|
|||||||
// DELETE /api/admin/onboarding-emails/bulk
|
// DELETE /api/admin/onboarding-emails/bulk
|
||||||
func (h *AdminOnboardingHandler) BulkDelete(c echo.Context) error {
|
func (h *AdminOnboardingHandler) BulkDelete(c echo.Context) error {
|
||||||
var req struct {
|
var req struct {
|
||||||
IDs []uint `json:"ids" binding:"required"`
|
IDs []uint `json:"ids" validate:"required"`
|
||||||
}
|
}
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request"})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request"})
|
||||||
@@ -263,8 +263,8 @@ func (h *AdminOnboardingHandler) BulkDelete(c echo.Context) error {
|
|||||||
|
|
||||||
// SendOnboardingEmailRequest represents a request to send an onboarding email
|
// SendOnboardingEmailRequest represents a request to send an onboarding email
|
||||||
type SendOnboardingEmailRequest struct {
|
type SendOnboardingEmailRequest struct {
|
||||||
UserID uint `json:"user_id" binding:"required"`
|
UserID uint `json:"user_id" validate:"required"`
|
||||||
EmailType string `json:"email_type" binding:"required"`
|
EmailType string `json:"email_type" validate:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send sends an onboarding email to a specific user
|
// Send sends an onboarding email to a specific user
|
||||||
@@ -278,6 +278,9 @@ func (h *AdminOnboardingHandler) Send(c echo.Context) error {
|
|||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request: user_id and email_type are required"})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request: user_id and email_type are required"})
|
||||||
}
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Validate email type
|
// Validate email type
|
||||||
var emailType models.OnboardingEmailType
|
var emailType models.OnboardingEmailType
|
||||||
@@ -301,7 +304,7 @@ func (h *AdminOnboardingHandler) Send(c echo.Context) error {
|
|||||||
|
|
||||||
// Send the email
|
// Send the email
|
||||||
if err := h.onboardingService.SendOnboardingEmailToUser(req.UserID, emailType); err != nil {
|
if err := h.onboardingService.SendOnboardingEmailToUser(req.UserID, emailType); err != nil {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to send onboarding email"})
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{
|
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ type PasswordResetCodeResponse struct {
|
|||||||
func (h *AdminPasswordResetCodeHandler) List(c echo.Context) error {
|
func (h *AdminPasswordResetCodeHandler) List(c echo.Context) error {
|
||||||
var filters dto.PaginationParams
|
var filters dto.PaginationParams
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var codes []models.PasswordResetCode
|
var codes []models.PasswordResetCode
|
||||||
@@ -147,7 +147,7 @@ func (h *AdminPasswordResetCodeHandler) Delete(c echo.Context) error {
|
|||||||
func (h *AdminPasswordResetCodeHandler) BulkDelete(c echo.Context) error {
|
func (h *AdminPasswordResetCodeHandler) BulkDelete(c echo.Context) error {
|
||||||
var req dto.BulkDeleteRequest
|
var req dto.BulkDeleteRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
result := h.db.Where("id IN ?", req.IDs).Delete(&models.PasswordResetCode{})
|
result := h.db.Where("id IN ?", req.IDs).Delete(&models.PasswordResetCode{})
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ type PromotionResponse struct {
|
|||||||
func (h *AdminPromotionHandler) List(c echo.Context) error {
|
func (h *AdminPromotionHandler) List(c echo.Context) error {
|
||||||
var filters dto.PaginationParams
|
var filters dto.PaginationParams
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var promotions []models.Promotion
|
var promotions []models.Promotion
|
||||||
@@ -123,18 +123,21 @@ func (h *AdminPromotionHandler) Get(c echo.Context) error {
|
|||||||
// Create handles POST /api/admin/promotions
|
// Create handles POST /api/admin/promotions
|
||||||
func (h *AdminPromotionHandler) Create(c echo.Context) error {
|
func (h *AdminPromotionHandler) Create(c echo.Context) error {
|
||||||
var req struct {
|
var req struct {
|
||||||
PromotionID string `json:"promotion_id" binding:"required"`
|
PromotionID string `json:"promotion_id" validate:"required"`
|
||||||
Title string `json:"title" binding:"required"`
|
Title string `json:"title" validate:"required"`
|
||||||
Message string `json:"message" binding:"required"`
|
Message string `json:"message" validate:"required"`
|
||||||
Link *string `json:"link"`
|
Link *string `json:"link"`
|
||||||
StartDate string `json:"start_date" binding:"required"`
|
StartDate string `json:"start_date" validate:"required"`
|
||||||
EndDate string `json:"end_date" binding:"required"`
|
EndDate string `json:"end_date" validate:"required"`
|
||||||
TargetTier string `json:"target_tier"`
|
TargetTier string `json:"target_tier"`
|
||||||
IsActive *bool `json:"is_active"`
|
IsActive *bool `json:"is_active"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
startDate, err := time.Parse("2006-01-02T15:04:05Z", req.StartDate)
|
startDate, err := time.Parse("2006-01-02T15:04:05Z", req.StartDate)
|
||||||
@@ -219,7 +222,7 @@ func (h *AdminPromotionHandler) Update(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.PromotionID != nil {
|
if req.PromotionID != nil {
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ func NewAdminResidenceHandler(db *gorm.DB) *AdminResidenceHandler {
|
|||||||
func (h *AdminResidenceHandler) List(c echo.Context) error {
|
func (h *AdminResidenceHandler) List(c echo.Context) error {
|
||||||
var filters dto.ResidenceFilters
|
var filters dto.ResidenceFilters
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var residences []models.Residence
|
var residences []models.Residence
|
||||||
@@ -143,7 +143,7 @@ func (h *AdminResidenceHandler) Update(c echo.Context) error {
|
|||||||
|
|
||||||
var req dto.UpdateResidenceRequest
|
var req dto.UpdateResidenceRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.OwnerID != nil {
|
if req.OwnerID != nil {
|
||||||
@@ -204,8 +204,7 @@ func (h *AdminResidenceHandler) Update(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if req.PurchasePrice != nil {
|
if req.PurchasePrice != nil {
|
||||||
d := decimal.NewFromFloat(*req.PurchasePrice)
|
residence.PurchasePrice = req.PurchasePrice
|
||||||
residence.PurchasePrice = &d
|
|
||||||
}
|
}
|
||||||
if req.IsActive != nil {
|
if req.IsActive != nil {
|
||||||
residence.IsActive = *req.IsActive
|
residence.IsActive = *req.IsActive
|
||||||
@@ -226,7 +225,7 @@ func (h *AdminResidenceHandler) Update(c echo.Context) error {
|
|||||||
func (h *AdminResidenceHandler) Create(c echo.Context) error {
|
func (h *AdminResidenceHandler) Create(c echo.Context) error {
|
||||||
var req dto.CreateResidenceRequest
|
var req dto.CreateResidenceRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify owner exists
|
// Verify owner exists
|
||||||
@@ -300,7 +299,7 @@ func (h *AdminResidenceHandler) Delete(c echo.Context) error {
|
|||||||
func (h *AdminResidenceHandler) BulkDelete(c echo.Context) error {
|
func (h *AdminResidenceHandler) BulkDelete(c echo.Context) error {
|
||||||
var req dto.BulkDeleteRequest
|
var req dto.BulkDeleteRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Soft delete - deactivate all
|
// Soft delete - deactivate all
|
||||||
|
|||||||
@@ -47,7 +47,9 @@ func (h *AdminSettingsHandler) GetSettings(c echo.Context) error {
|
|||||||
TrialEnabled: true,
|
TrialEnabled: true,
|
||||||
TrialDurationDays: 14,
|
TrialDurationDays: 14,
|
||||||
}
|
}
|
||||||
h.db.Create(&settings)
|
if err := h.db.Create(&settings).Error; err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to create default settings"})
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch settings"})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch settings"})
|
||||||
}
|
}
|
||||||
@@ -73,7 +75,7 @@ type UpdateSettingsRequest struct {
|
|||||||
func (h *AdminSettingsHandler) UpdateSettings(c echo.Context) error {
|
func (h *AdminSettingsHandler) UpdateSettings(c echo.Context) error {
|
||||||
var req UpdateSettingsRequest
|
var req UpdateSettingsRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var settings models.SubscriptionSettings
|
var settings models.SubscriptionSettings
|
||||||
@@ -123,12 +125,12 @@ func (h *AdminSettingsHandler) UpdateSettings(c echo.Context) error {
|
|||||||
func (h *AdminSettingsHandler) SeedLookups(c echo.Context) error {
|
func (h *AdminSettingsHandler) SeedLookups(c echo.Context) error {
|
||||||
// First seed lookup tables
|
// First seed lookup tables
|
||||||
if err := h.runSeedFile("001_lookups.sql"); err != nil {
|
if err := h.runSeedFile("001_lookups.sql"); err != nil {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed lookups: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed lookups"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Then seed task templates
|
// Then seed task templates
|
||||||
if err := h.runSeedFile("003_task_templates.sql"); err != nil {
|
if err := h.runSeedFile("003_task_templates.sql"); err != nil {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed task templates: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed task templates"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache all lookups in Redis
|
// Cache all lookups in Redis
|
||||||
@@ -349,7 +351,7 @@ func parseTags(tags string) []string {
|
|||||||
// SeedTestData handles POST /api/admin/settings/seed-test-data
|
// SeedTestData handles POST /api/admin/settings/seed-test-data
|
||||||
func (h *AdminSettingsHandler) SeedTestData(c echo.Context) error {
|
func (h *AdminSettingsHandler) SeedTestData(c echo.Context) error {
|
||||||
if err := h.runSeedFile("002_test_data.sql"); err != nil {
|
if err := h.runSeedFile("002_test_data.sql"); err != nil {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed test data: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed test data"})
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Test data seeded successfully"})
|
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Test data seeded successfully"})
|
||||||
@@ -358,7 +360,7 @@ func (h *AdminSettingsHandler) SeedTestData(c echo.Context) error {
|
|||||||
// SeedTaskTemplates handles POST /api/admin/settings/seed-task-templates
|
// SeedTaskTemplates handles POST /api/admin/settings/seed-task-templates
|
||||||
func (h *AdminSettingsHandler) SeedTaskTemplates(c echo.Context) error {
|
func (h *AdminSettingsHandler) SeedTaskTemplates(c echo.Context) error {
|
||||||
if err := h.runSeedFile("003_task_templates.sql"); err != nil {
|
if err := h.runSeedFile("003_task_templates.sql"); err != nil {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed task templates: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to seed task templates"})
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Task templates seeded successfully"})
|
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Task templates seeded successfully"})
|
||||||
@@ -590,38 +592,38 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
|
|||||||
// 1. Delete task completion images
|
// 1. Delete task completion images
|
||||||
if err := tx.Exec("DELETE FROM task_taskcompletionimage").Error; err != nil {
|
if err := tx.Exec("DELETE FROM task_taskcompletionimage").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete task completion images: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete task completion images"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Delete task completions
|
// 2. Delete task completions
|
||||||
if err := tx.Exec("DELETE FROM task_taskcompletion").Error; err != nil {
|
if err := tx.Exec("DELETE FROM task_taskcompletion").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete task completions: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete task completions"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Delete notifications (must be before tasks since notifications have task_id FK)
|
// 3. Delete notifications (must be before tasks since notifications have task_id FK)
|
||||||
if len(preservedUserIDs) > 0 {
|
if len(preservedUserIDs) > 0 {
|
||||||
if err := tx.Exec("DELETE FROM notifications_notification WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
if err := tx.Exec("DELETE FROM notifications_notification WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notifications: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notifications"})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := tx.Exec("DELETE FROM notifications_notification").Error; err != nil {
|
if err := tx.Exec("DELETE FROM notifications_notification").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notifications: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notifications"})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Delete document images
|
// 4. Delete document images
|
||||||
if err := tx.Exec("DELETE FROM task_documentimage").Error; err != nil {
|
if err := tx.Exec("DELETE FROM task_documentimage").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete document images: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete document images"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. Delete documents
|
// 5. Delete documents
|
||||||
if err := tx.Exec("DELETE FROM task_document").Error; err != nil {
|
if err := tx.Exec("DELETE FROM task_document").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete documents: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete documents"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 6. Delete task reminder logs (must be before tasks since reminder logs have task_id FK)
|
// 6. Delete task reminder logs (must be before tasks since reminder logs have task_id FK)
|
||||||
@@ -631,64 +633,64 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
|
|||||||
if tableExists {
|
if tableExists {
|
||||||
if err := tx.Exec("DELETE FROM task_reminderlog").Error; err != nil {
|
if err := tx.Exec("DELETE FROM task_reminderlog").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete task reminder logs: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete task reminder logs"})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 7. Delete tasks (must be before contractors since tasks reference contractors)
|
// 7. Delete tasks (must be before contractors since tasks reference contractors)
|
||||||
if err := tx.Exec("DELETE FROM task_task").Error; err != nil {
|
if err := tx.Exec("DELETE FROM task_task").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete tasks: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete tasks"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 8. Delete contractor specialties (many-to-many)
|
// 8. Delete contractor specialties (many-to-many)
|
||||||
if err := tx.Exec("DELETE FROM task_contractor_specialties").Error; err != nil {
|
if err := tx.Exec("DELETE FROM task_contractor_specialties").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete contractor specialties: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete contractor specialties"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 9. Delete contractors
|
// 9. Delete contractors
|
||||||
if err := tx.Exec("DELETE FROM task_contractor").Error; err != nil {
|
if err := tx.Exec("DELETE FROM task_contractor").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete contractors: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete contractors"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 10. Delete residence_users (many-to-many for shared residences)
|
// 10. Delete residence_users (many-to-many for shared residences)
|
||||||
if err := tx.Exec("DELETE FROM residence_residence_users").Error; err != nil {
|
if err := tx.Exec("DELETE FROM residence_residence_users").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete residence users: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete residence users"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 11. Delete residence share codes (must be before residences since share codes have residence_id FK)
|
// 11. Delete residence share codes (must be before residences since share codes have residence_id FK)
|
||||||
if err := tx.Exec("DELETE FROM residence_residencesharecode").Error; err != nil {
|
if err := tx.Exec("DELETE FROM residence_residencesharecode").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete residence share codes: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete residence share codes"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 12. Delete residences
|
// 12. Delete residences
|
||||||
if err := tx.Exec("DELETE FROM residence_residence").Error; err != nil {
|
if err := tx.Exec("DELETE FROM residence_residence").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete residences: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete residences"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 13. Delete push devices for non-superusers (both APNS and GCM)
|
// 13. Delete push devices for non-superusers (both APNS and GCM)
|
||||||
if len(preservedUserIDs) > 0 {
|
if len(preservedUserIDs) > 0 {
|
||||||
if err := tx.Exec("DELETE FROM push_notifications_apnsdevice WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
if err := tx.Exec("DELETE FROM push_notifications_apnsdevice WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete APNS devices: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete APNS devices"})
|
||||||
}
|
}
|
||||||
if err := tx.Exec("DELETE FROM push_notifications_gcmdevice WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
if err := tx.Exec("DELETE FROM push_notifications_gcmdevice WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete GCM devices: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete GCM devices"})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := tx.Exec("DELETE FROM push_notifications_apnsdevice").Error; err != nil {
|
if err := tx.Exec("DELETE FROM push_notifications_apnsdevice").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete APNS devices: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete APNS devices"})
|
||||||
}
|
}
|
||||||
if err := tx.Exec("DELETE FROM push_notifications_gcmdevice").Error; err != nil {
|
if err := tx.Exec("DELETE FROM push_notifications_gcmdevice").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete GCM devices: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete GCM devices"})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -696,12 +698,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
|
|||||||
if len(preservedUserIDs) > 0 {
|
if len(preservedUserIDs) > 0 {
|
||||||
if err := tx.Exec("DELETE FROM notifications_notificationpreference WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
if err := tx.Exec("DELETE FROM notifications_notificationpreference WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notification preferences: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notification preferences"})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := tx.Exec("DELETE FROM notifications_notificationpreference").Error; err != nil {
|
if err := tx.Exec("DELETE FROM notifications_notificationpreference").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notification preferences: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete notification preferences"})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -709,12 +711,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
|
|||||||
if len(preservedUserIDs) > 0 {
|
if len(preservedUserIDs) > 0 {
|
||||||
if err := tx.Exec("DELETE FROM subscription_usersubscription WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
if err := tx.Exec("DELETE FROM subscription_usersubscription WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete subscriptions: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete subscriptions"})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := tx.Exec("DELETE FROM subscription_usersubscription").Error; err != nil {
|
if err := tx.Exec("DELETE FROM subscription_usersubscription").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete subscriptions: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete subscriptions"})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -722,12 +724,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
|
|||||||
if len(preservedUserIDs) > 0 {
|
if len(preservedUserIDs) > 0 {
|
||||||
if err := tx.Exec("DELETE FROM user_passwordresetcode WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
if err := tx.Exec("DELETE FROM user_passwordresetcode WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete password reset codes: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete password reset codes"})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := tx.Exec("DELETE FROM user_passwordresetcode").Error; err != nil {
|
if err := tx.Exec("DELETE FROM user_passwordresetcode").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete password reset codes: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete password reset codes"})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -735,12 +737,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
|
|||||||
if len(preservedUserIDs) > 0 {
|
if len(preservedUserIDs) > 0 {
|
||||||
if err := tx.Exec("DELETE FROM user_confirmationcode WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
if err := tx.Exec("DELETE FROM user_confirmationcode WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete confirmation codes: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete confirmation codes"})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := tx.Exec("DELETE FROM user_confirmationcode").Error; err != nil {
|
if err := tx.Exec("DELETE FROM user_confirmationcode").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete confirmation codes: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete confirmation codes"})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -748,12 +750,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
|
|||||||
if len(preservedUserIDs) > 0 {
|
if len(preservedUserIDs) > 0 {
|
||||||
if err := tx.Exec("DELETE FROM user_authtoken WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
if err := tx.Exec("DELETE FROM user_authtoken WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete auth tokens: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete auth tokens"})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := tx.Exec("DELETE FROM user_authtoken").Error; err != nil {
|
if err := tx.Exec("DELETE FROM user_authtoken").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete auth tokens: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete auth tokens"})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -761,12 +763,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
|
|||||||
if len(preservedUserIDs) > 0 {
|
if len(preservedUserIDs) > 0 {
|
||||||
if err := tx.Exec("DELETE FROM user_applesocialauth WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
if err := tx.Exec("DELETE FROM user_applesocialauth WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth"})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := tx.Exec("DELETE FROM user_applesocialauth").Error; err != nil {
|
if err := tx.Exec("DELETE FROM user_applesocialauth").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth"})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -774,12 +776,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
|
|||||||
if len(preservedUserIDs) > 0 {
|
if len(preservedUserIDs) > 0 {
|
||||||
if err := tx.Exec("DELETE FROM user_userprofile WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
if err := tx.Exec("DELETE FROM user_userprofile WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete user profiles: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete user profiles"})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := tx.Exec("DELETE FROM user_userprofile").Error; err != nil {
|
if err := tx.Exec("DELETE FROM user_userprofile").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete user profiles: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete user profiles"})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -790,12 +792,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
|
|||||||
if len(preservedUserIDs) > 0 {
|
if len(preservedUserIDs) > 0 {
|
||||||
if err := tx.Exec("DELETE FROM onboarding_emails WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
if err := tx.Exec("DELETE FROM onboarding_emails WHERE user_id NOT IN (?)", preservedUserIDs).Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete onboarding emails: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete onboarding emails"})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := tx.Exec("DELETE FROM onboarding_emails").Error; err != nil {
|
if err := tx.Exec("DELETE FROM onboarding_emails").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete onboarding emails: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete onboarding emails"})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -804,12 +806,12 @@ func (h *AdminSettingsHandler) ClearAllData(c echo.Context) error {
|
|||||||
// Always filter by is_superuser to be safe, regardless of preservedUserIDs
|
// Always filter by is_superuser to be safe, regardless of preservedUserIDs
|
||||||
if err := tx.Exec("DELETE FROM auth_user WHERE is_superuser = false").Error; err != nil {
|
if err := tx.Exec("DELETE FROM auth_user WHERE is_superuser = false").Error; err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete users: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete users"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Commit the transaction
|
// Commit the transaction
|
||||||
if err := tx.Commit().Error; err != nil {
|
if err := tx.Commit().Error; err != nil {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to commit transaction: " + err.Error()})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to commit transaction"})
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, ClearAllDataResponse{
|
return c.JSON(http.StatusOK, ClearAllDataResponse{
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ type ShareCodeResponse struct {
|
|||||||
func (h *AdminShareCodeHandler) List(c echo.Context) error {
|
func (h *AdminShareCodeHandler) List(c echo.Context) error {
|
||||||
var filters dto.PaginationParams
|
var filters dto.PaginationParams
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var codes []models.ResidenceShareCode
|
var codes []models.ResidenceShareCode
|
||||||
@@ -156,7 +156,7 @@ func (h *AdminShareCodeHandler) Update(c echo.Context) error {
|
|||||||
IsActive *bool `json:"is_active"`
|
IsActive *bool `json:"is_active"`
|
||||||
}
|
}
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only update IsActive when explicitly provided (non-nil).
|
// Only update IsActive when explicitly provided (non-nil).
|
||||||
@@ -216,7 +216,7 @@ func (h *AdminShareCodeHandler) Delete(c echo.Context) error {
|
|||||||
func (h *AdminShareCodeHandler) BulkDelete(c echo.Context) error {
|
func (h *AdminShareCodeHandler) BulkDelete(c echo.Context) error {
|
||||||
var req dto.BulkDeleteRequest
|
var req dto.BulkDeleteRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
result := h.db.Where("id IN ?", req.IDs).Delete(&models.ResidenceShareCode{})
|
result := h.db.Where("id IN ?", req.IDs).Delete(&models.ResidenceShareCode{})
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ func NewAdminSubscriptionHandler(db *gorm.DB) *AdminSubscriptionHandler {
|
|||||||
func (h *AdminSubscriptionHandler) List(c echo.Context) error {
|
func (h *AdminSubscriptionHandler) List(c echo.Context) error {
|
||||||
var filters dto.SubscriptionFilters
|
var filters dto.SubscriptionFilters
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var subscriptions []models.UserSubscription
|
var subscriptions []models.UserSubscription
|
||||||
@@ -38,8 +38,8 @@ func (h *AdminSubscriptionHandler) List(c echo.Context) error {
|
|||||||
// Apply search (search by user email)
|
// Apply search (search by user email)
|
||||||
if filters.Search != "" {
|
if filters.Search != "" {
|
||||||
search := "%" + filters.Search + "%"
|
search := "%" + filters.Search + "%"
|
||||||
query = query.Joins("JOIN users ON users.id = subscription_usersubscription.user_id").
|
query = query.Joins("JOIN auth_user ON auth_user.id = subscription_usersubscription.user_id").
|
||||||
Where("users.email ILIKE ? OR users.username ILIKE ?", search, search)
|
Where("auth_user.email ILIKE ? OR auth_user.username ILIKE ?", search, search)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply filters
|
// Apply filters
|
||||||
@@ -140,7 +140,7 @@ func (h *AdminSubscriptionHandler) Update(c echo.Context) error {
|
|||||||
|
|
||||||
var req dto.UpdateSubscriptionRequest
|
var req dto.UpdateSubscriptionRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Tier != nil {
|
if req.Tier != nil {
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/shopspring/decimal"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/treytartt/honeydue-api/internal/admin/dto"
|
"github.com/treytartt/honeydue-api/internal/admin/dto"
|
||||||
@@ -27,7 +26,7 @@ func NewAdminTaskHandler(db *gorm.DB) *AdminTaskHandler {
|
|||||||
func (h *AdminTaskHandler) List(c echo.Context) error {
|
func (h *AdminTaskHandler) List(c echo.Context) error {
|
||||||
var filters dto.TaskFilters
|
var filters dto.TaskFilters
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var tasks []models.Task
|
var tasks []models.Task
|
||||||
@@ -149,7 +148,7 @@ func (h *AdminTaskHandler) Update(c echo.Context) error {
|
|||||||
|
|
||||||
var req dto.UpdateTaskRequest
|
var req dto.UpdateTaskRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify residence if changing
|
// Verify residence if changing
|
||||||
@@ -216,10 +215,10 @@ func (h *AdminTaskHandler) Update(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if req.EstimatedCost != nil {
|
if req.EstimatedCost != nil {
|
||||||
updates["estimated_cost"] = decimal.NewFromFloat(*req.EstimatedCost)
|
updates["estimated_cost"] = *req.EstimatedCost
|
||||||
}
|
}
|
||||||
if req.ActualCost != nil {
|
if req.ActualCost != nil {
|
||||||
updates["actual_cost"] = decimal.NewFromFloat(*req.ActualCost)
|
updates["actual_cost"] = *req.ActualCost
|
||||||
}
|
}
|
||||||
if req.ContractorID != nil {
|
if req.ContractorID != nil {
|
||||||
updates["contractor_id"] = *req.ContractorID
|
updates["contractor_id"] = *req.ContractorID
|
||||||
@@ -248,7 +247,7 @@ func (h *AdminTaskHandler) Update(c echo.Context) error {
|
|||||||
func (h *AdminTaskHandler) Create(c echo.Context) error {
|
func (h *AdminTaskHandler) Create(c echo.Context) error {
|
||||||
var req dto.CreateTaskRequest
|
var req dto.CreateTaskRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify residence exists
|
// Verify residence exists
|
||||||
@@ -285,8 +284,7 @@ func (h *AdminTaskHandler) Create(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if req.EstimatedCost != nil {
|
if req.EstimatedCost != nil {
|
||||||
d := decimal.NewFromFloat(*req.EstimatedCost)
|
task.EstimatedCost = req.EstimatedCost
|
||||||
task.EstimatedCost = &d
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.Create(&task).Error; err != nil {
|
if err := h.db.Create(&task).Error; err != nil {
|
||||||
@@ -326,7 +324,7 @@ func (h *AdminTaskHandler) Delete(c echo.Context) error {
|
|||||||
func (h *AdminTaskHandler) BulkDelete(c echo.Context) error {
|
func (h *AdminTaskHandler) BulkDelete(c echo.Context) error {
|
||||||
var req dto.BulkDeleteRequest
|
var req dto.BulkDeleteRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Soft delete - archive and cancel all
|
// Soft delete - archive and cancel all
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ func NewAdminTaskTemplateHandler(db *gorm.DB) *AdminTaskTemplateHandler {
|
|||||||
func (h *AdminTaskTemplateHandler) refreshTaskTemplatesCache(ctx context.Context) {
|
func (h *AdminTaskTemplateHandler) refreshTaskTemplatesCache(ctx context.Context) {
|
||||||
cache := services.GetCache()
|
cache := services.GetCache()
|
||||||
if cache == nil {
|
if cache == nil {
|
||||||
|
log.Warn().Msg("Cache service unavailable, skipping task templates cache refresh")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var templates []models.TaskTemplate
|
var templates []models.TaskTemplate
|
||||||
@@ -68,12 +70,12 @@ type TaskTemplateResponse struct {
|
|||||||
|
|
||||||
// CreateUpdateTaskTemplateRequest represents the request body for creating/updating templates
|
// CreateUpdateTaskTemplateRequest represents the request body for creating/updating templates
|
||||||
type CreateUpdateTaskTemplateRequest struct {
|
type CreateUpdateTaskTemplateRequest struct {
|
||||||
Title string `json:"title" binding:"required,max=200"`
|
Title string `json:"title" validate:"required,max=200"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
CategoryID *uint `json:"category_id"`
|
CategoryID *uint `json:"category_id"`
|
||||||
FrequencyID *uint `json:"frequency_id"`
|
FrequencyID *uint `json:"frequency_id"`
|
||||||
IconIOS string `json:"icon_ios" binding:"max=100"`
|
IconIOS string `json:"icon_ios" validate:"max=100"`
|
||||||
IconAndroid string `json:"icon_android" binding:"max=100"`
|
IconAndroid string `json:"icon_android" validate:"max=100"`
|
||||||
Tags string `json:"tags"`
|
Tags string `json:"tags"`
|
||||||
DisplayOrder *int `json:"display_order"`
|
DisplayOrder *int `json:"display_order"`
|
||||||
IsActive *bool `json:"is_active"`
|
IsActive *bool `json:"is_active"`
|
||||||
@@ -140,7 +142,10 @@ func (h *AdminTaskTemplateHandler) GetTemplate(c echo.Context) error {
|
|||||||
func (h *AdminTaskTemplateHandler) CreateTemplate(c echo.Context) error {
|
func (h *AdminTaskTemplateHandler) CreateTemplate(c echo.Context) error {
|
||||||
var req CreateUpdateTaskTemplateRequest
|
var req CreateUpdateTaskTemplateRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
template := models.TaskTemplate{
|
template := models.TaskTemplate{
|
||||||
@@ -191,7 +196,10 @@ func (h *AdminTaskTemplateHandler) UpdateTemplate(c echo.Context) error {
|
|||||||
|
|
||||||
var req CreateUpdateTaskTemplateRequest
|
var req CreateUpdateTaskTemplateRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
template.Title = req.Title
|
template.Title = req.Title
|
||||||
@@ -271,10 +279,13 @@ func (h *AdminTaskTemplateHandler) ToggleActive(c echo.Context) error {
|
|||||||
// BulkCreate handles POST /admin/api/task-templates/bulk/
|
// BulkCreate handles POST /admin/api/task-templates/bulk/
|
||||||
func (h *AdminTaskTemplateHandler) BulkCreate(c echo.Context) error {
|
func (h *AdminTaskTemplateHandler) BulkCreate(c echo.Context) error {
|
||||||
var req struct {
|
var req struct {
|
||||||
Templates []CreateUpdateTaskTemplateRequest `json:"templates" binding:"required,dive"`
|
Templates []CreateUpdateTaskTemplateRequest `json:"templates" validate:"required,dive"`
|
||||||
}
|
}
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
|
}
|
||||||
|
if err := c.Validate(&req); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
templates := make([]models.TaskTemplate, len(req.Templates))
|
templates := make([]models.TaskTemplate, len(req.Templates))
|
||||||
|
|||||||
@@ -3,14 +3,25 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/treytartt/honeydue-api/internal/admin/dto"
|
"github.com/treytartt/honeydue-api/internal/admin/dto"
|
||||||
"github.com/treytartt/honeydue-api/internal/models"
|
"github.com/treytartt/honeydue-api/internal/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// escapeLikeWildcards escapes SQL LIKE wildcards (%, _) in user input
|
||||||
|
// to prevent wildcard injection in LIKE queries.
|
||||||
|
func escapeLikeWildcards(s string) string {
|
||||||
|
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||||
|
s = strings.ReplaceAll(s, "%", "\\%")
|
||||||
|
s = strings.ReplaceAll(s, "_", "\\_")
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
// AdminUserHandler handles admin user management endpoints
|
// AdminUserHandler handles admin user management endpoints
|
||||||
type AdminUserHandler struct {
|
type AdminUserHandler struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
@@ -25,7 +36,7 @@ func NewAdminUserHandler(db *gorm.DB) *AdminUserHandler {
|
|||||||
func (h *AdminUserHandler) List(c echo.Context) error {
|
func (h *AdminUserHandler) List(c echo.Context) error {
|
||||||
var filters dto.UserFilters
|
var filters dto.UserFilters
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var users []models.User
|
var users []models.User
|
||||||
@@ -35,7 +46,7 @@ func (h *AdminUserHandler) List(c echo.Context) error {
|
|||||||
|
|
||||||
// Apply search
|
// Apply search
|
||||||
if filters.Search != "" {
|
if filters.Search != "" {
|
||||||
search := "%" + filters.Search + "%"
|
search := "%" + escapeLikeWildcards(filters.Search) + "%"
|
||||||
query = query.Where(
|
query = query.Where(
|
||||||
"username ILIKE ? OR email ILIKE ? OR first_name ILIKE ? OR last_name ILIKE ?",
|
"username ILIKE ? OR email ILIKE ? OR first_name ILIKE ? OR last_name ILIKE ?",
|
||||||
search, search, search, search,
|
search, search, search, search,
|
||||||
@@ -70,10 +81,49 @@ func (h *AdminUserHandler) List(c echo.Context) error {
|
|||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch users"})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch users"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Batch COUNT queries for residence and task counts instead of N+1
|
||||||
|
userIDs := make([]uint, len(users))
|
||||||
|
for i, user := range users {
|
||||||
|
userIDs[i] = user.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
type countResult struct {
|
||||||
|
OwnerID uint
|
||||||
|
Count int64
|
||||||
|
}
|
||||||
|
|
||||||
|
residenceCounts := make(map[uint]int64)
|
||||||
|
taskCounts := make(map[uint]int64)
|
||||||
|
|
||||||
|
if len(userIDs) > 0 {
|
||||||
|
var resCounts []countResult
|
||||||
|
h.db.Model(&models.Residence{}).
|
||||||
|
Select("owner_id, COUNT(*) as count").
|
||||||
|
Where("owner_id IN ?", userIDs).
|
||||||
|
Group("owner_id").
|
||||||
|
Scan(&resCounts)
|
||||||
|
for _, rc := range resCounts {
|
||||||
|
residenceCounts[rc.OwnerID] = rc.Count
|
||||||
|
}
|
||||||
|
|
||||||
|
var tskCounts []struct {
|
||||||
|
CreatedByID uint
|
||||||
|
Count int64
|
||||||
|
}
|
||||||
|
h.db.Model(&models.Task{}).
|
||||||
|
Select("created_by_id, COUNT(*) as count").
|
||||||
|
Where("created_by_id IN ?", userIDs).
|
||||||
|
Group("created_by_id").
|
||||||
|
Scan(&tskCounts)
|
||||||
|
for _, tc := range tskCounts {
|
||||||
|
taskCounts[tc.CreatedByID] = tc.Count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Build response
|
// Build response
|
||||||
responses := make([]dto.UserResponse, len(users))
|
responses := make([]dto.UserResponse, len(users))
|
||||||
for i, user := range users {
|
for i, user := range users {
|
||||||
responses[i] = h.toUserResponse(&user)
|
responses[i] = h.toUserResponseWithCounts(&user, int(residenceCounts[user.ID]), int(taskCounts[user.ID]))
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, dto.NewPaginatedResponse(responses, total, filters.GetPage(), filters.GetPerPage()))
|
return c.JSON(http.StatusOK, dto.NewPaginatedResponse(responses, total, filters.GetPage(), filters.GetPerPage()))
|
||||||
@@ -122,7 +172,7 @@ func (h *AdminUserHandler) Get(c echo.Context) error {
|
|||||||
func (h *AdminUserHandler) Create(c echo.Context) error {
|
func (h *AdminUserHandler) Create(c echo.Context) error {
|
||||||
var req dto.CreateUserRequest
|
var req dto.CreateUserRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if username exists
|
// Check if username exists
|
||||||
@@ -170,7 +220,9 @@ func (h *AdminUserHandler) Create(c echo.Context) error {
|
|||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
PhoneNumber: req.PhoneNumber,
|
PhoneNumber: req.PhoneNumber,
|
||||||
}
|
}
|
||||||
h.db.Create(&profile)
|
if err := h.db.Create(&profile).Error; err != nil {
|
||||||
|
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create user profile")
|
||||||
|
}
|
||||||
|
|
||||||
// Reload with profile
|
// Reload with profile
|
||||||
h.db.Preload("Profile").First(&user, user.ID)
|
h.db.Preload("Profile").First(&user, user.ID)
|
||||||
@@ -194,7 +246,7 @@ func (h *AdminUserHandler) Update(c echo.Context) error {
|
|||||||
|
|
||||||
var req dto.UpdateUserRequest
|
var req dto.UpdateUserRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check username uniqueness if changing
|
// Check username uniqueness if changing
|
||||||
@@ -298,7 +350,7 @@ func (h *AdminUserHandler) Delete(c echo.Context) error {
|
|||||||
func (h *AdminUserHandler) BulkDelete(c echo.Context) error {
|
func (h *AdminUserHandler) BulkDelete(c echo.Context) error {
|
||||||
var req dto.BulkDeleteRequest
|
var req dto.BulkDeleteRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Soft delete - deactivate all
|
// Soft delete - deactivate all
|
||||||
@@ -309,8 +361,30 @@ func (h *AdminUserHandler) BulkDelete(c echo.Context) error {
|
|||||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Users deactivated successfully", "count": len(req.IDs)})
|
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Users deactivated successfully", "count": len(req.IDs)})
|
||||||
}
|
}
|
||||||
|
|
||||||
// toUserResponse converts a User model to UserResponse DTO
|
// toUserResponseWithCounts converts a User model to UserResponse DTO with pre-computed counts
|
||||||
|
func (h *AdminUserHandler) toUserResponseWithCounts(user *models.User, residenceCount, taskCount int) dto.UserResponse {
|
||||||
|
response := h.buildUserResponse(user)
|
||||||
|
response.ResidenceCount = residenceCount
|
||||||
|
response.TaskCount = taskCount
|
||||||
|
return response
|
||||||
|
}
|
||||||
|
|
||||||
|
// toUserResponse converts a User model to UserResponse DTO (fetches counts individually)
|
||||||
func (h *AdminUserHandler) toUserResponse(user *models.User) dto.UserResponse {
|
func (h *AdminUserHandler) toUserResponse(user *models.User) dto.UserResponse {
|
||||||
|
response := h.buildUserResponse(user)
|
||||||
|
|
||||||
|
// Get counts individually (used for single-user views)
|
||||||
|
var residenceCount, taskCount int64
|
||||||
|
h.db.Model(&models.Residence{}).Where("owner_id = ?", user.ID).Count(&residenceCount)
|
||||||
|
h.db.Model(&models.Task{}).Where("created_by_id = ?", user.ID).Count(&taskCount)
|
||||||
|
response.ResidenceCount = int(residenceCount)
|
||||||
|
response.TaskCount = int(taskCount)
|
||||||
|
|
||||||
|
return response
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildUserResponse creates the base UserResponse without counts
|
||||||
|
func (h *AdminUserHandler) buildUserResponse(user *models.User) dto.UserResponse {
|
||||||
response := dto.UserResponse{
|
response := dto.UserResponse{
|
||||||
ID: user.ID,
|
ID: user.ID,
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
@@ -335,12 +409,5 @@ func (h *AdminUserHandler) toUserResponse(user *models.User) dto.UserResponse {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get counts
|
|
||||||
var residenceCount, taskCount int64
|
|
||||||
h.db.Model(&models.Residence{}).Where("owner_id = ?", user.ID).Count(&residenceCount)
|
|
||||||
h.db.Model(&models.Task{}).Where("created_by_id = ?", user.ID).Count(&taskCount)
|
|
||||||
response.ResidenceCount = int(residenceCount)
|
|
||||||
response.TaskCount = int(taskCount)
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ type UpdateUserProfileRequest struct {
|
|||||||
func (h *AdminUserProfileHandler) List(c echo.Context) error {
|
func (h *AdminUserProfileHandler) List(c echo.Context) error {
|
||||||
var filters dto.PaginationParams
|
var filters dto.PaginationParams
|
||||||
if err := c.Bind(&filters); err != nil {
|
if err := c.Bind(&filters); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
var profiles []models.UserProfile
|
var profiles []models.UserProfile
|
||||||
@@ -144,7 +144,7 @@ func (h *AdminUserProfileHandler) Update(c echo.Context) error {
|
|||||||
|
|
||||||
var req UpdateUserProfileRequest
|
var req UpdateUserProfileRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Verified != nil {
|
if req.Verified != nil {
|
||||||
@@ -205,14 +205,15 @@ func (h *AdminUserProfileHandler) Delete(c echo.Context) error {
|
|||||||
func (h *AdminUserProfileHandler) BulkDelete(c echo.Context) error {
|
func (h *AdminUserProfileHandler) BulkDelete(c echo.Context) error {
|
||||||
var req dto.BulkDeleteRequest
|
var req dto.BulkDeleteRequest
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": err.Error()})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.Where("id IN ?", req.IDs).Delete(&models.UserProfile{}).Error; err != nil {
|
result := h.db.Where("id IN ?", req.IDs).Delete(&models.UserProfile{})
|
||||||
|
if result.Error != nil {
|
||||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete user profiles"})
|
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete user profiles"})
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "User profiles deleted successfully", "count": len(req.IDs)})
|
return c.JSON(http.StatusOK, map[string]interface{}{"message": "User profiles deleted successfully", "count": result.RowsAffected})
|
||||||
}
|
}
|
||||||
|
|
||||||
// toProfileResponse converts a UserProfile model to UserProfileResponse
|
// toProfileResponse converts a UserProfile model to UserProfileResponse
|
||||||
|
|||||||
@@ -379,9 +379,10 @@ func SetupRoutes(router *echo.Echo, db *gorm.DB, cfg *config.Config, deps *Depen
|
|||||||
documentImages.DELETE("/:id", documentImageHandler.Delete)
|
documentImages.DELETE("/:id", documentImageHandler.Delete)
|
||||||
}
|
}
|
||||||
|
|
||||||
// System settings management
|
// System settings management (super admin only)
|
||||||
settingsHandler := handlers.NewAdminSettingsHandler(db)
|
settingsHandler := handlers.NewAdminSettingsHandler(db)
|
||||||
settings := protected.Group("/settings")
|
settings := protected.Group("/settings")
|
||||||
|
settings.Use(middleware.RequireSuperAdmin())
|
||||||
{
|
{
|
||||||
settings.GET("", settingsHandler.GetSettings)
|
settings.GET("", settingsHandler.GetSettings)
|
||||||
settings.PUT("", settingsHandler.UpdateSettings)
|
settings.PUT("", settingsHandler.UpdateSettings)
|
||||||
@@ -500,9 +501,21 @@ func setupAdminProxy(router *echo.Echo) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Proxy Next.js static assets (served from /_next/ regardless of host)
|
// Proxy Next.js static assets — only when the request comes from a
|
||||||
|
// known host to prevent SSRF via crafted Host headers.
|
||||||
|
var allowedProxyHosts []string
|
||||||
|
if adminHost != "" {
|
||||||
|
allowedProxyHosts = append(allowedProxyHosts, adminHost)
|
||||||
|
}
|
||||||
|
// Also allow localhost variants for development
|
||||||
|
allowedProxyHosts = append(allowedProxyHosts,
|
||||||
|
"localhost:3001", "127.0.0.1:3001",
|
||||||
|
"localhost:8000", "127.0.0.1:8000",
|
||||||
|
)
|
||||||
|
hostCheck := middleware.HostCheck(allowedProxyHosts)
|
||||||
|
|
||||||
router.Any("/_next/*", func(c echo.Context) error {
|
router.Any("/_next/*", func(c echo.Context) error {
|
||||||
proxy.ServeHTTP(c.Response(), c.Request())
|
proxy.ServeHTTP(c.Response(), c.Request())
|
||||||
return nil
|
return nil
|
||||||
})
|
}, hostCheck)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
@@ -32,6 +33,7 @@ type Config struct {
|
|||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
Port int
|
Port int
|
||||||
Debug bool
|
Debug bool
|
||||||
|
DebugFixedCodes bool // Separate from Debug: enables fixed confirmation codes for local testing
|
||||||
AllowedHosts []string
|
AllowedHosts []string
|
||||||
CorsAllowedOrigins []string // Comma-separated origins for CORS (production only; debug uses wildcard)
|
CorsAllowedOrigins []string // Comma-separated origins for CORS (production only; debug uses wildcard)
|
||||||
Timezone string
|
Timezone string
|
||||||
@@ -75,7 +77,12 @@ type PushConfig struct {
|
|||||||
APNSSandbox bool
|
APNSSandbox bool
|
||||||
APNSProduction bool // If true, use production APNs; if false, use sandbox
|
APNSProduction bool // If true, use production APNs; if false, use sandbox
|
||||||
|
|
||||||
// FCM (Android) - uses direct HTTP to FCM legacy API
|
// FCM (Android) - uses FCM HTTP v1 API with OAuth 2.0
|
||||||
|
FCMProjectID string // Firebase project ID (required for v1 API)
|
||||||
|
FCMServiceAccountPath string // Path to Google service account JSON file
|
||||||
|
FCMServiceAccountJSON string // Raw service account JSON (alternative to path, e.g. for env var injection)
|
||||||
|
|
||||||
|
// Deprecated: FCMServerKey is for the legacy HTTP API. Use FCMProjectID + service account instead.
|
||||||
FCMServerKey string
|
FCMServerKey string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -147,135 +154,166 @@ type FeatureFlags struct {
|
|||||||
WorkerEnabled bool // FEATURE_WORKER_ENABLED (default: true)
|
WorkerEnabled bool // FEATURE_WORKER_ENABLED (default: true)
|
||||||
}
|
}
|
||||||
|
|
||||||
var cfg *Config
|
var (
|
||||||
|
cfg *Config
|
||||||
|
cfgOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
// knownWeakSecretKeys contains well-known default or placeholder secret keys
|
||||||
|
// that must not be used in production.
|
||||||
|
var knownWeakSecretKeys = map[string]bool{
|
||||||
|
"secret": true,
|
||||||
|
"changeme": true,
|
||||||
|
"change-me": true,
|
||||||
|
"password": true,
|
||||||
|
"change-me-in-production-secret-key-12345": true,
|
||||||
|
}
|
||||||
|
|
||||||
// Load reads configuration from environment variables
|
// Load reads configuration from environment variables
|
||||||
func Load() (*Config, error) {
|
func Load() (*Config, error) {
|
||||||
viper.SetEnvPrefix("")
|
var loadErr error
|
||||||
viper.AutomaticEnv()
|
|
||||||
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
|
||||||
|
|
||||||
// Set defaults
|
cfgOnce.Do(func() {
|
||||||
setDefaults()
|
viper.SetEnvPrefix("")
|
||||||
|
viper.AutomaticEnv()
|
||||||
|
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||||
|
|
||||||
// Parse DATABASE_URL if set (Dokku-style)
|
// Set defaults
|
||||||
dbConfig := DatabaseConfig{
|
setDefaults()
|
||||||
Host: viper.GetString("DB_HOST"),
|
|
||||||
Port: viper.GetInt("DB_PORT"),
|
|
||||||
User: viper.GetString("POSTGRES_USER"),
|
|
||||||
Password: viper.GetString("POSTGRES_PASSWORD"),
|
|
||||||
Database: viper.GetString("POSTGRES_DB"),
|
|
||||||
SSLMode: viper.GetString("DB_SSLMODE"),
|
|
||||||
MaxOpenConns: viper.GetInt("DB_MAX_OPEN_CONNS"),
|
|
||||||
MaxIdleConns: viper.GetInt("DB_MAX_IDLE_CONNS"),
|
|
||||||
MaxLifetime: viper.GetDuration("DB_MAX_LIFETIME"),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Override with DATABASE_URL if present
|
// Parse DATABASE_URL if set (Dokku-style)
|
||||||
if databaseURL := viper.GetString("DATABASE_URL"); databaseURL != "" {
|
dbConfig := DatabaseConfig{
|
||||||
parsed, err := parseDatabaseURL(databaseURL)
|
Host: viper.GetString("DB_HOST"),
|
||||||
if err == nil {
|
Port: viper.GetInt("DB_PORT"),
|
||||||
dbConfig.Host = parsed.Host
|
User: viper.GetString("POSTGRES_USER"),
|
||||||
dbConfig.Port = parsed.Port
|
Password: viper.GetString("POSTGRES_PASSWORD"),
|
||||||
dbConfig.User = parsed.User
|
Database: viper.GetString("POSTGRES_DB"),
|
||||||
dbConfig.Password = parsed.Password
|
SSLMode: viper.GetString("DB_SSLMODE"),
|
||||||
dbConfig.Database = parsed.Database
|
MaxOpenConns: viper.GetInt("DB_MAX_OPEN_CONNS"),
|
||||||
if parsed.SSLMode != "" {
|
MaxIdleConns: viper.GetInt("DB_MAX_IDLE_CONNS"),
|
||||||
dbConfig.SSLMode = parsed.SSLMode
|
MaxLifetime: viper.GetDuration("DB_MAX_LIFETIME"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override with DATABASE_URL if present (F-16: log warning on parse failure)
|
||||||
|
if databaseURL := viper.GetString("DATABASE_URL"); databaseURL != "" {
|
||||||
|
parsed, err := parseDatabaseURL(databaseURL)
|
||||||
|
if err != nil {
|
||||||
|
maskedURL := MaskURLCredentials(databaseURL)
|
||||||
|
fmt.Printf("WARNING: Failed to parse DATABASE_URL (%s): %v — falling back to individual DB_* env vars\n", maskedURL, err)
|
||||||
|
} else {
|
||||||
|
dbConfig.Host = parsed.Host
|
||||||
|
dbConfig.Port = parsed.Port
|
||||||
|
dbConfig.User = parsed.User
|
||||||
|
dbConfig.Password = parsed.Password
|
||||||
|
dbConfig.Database = parsed.Database
|
||||||
|
if parsed.SSLMode != "" {
|
||||||
|
dbConfig.SSLMode = parsed.SSLMode
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
cfg = &Config{
|
cfg = &Config{
|
||||||
Server: ServerConfig{
|
Server: ServerConfig{
|
||||||
Port: viper.GetInt("PORT"),
|
Port: viper.GetInt("PORT"),
|
||||||
Debug: viper.GetBool("DEBUG"),
|
Debug: viper.GetBool("DEBUG"),
|
||||||
AllowedHosts: strings.Split(viper.GetString("ALLOWED_HOSTS"), ","),
|
DebugFixedCodes: viper.GetBool("DEBUG_FIXED_CODES"),
|
||||||
CorsAllowedOrigins: parseCorsOrigins(viper.GetString("CORS_ALLOWED_ORIGINS")),
|
AllowedHosts: strings.Split(viper.GetString("ALLOWED_HOSTS"), ","),
|
||||||
Timezone: viper.GetString("TIMEZONE"),
|
CorsAllowedOrigins: parseCorsOrigins(viper.GetString("CORS_ALLOWED_ORIGINS")),
|
||||||
StaticDir: viper.GetString("STATIC_DIR"),
|
Timezone: viper.GetString("TIMEZONE"),
|
||||||
BaseURL: viper.GetString("BASE_URL"),
|
StaticDir: viper.GetString("STATIC_DIR"),
|
||||||
},
|
BaseURL: viper.GetString("BASE_URL"),
|
||||||
Database: dbConfig,
|
},
|
||||||
Redis: RedisConfig{
|
Database: dbConfig,
|
||||||
URL: viper.GetString("REDIS_URL"),
|
Redis: RedisConfig{
|
||||||
Password: viper.GetString("REDIS_PASSWORD"),
|
URL: viper.GetString("REDIS_URL"),
|
||||||
DB: viper.GetInt("REDIS_DB"),
|
Password: viper.GetString("REDIS_PASSWORD"),
|
||||||
},
|
DB: viper.GetInt("REDIS_DB"),
|
||||||
Email: EmailConfig{
|
},
|
||||||
Host: viper.GetString("EMAIL_HOST"),
|
Email: EmailConfig{
|
||||||
Port: viper.GetInt("EMAIL_PORT"),
|
Host: viper.GetString("EMAIL_HOST"),
|
||||||
User: viper.GetString("EMAIL_HOST_USER"),
|
Port: viper.GetInt("EMAIL_PORT"),
|
||||||
Password: viper.GetString("EMAIL_HOST_PASSWORD"),
|
User: viper.GetString("EMAIL_HOST_USER"),
|
||||||
From: viper.GetString("DEFAULT_FROM_EMAIL"),
|
Password: viper.GetString("EMAIL_HOST_PASSWORD"),
|
||||||
UseTLS: viper.GetBool("EMAIL_USE_TLS"),
|
From: viper.GetString("DEFAULT_FROM_EMAIL"),
|
||||||
},
|
UseTLS: viper.GetBool("EMAIL_USE_TLS"),
|
||||||
Push: PushConfig{
|
},
|
||||||
APNSKeyPath: viper.GetString("APNS_AUTH_KEY_PATH"),
|
Push: PushConfig{
|
||||||
APNSKeyID: viper.GetString("APNS_AUTH_KEY_ID"),
|
APNSKeyPath: viper.GetString("APNS_AUTH_KEY_PATH"),
|
||||||
APNSTeamID: viper.GetString("APNS_TEAM_ID"),
|
APNSKeyID: viper.GetString("APNS_AUTH_KEY_ID"),
|
||||||
APNSTopic: viper.GetString("APNS_TOPIC"),
|
APNSTeamID: viper.GetString("APNS_TEAM_ID"),
|
||||||
APNSSandbox: viper.GetBool("APNS_USE_SANDBOX"),
|
APNSTopic: viper.GetString("APNS_TOPIC"),
|
||||||
APNSProduction: viper.GetBool("APNS_PRODUCTION"),
|
APNSSandbox: viper.GetBool("APNS_USE_SANDBOX"),
|
||||||
FCMServerKey: viper.GetString("FCM_SERVER_KEY"),
|
APNSProduction: viper.GetBool("APNS_PRODUCTION"),
|
||||||
},
|
FCMProjectID: viper.GetString("FCM_PROJECT_ID"),
|
||||||
Worker: WorkerConfig{
|
FCMServiceAccountPath: viper.GetString("FCM_SERVICE_ACCOUNT_PATH"),
|
||||||
TaskReminderHour: viper.GetInt("TASK_REMINDER_HOUR"),
|
FCMServiceAccountJSON: viper.GetString("FCM_SERVICE_ACCOUNT_JSON"),
|
||||||
OverdueReminderHour: viper.GetInt("OVERDUE_REMINDER_HOUR"),
|
FCMServerKey: viper.GetString("FCM_SERVER_KEY"),
|
||||||
DailyNotifHour: viper.GetInt("DAILY_DIGEST_HOUR"),
|
},
|
||||||
},
|
Worker: WorkerConfig{
|
||||||
Security: SecurityConfig{
|
TaskReminderHour: viper.GetInt("TASK_REMINDER_HOUR"),
|
||||||
SecretKey: viper.GetString("SECRET_KEY"),
|
OverdueReminderHour: viper.GetInt("OVERDUE_REMINDER_HOUR"),
|
||||||
TokenCacheTTL: 5 * time.Minute,
|
DailyNotifHour: viper.GetInt("DAILY_DIGEST_HOUR"),
|
||||||
PasswordResetExpiry: 15 * time.Minute,
|
},
|
||||||
ConfirmationExpiry: 24 * time.Hour,
|
Security: SecurityConfig{
|
||||||
MaxPasswordResetRate: 3,
|
SecretKey: viper.GetString("SECRET_KEY"),
|
||||||
},
|
TokenCacheTTL: 5 * time.Minute,
|
||||||
Storage: StorageConfig{
|
PasswordResetExpiry: 15 * time.Minute,
|
||||||
UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"),
|
ConfirmationExpiry: 24 * time.Hour,
|
||||||
BaseURL: viper.GetString("STORAGE_BASE_URL"),
|
MaxPasswordResetRate: 3,
|
||||||
MaxFileSize: viper.GetInt64("STORAGE_MAX_FILE_SIZE"),
|
},
|
||||||
AllowedTypes: viper.GetString("STORAGE_ALLOWED_TYPES"),
|
Storage: StorageConfig{
|
||||||
},
|
UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"),
|
||||||
AppleAuth: AppleAuthConfig{
|
BaseURL: viper.GetString("STORAGE_BASE_URL"),
|
||||||
ClientID: viper.GetString("APPLE_CLIENT_ID"),
|
MaxFileSize: viper.GetInt64("STORAGE_MAX_FILE_SIZE"),
|
||||||
TeamID: viper.GetString("APPLE_TEAM_ID"),
|
AllowedTypes: viper.GetString("STORAGE_ALLOWED_TYPES"),
|
||||||
},
|
},
|
||||||
GoogleAuth: GoogleAuthConfig{
|
AppleAuth: AppleAuthConfig{
|
||||||
ClientID: viper.GetString("GOOGLE_CLIENT_ID"),
|
ClientID: viper.GetString("APPLE_CLIENT_ID"),
|
||||||
AndroidClientID: viper.GetString("GOOGLE_ANDROID_CLIENT_ID"),
|
TeamID: viper.GetString("APPLE_TEAM_ID"),
|
||||||
IOSClientID: viper.GetString("GOOGLE_IOS_CLIENT_ID"),
|
},
|
||||||
},
|
GoogleAuth: GoogleAuthConfig{
|
||||||
AppleIAP: AppleIAPConfig{
|
ClientID: viper.GetString("GOOGLE_CLIENT_ID"),
|
||||||
KeyPath: viper.GetString("APPLE_IAP_KEY_PATH"),
|
AndroidClientID: viper.GetString("GOOGLE_ANDROID_CLIENT_ID"),
|
||||||
KeyID: viper.GetString("APPLE_IAP_KEY_ID"),
|
IOSClientID: viper.GetString("GOOGLE_IOS_CLIENT_ID"),
|
||||||
IssuerID: viper.GetString("APPLE_IAP_ISSUER_ID"),
|
},
|
||||||
BundleID: viper.GetString("APPLE_IAP_BUNDLE_ID"),
|
AppleIAP: AppleIAPConfig{
|
||||||
Sandbox: viper.GetBool("APPLE_IAP_SANDBOX"),
|
KeyPath: viper.GetString("APPLE_IAP_KEY_PATH"),
|
||||||
},
|
KeyID: viper.GetString("APPLE_IAP_KEY_ID"),
|
||||||
GoogleIAP: GoogleIAPConfig{
|
IssuerID: viper.GetString("APPLE_IAP_ISSUER_ID"),
|
||||||
ServiceAccountPath: viper.GetString("GOOGLE_IAP_SERVICE_ACCOUNT_PATH"),
|
BundleID: viper.GetString("APPLE_IAP_BUNDLE_ID"),
|
||||||
PackageName: viper.GetString("GOOGLE_IAP_PACKAGE_NAME"),
|
Sandbox: viper.GetBool("APPLE_IAP_SANDBOX"),
|
||||||
},
|
},
|
||||||
Stripe: StripeConfig{
|
GoogleIAP: GoogleIAPConfig{
|
||||||
SecretKey: viper.GetString("STRIPE_SECRET_KEY"),
|
ServiceAccountPath: viper.GetString("GOOGLE_IAP_SERVICE_ACCOUNT_PATH"),
|
||||||
WebhookSecret: viper.GetString("STRIPE_WEBHOOK_SECRET"),
|
PackageName: viper.GetString("GOOGLE_IAP_PACKAGE_NAME"),
|
||||||
PriceMonthly: viper.GetString("STRIPE_PRICE_MONTHLY"),
|
},
|
||||||
PriceYearly: viper.GetString("STRIPE_PRICE_YEARLY"),
|
Stripe: StripeConfig{
|
||||||
},
|
SecretKey: viper.GetString("STRIPE_SECRET_KEY"),
|
||||||
Features: FeatureFlags{
|
WebhookSecret: viper.GetString("STRIPE_WEBHOOK_SECRET"),
|
||||||
PushEnabled: viper.GetBool("FEATURE_PUSH_ENABLED"),
|
PriceMonthly: viper.GetString("STRIPE_PRICE_MONTHLY"),
|
||||||
EmailEnabled: viper.GetBool("FEATURE_EMAIL_ENABLED"),
|
PriceYearly: viper.GetString("STRIPE_PRICE_YEARLY"),
|
||||||
WebhooksEnabled: viper.GetBool("FEATURE_WEBHOOKS_ENABLED"),
|
},
|
||||||
OnboardingEmailsEnabled: viper.GetBool("FEATURE_ONBOARDING_EMAILS_ENABLED"),
|
Features: FeatureFlags{
|
||||||
PDFReportsEnabled: viper.GetBool("FEATURE_PDF_REPORTS_ENABLED"),
|
PushEnabled: viper.GetBool("FEATURE_PUSH_ENABLED"),
|
||||||
WorkerEnabled: viper.GetBool("FEATURE_WORKER_ENABLED"),
|
EmailEnabled: viper.GetBool("FEATURE_EMAIL_ENABLED"),
|
||||||
},
|
WebhooksEnabled: viper.GetBool("FEATURE_WEBHOOKS_ENABLED"),
|
||||||
}
|
OnboardingEmailsEnabled: viper.GetBool("FEATURE_ONBOARDING_EMAILS_ENABLED"),
|
||||||
|
PDFReportsEnabled: viper.GetBool("FEATURE_PDF_REPORTS_ENABLED"),
|
||||||
|
WorkerEnabled: viper.GetBool("FEATURE_WORKER_ENABLED"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// Validate required fields
|
// Validate required fields
|
||||||
if err := validate(cfg); err != nil {
|
if err := validate(cfg); err != nil {
|
||||||
return nil, err
|
loadErr = err
|
||||||
|
// Reset so a subsequent call can retry after env is fixed
|
||||||
|
cfg = nil
|
||||||
|
cfgOnce = sync.Once{}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if loadErr != nil {
|
||||||
|
return nil, loadErr
|
||||||
}
|
}
|
||||||
|
|
||||||
return cfg, nil
|
return cfg, nil
|
||||||
@@ -290,6 +328,7 @@ func setDefaults() {
|
|||||||
// Server defaults
|
// Server defaults
|
||||||
viper.SetDefault("PORT", 8000)
|
viper.SetDefault("PORT", 8000)
|
||||||
viper.SetDefault("DEBUG", false)
|
viper.SetDefault("DEBUG", false)
|
||||||
|
viper.SetDefault("DEBUG_FIXED_CODES", false) // Separate flag for fixed confirmation codes
|
||||||
viper.SetDefault("ALLOWED_HOSTS", "localhost,127.0.0.1")
|
viper.SetDefault("ALLOWED_HOSTS", "localhost,127.0.0.1")
|
||||||
viper.SetDefault("TIMEZONE", "UTC")
|
viper.SetDefault("TIMEZONE", "UTC")
|
||||||
viper.SetDefault("STATIC_DIR", "/app/static")
|
viper.SetDefault("STATIC_DIR", "/app/static")
|
||||||
@@ -347,7 +386,13 @@ func setDefaults() {
|
|||||||
viper.SetDefault("FEATURE_WORKER_ENABLED", true)
|
viper.SetDefault("FEATURE_WORKER_ENABLED", true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isWeakSecretKey checks if the provided key is a known weak/default value.
|
||||||
|
func isWeakSecretKey(key string) bool {
|
||||||
|
return knownWeakSecretKeys[strings.ToLower(strings.TrimSpace(key))]
|
||||||
|
}
|
||||||
|
|
||||||
func validate(cfg *Config) error {
|
func validate(cfg *Config) error {
|
||||||
|
// S-08: Validate SECRET_KEY against known weak defaults
|
||||||
if cfg.Security.SecretKey == "" {
|
if cfg.Security.SecretKey == "" {
|
||||||
if cfg.Server.Debug {
|
if cfg.Server.Debug {
|
||||||
// In debug mode, use a default key with a warning for local development
|
// In debug mode, use a default key with a warning for local development
|
||||||
@@ -358,9 +403,12 @@ func validate(cfg *Config) error {
|
|||||||
// In production, refuse to start without a proper secret key
|
// In production, refuse to start without a proper secret key
|
||||||
return fmt.Errorf("FATAL: SECRET_KEY environment variable is required in production (DEBUG=false)")
|
return fmt.Errorf("FATAL: SECRET_KEY environment variable is required in production (DEBUG=false)")
|
||||||
}
|
}
|
||||||
} else if cfg.Security.SecretKey == "change-me-in-production-secret-key-12345" {
|
} else if isWeakSecretKey(cfg.Security.SecretKey) {
|
||||||
// Warn if someone explicitly set the well-known debug key
|
if cfg.Server.Debug {
|
||||||
fmt.Println("WARNING: SECRET_KEY is set to the well-known debug default. Change it for production use.")
|
fmt.Printf("WARNING: SECRET_KEY is set to a well-known weak value (%q). Change it for production use.\n", cfg.Security.SecretKey)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("FATAL: SECRET_KEY is set to a well-known weak value (%q). Use a strong, unique secret in production", cfg.Security.SecretKey)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Database password might come from DATABASE_URL, don't require it separately
|
// Database password might come from DATABASE_URL, don't require it separately
|
||||||
@@ -369,6 +417,21 @@ func validate(cfg *Config) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MaskURLCredentials parses a URL and replaces any password with "***".
|
||||||
|
// If parsing fails, it returns the string "<unparseable-url>" to avoid leaking credentials.
|
||||||
|
func MaskURLCredentials(rawURL string) string {
|
||||||
|
u, err := url.Parse(rawURL)
|
||||||
|
if err != nil {
|
||||||
|
return "<unparseable-url>"
|
||||||
|
}
|
||||||
|
if u.User != nil {
|
||||||
|
if _, hasPassword := u.User.Password(); hasPassword {
|
||||||
|
u.User = url.UserPassword(u.User.Username(), "***")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return u.Redacted()
|
||||||
|
}
|
||||||
|
|
||||||
// DSN returns the database connection string
|
// DSN returns the database connection string
|
||||||
func (d *DatabaseConfig) DSN() string {
|
func (d *DatabaseConfig) DSN() string {
|
||||||
return fmt.Sprintf(
|
return fmt.Sprintf(
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
"gorm.io/driver/postgres"
|
"gorm.io/driver/postgres"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
@@ -376,13 +378,23 @@ func migrateGoAdmin() error {
|
|||||||
}
|
}
|
||||||
db.Exec(`CREATE INDEX IF NOT EXISTS idx_goadmin_site_key ON goadmin_site(key)`)
|
db.Exec(`CREATE INDEX IF NOT EXISTS idx_goadmin_site_key ON goadmin_site(key)`)
|
||||||
|
|
||||||
// Seed default admin user only on first run (ON CONFLICT DO NOTHING).
|
// Seed default GoAdmin user only on first run (ON CONFLICT DO NOTHING).
|
||||||
// Password is NOT reset on subsequent migrations to preserve operator changes.
|
// Password is NOT reset on subsequent migrations to preserve operator changes.
|
||||||
db.Exec(`
|
goAdminUsername := viper.GetString("GOADMIN_ADMIN_USERNAME")
|
||||||
INSERT INTO goadmin_users (username, password, name, avatar)
|
goAdminPassword := viper.GetString("GOADMIN_ADMIN_PASSWORD")
|
||||||
VALUES ('admin', '$2a$10$t.GCU24EqIWLSl7F51Hdz.IkkgFK.Qa9/BzEc5Bi2C/I2bXf1nJgm', 'Administrator', '')
|
if goAdminUsername == "" || goAdminPassword == "" {
|
||||||
ON CONFLICT DO NOTHING
|
log.Warn().Msg("GOADMIN_ADMIN_USERNAME and/or GOADMIN_ADMIN_PASSWORD not set; skipping GoAdmin admin user seed")
|
||||||
`)
|
} else {
|
||||||
|
goAdminHash, err := bcrypt.GenerateFromPassword([]byte(goAdminPassword), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to hash GoAdmin admin password: %w", err)
|
||||||
|
}
|
||||||
|
db.Exec(`
|
||||||
|
INSERT INTO goadmin_users (username, password, name, avatar)
|
||||||
|
VALUES (?, ?, 'Administrator', '')
|
||||||
|
ON CONFLICT DO NOTHING
|
||||||
|
`, goAdminUsername, string(goAdminHash))
|
||||||
|
}
|
||||||
|
|
||||||
// Seed default roles
|
// Seed default roles
|
||||||
db.Exec(`INSERT INTO goadmin_roles (name, slug) VALUES ('Administrator', 'administrator') ON CONFLICT DO NOTHING`)
|
db.Exec(`INSERT INTO goadmin_roles (name, slug) VALUES ('Administrator', 'administrator') ON CONFLICT DO NOTHING`)
|
||||||
@@ -393,15 +405,17 @@ func migrateGoAdmin() error {
|
|||||||
db.Exec(`INSERT INTO goadmin_permissions (name, slug, http_method, http_path) VALUES ('Dashboard', 'dashboard', 'GET', '/') ON CONFLICT DO NOTHING`)
|
db.Exec(`INSERT INTO goadmin_permissions (name, slug, http_method, http_path) VALUES ('Dashboard', 'dashboard', 'GET', '/') ON CONFLICT DO NOTHING`)
|
||||||
|
|
||||||
// Assign admin user to administrator role (if not already assigned)
|
// Assign admin user to administrator role (if not already assigned)
|
||||||
db.Exec(`
|
if goAdminUsername != "" {
|
||||||
INSERT INTO goadmin_role_users (role_id, user_id)
|
db.Exec(`
|
||||||
SELECT r.id, u.id FROM goadmin_roles r, goadmin_users u
|
INSERT INTO goadmin_role_users (role_id, user_id)
|
||||||
WHERE r.slug = 'administrator' AND u.username = 'admin'
|
SELECT r.id, u.id FROM goadmin_roles r, goadmin_users u
|
||||||
AND NOT EXISTS (
|
WHERE r.slug = 'administrator' AND u.username = ?
|
||||||
SELECT 1 FROM goadmin_role_users ru
|
AND NOT EXISTS (
|
||||||
WHERE ru.role_id = r.id AND ru.user_id = u.id
|
SELECT 1 FROM goadmin_role_users ru
|
||||||
)
|
WHERE ru.role_id = r.id AND ru.user_id = u.id
|
||||||
`)
|
)
|
||||||
|
`, goAdminUsername)
|
||||||
|
}
|
||||||
|
|
||||||
// Assign all permissions to administrator role (if not already assigned)
|
// Assign all permissions to administrator role (if not already assigned)
|
||||||
db.Exec(`
|
db.Exec(`
|
||||||
@@ -448,15 +462,25 @@ func migrateGoAdmin() error {
|
|||||||
|
|
||||||
// Seed default Next.js admin user only on first run.
|
// Seed default Next.js admin user only on first run.
|
||||||
// Password is NOT reset on subsequent migrations to preserve operator changes.
|
// Password is NOT reset on subsequent migrations to preserve operator changes.
|
||||||
var adminCount int64
|
adminEmail := viper.GetString("ADMIN_EMAIL")
|
||||||
db.Raw(`SELECT COUNT(*) FROM admin_users WHERE email = 'admin@honeydue.com'`).Scan(&adminCount)
|
adminPassword := viper.GetString("ADMIN_PASSWORD")
|
||||||
if adminCount == 0 {
|
if adminEmail == "" || adminPassword == "" {
|
||||||
log.Info().Msg("Seeding default admin user for Next.js admin panel...")
|
log.Warn().Msg("ADMIN_EMAIL and/or ADMIN_PASSWORD not set; skipping Next.js admin user seed")
|
||||||
db.Exec(`
|
} else {
|
||||||
INSERT INTO admin_users (email, password, first_name, last_name, role, is_active, created_at, updated_at)
|
var adminCount int64
|
||||||
VALUES ('admin@honeydue.com', '$2a$10$t5hGjdXQLxr9Z0193qx.Tef6hd1vYI3JvrfX/piKx2qS9UvQ41I9O', 'Admin', 'User', 'super_admin', true, NOW(), NOW())
|
db.Raw(`SELECT COUNT(*) FROM admin_users WHERE email = ?`, adminEmail).Scan(&adminCount)
|
||||||
`)
|
if adminCount == 0 {
|
||||||
log.Info().Msg("Default admin user created: admin@honeydue.com")
|
log.Info().Str("email", adminEmail).Msg("Seeding default admin user for Next.js admin panel...")
|
||||||
|
adminHash, err := bcrypt.GenerateFromPassword([]byte(adminPassword), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to hash admin password: %w", err)
|
||||||
|
}
|
||||||
|
db.Exec(`
|
||||||
|
INSERT INTO admin_users (email, password, first_name, last_name, role, is_active, created_at, updated_at)
|
||||||
|
VALUES (?, ?, 'Admin', 'User', 'super_admin', true, NOW(), NOW())
|
||||||
|
`, adminEmail, string(adminHash))
|
||||||
|
log.Info().Str("email", adminEmail).Msg("Default admin user created")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ type CreateTaskRequest struct {
|
|||||||
// UpdateTaskRequest represents the request to update a task
|
// UpdateTaskRequest represents the request to update a task
|
||||||
type UpdateTaskRequest struct {
|
type UpdateTaskRequest struct {
|
||||||
Title *string `json:"title" validate:"omitempty,min=1,max=200"`
|
Title *string `json:"title" validate:"omitempty,min=1,max=200"`
|
||||||
Description *string `json:"description"`
|
Description *string `json:"description" validate:"omitempty,max=10000"`
|
||||||
CategoryID *uint `json:"category_id"`
|
CategoryID *uint `json:"category_id"`
|
||||||
PriorityID *uint `json:"priority_id"`
|
PriorityID *uint `json:"priority_id"`
|
||||||
FrequencyID *uint `json:"frequency_id"`
|
FrequencyID *uint `json:"frequency_id"`
|
||||||
|
|||||||
@@ -21,8 +21,9 @@ type DocumentUserResponse struct {
|
|||||||
type DocumentImageResponse struct {
|
type DocumentImageResponse struct {
|
||||||
ID uint `json:"id"`
|
ID uint `json:"id"`
|
||||||
ImageURL string `json:"image_url"`
|
ImageURL string `json:"image_url"`
|
||||||
MediaURL string `json:"media_url"` // Authenticated endpoint: /api/media/document-image/{id}
|
MediaURL string `json:"media_url"` // Authenticated endpoint: /api/media/document-image/{id}
|
||||||
Caption string `json:"caption"`
|
Caption string `json:"caption"`
|
||||||
|
Error string `json:"error,omitempty"` // Non-empty when the image could not be resolved
|
||||||
}
|
}
|
||||||
|
|
||||||
// DocumentResponse represents a document in the API response
|
// DocumentResponse represents a document in the API response
|
||||||
@@ -35,7 +36,6 @@ type DocumentResponse struct {
|
|||||||
Title string `json:"title"`
|
Title string `json:"title"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
DocumentType models.DocumentType `json:"document_type"`
|
DocumentType models.DocumentType `json:"document_type"`
|
||||||
FileURL string `json:"file_url"`
|
|
||||||
MediaURL string `json:"media_url"` // Authenticated endpoint: /api/media/document/{id}
|
MediaURL string `json:"media_url"` // Authenticated endpoint: /api/media/document/{id}
|
||||||
FileName string `json:"file_name"`
|
FileName string `json:"file_name"`
|
||||||
FileSize *int64 `json:"file_size"`
|
FileSize *int64 `json:"file_size"`
|
||||||
@@ -80,7 +80,6 @@ func NewDocumentResponse(d *models.Document) DocumentResponse {
|
|||||||
Title: d.Title,
|
Title: d.Title,
|
||||||
Description: d.Description,
|
Description: d.Description,
|
||||||
DocumentType: d.DocumentType,
|
DocumentType: d.DocumentType,
|
||||||
FileURL: d.FileURL,
|
|
||||||
MediaURL: fmt.Sprintf("/api/media/document/%d", d.ID), // Authenticated endpoint
|
MediaURL: fmt.Sprintf("/api/media/document/%d", d.ID), // Authenticated endpoint
|
||||||
FileName: d.FileName,
|
FileName: d.FileName,
|
||||||
FileSize: d.FileSize,
|
FileSize: d.FileSize,
|
||||||
@@ -104,12 +103,16 @@ func NewDocumentResponse(d *models.Document) DocumentResponse {
|
|||||||
|
|
||||||
// Convert images with authenticated media URLs
|
// Convert images with authenticated media URLs
|
||||||
for _, img := range d.Images {
|
for _, img := range d.Images {
|
||||||
resp.Images = append(resp.Images, DocumentImageResponse{
|
imgResp := DocumentImageResponse{
|
||||||
ID: img.ID,
|
ID: img.ID,
|
||||||
ImageURL: img.ImageURL,
|
ImageURL: img.ImageURL,
|
||||||
MediaURL: fmt.Sprintf("/api/media/document-image/%d", img.ID), // Authenticated endpoint
|
MediaURL: fmt.Sprintf("/api/media/document-image/%d", img.ID), // Authenticated endpoint
|
||||||
Caption: img.Caption,
|
Caption: img.Caption,
|
||||||
})
|
}
|
||||||
|
if img.ImageURL == "" {
|
||||||
|
imgResp.Error = "image source URL is missing"
|
||||||
|
}
|
||||||
|
resp.Images = append(resp.Images, imgResp)
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|||||||
@@ -281,13 +281,15 @@ func NewTaskListResponse(tasks []models.Task) []TaskResponse {
|
|||||||
return results
|
return results
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewKanbanBoardResponse creates a KanbanBoardResponse from a KanbanBoard model
|
// NewKanbanBoardResponse creates a KanbanBoardResponse from a KanbanBoard model.
|
||||||
func NewKanbanBoardResponse(board *models.KanbanBoard, residenceID uint) KanbanBoardResponse {
|
// The `now` parameter should be the start of day in the user's timezone so that
|
||||||
|
// individual task kanban columns are categorized consistently with the board query.
|
||||||
|
func NewKanbanBoardResponse(board *models.KanbanBoard, residenceID uint, now time.Time) KanbanBoardResponse {
|
||||||
columns := make([]KanbanColumnResponse, len(board.Columns))
|
columns := make([]KanbanColumnResponse, len(board.Columns))
|
||||||
for i, col := range board.Columns {
|
for i, col := range board.Columns {
|
||||||
tasks := make([]TaskResponse, len(col.Tasks))
|
tasks := make([]TaskResponse, len(col.Tasks))
|
||||||
for j, t := range col.Tasks {
|
for j, t := range col.Tasks {
|
||||||
tasks[j] = NewTaskResponse(&t)
|
tasks[j] = NewTaskResponseWithTime(&t, board.DaysThreshold, now)
|
||||||
}
|
}
|
||||||
columns[i] = KanbanColumnResponse{
|
columns[i] = KanbanColumnResponse{
|
||||||
Name: col.Name,
|
Name: col.Name,
|
||||||
@@ -306,13 +308,15 @@ func NewKanbanBoardResponse(board *models.KanbanBoard, residenceID uint) KanbanB
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewKanbanBoardResponseForAll creates a KanbanBoardResponse for all residences (no specific residence ID)
|
// NewKanbanBoardResponseForAll creates a KanbanBoardResponse for all residences (no specific residence ID).
|
||||||
func NewKanbanBoardResponseForAll(board *models.KanbanBoard) KanbanBoardResponse {
|
// The `now` parameter should be the start of day in the user's timezone so that
|
||||||
|
// individual task kanban columns are categorized consistently with the board query.
|
||||||
|
func NewKanbanBoardResponseForAll(board *models.KanbanBoard, now time.Time) KanbanBoardResponse {
|
||||||
columns := make([]KanbanColumnResponse, len(board.Columns))
|
columns := make([]KanbanColumnResponse, len(board.Columns))
|
||||||
for i, col := range board.Columns {
|
for i, col := range board.Columns {
|
||||||
tasks := make([]TaskResponse, len(col.Tasks))
|
tasks := make([]TaskResponse, len(col.Tasks))
|
||||||
for j, t := range col.Tasks {
|
for j, t := range col.Tasks {
|
||||||
tasks[j] = NewTaskResponse(&t)
|
tasks[j] = NewTaskResponseWithTime(&t, board.DaysThreshold, now)
|
||||||
}
|
}
|
||||||
columns[i] = KanbanColumnResponse{
|
columns[i] = KanbanColumnResponse{
|
||||||
Name: col.Name,
|
Name: col.Name,
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import (
|
|||||||
|
|
||||||
"github.com/treytartt/honeydue-api/internal/apperrors"
|
"github.com/treytartt/honeydue-api/internal/apperrors"
|
||||||
"github.com/treytartt/honeydue-api/internal/dto/requests"
|
"github.com/treytartt/honeydue-api/internal/dto/requests"
|
||||||
|
"github.com/treytartt/honeydue-api/internal/dto/responses"
|
||||||
|
"github.com/treytartt/honeydue-api/internal/i18n"
|
||||||
"github.com/treytartt/honeydue-api/internal/middleware"
|
"github.com/treytartt/honeydue-api/internal/middleware"
|
||||||
"github.com/treytartt/honeydue-api/internal/services"
|
"github.com/treytartt/honeydue-api/internal/services"
|
||||||
)
|
)
|
||||||
@@ -115,7 +117,7 @@ func (h *ContractorHandler) DeleteContractor(c echo.Context) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Contractor deleted successfully"})
|
return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.contractor_deleted")})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToggleFavorite handles POST /api/contractors/:id/toggle-favorite/
|
// ToggleFavorite handles POST /api/contractors/:id/toggle-favorite/
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ import (
|
|||||||
|
|
||||||
"github.com/treytartt/honeydue-api/internal/apperrors"
|
"github.com/treytartt/honeydue-api/internal/apperrors"
|
||||||
"github.com/treytartt/honeydue-api/internal/dto/requests"
|
"github.com/treytartt/honeydue-api/internal/dto/requests"
|
||||||
|
"github.com/treytartt/honeydue-api/internal/dto/responses"
|
||||||
|
"github.com/treytartt/honeydue-api/internal/i18n"
|
||||||
"github.com/treytartt/honeydue-api/internal/middleware"
|
"github.com/treytartt/honeydue-api/internal/middleware"
|
||||||
"github.com/treytartt/honeydue-api/internal/models"
|
"github.com/treytartt/honeydue-api/internal/models"
|
||||||
"github.com/treytartt/honeydue-api/internal/repositories"
|
"github.com/treytartt/honeydue-api/internal/repositories"
|
||||||
@@ -60,6 +62,9 @@ func (h *DocumentHandler) ListDocuments(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
if es := c.QueryParam("expiring_soon"); es != "" {
|
if es := c.QueryParam("expiring_soon"); es != "" {
|
||||||
if parsed, err := strconv.Atoi(es); err == nil {
|
if parsed, err := strconv.Atoi(es); err == nil {
|
||||||
|
if parsed < 1 || parsed > 3650 {
|
||||||
|
return apperrors.BadRequest("error.days_out_of_range")
|
||||||
|
}
|
||||||
filter.ExpiringSoon = &parsed
|
filter.ExpiringSoon = &parsed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -192,7 +197,10 @@ func (h *DocumentHandler) CreateDocument(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if uploadedFile != nil && h.storageService != nil {
|
if uploadedFile != nil {
|
||||||
|
if h.storageService == nil {
|
||||||
|
return apperrors.Internal(nil)
|
||||||
|
}
|
||||||
result, err := h.storageService.Upload(uploadedFile, "documents")
|
result, err := h.storageService.Upload(uploadedFile, "documents")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return apperrors.BadRequest("error.failed_to_upload_file")
|
return apperrors.BadRequest("error.failed_to_upload_file")
|
||||||
@@ -262,7 +270,7 @@ func (h *DocumentHandler) DeleteDocument(c echo.Context) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Document deleted successfully"})
|
return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.document_deleted")})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ActivateDocument handles POST /api/documents/:id/activate/
|
// ActivateDocument handles POST /api/documents/:id/activate/
|
||||||
@@ -280,7 +288,7 @@ func (h *DocumentHandler) ActivateDocument(c echo.Context) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Document activated successfully", "document": response})
|
return c.JSON(http.StatusOK, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeactivateDocument handles POST /api/documents/:id/deactivate/
|
// DeactivateDocument handles POST /api/documents/:id/deactivate/
|
||||||
@@ -298,7 +306,7 @@ func (h *DocumentHandler) DeactivateDocument(c echo.Context) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Document deactivated successfully", "document": response})
|
return c.JSON(http.StatusOK, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UploadDocumentImage handles POST /api/documents/:id/images/
|
// UploadDocumentImage handles POST /api/documents/:id/images/
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
|
|
||||||
"github.com/treytartt/honeydue-api/internal/apperrors"
|
"github.com/treytartt/honeydue-api/internal/apperrors"
|
||||||
|
"github.com/treytartt/honeydue-api/internal/dto/responses"
|
||||||
|
"github.com/treytartt/honeydue-api/internal/i18n"
|
||||||
"github.com/treytartt/honeydue-api/internal/middleware"
|
"github.com/treytartt/honeydue-api/internal/middleware"
|
||||||
"github.com/treytartt/honeydue-api/internal/services"
|
"github.com/treytartt/honeydue-api/internal/services"
|
||||||
)
|
)
|
||||||
@@ -87,7 +89,7 @@ func (h *NotificationHandler) MarkAsRead(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "message.notification_marked_read"})
|
return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.notification_marked_read")})
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkAllAsRead handles POST /api/notifications/mark-all-read/
|
// MarkAllAsRead handles POST /api/notifications/mark-all-read/
|
||||||
@@ -102,7 +104,7 @@ func (h *NotificationHandler) MarkAllAsRead(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "message.all_notifications_marked_read"})
|
return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.all_notifications_marked_read")})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPreferences handles GET /api/notifications/preferences/
|
// GetPreferences handles GET /api/notifications/preferences/
|
||||||
@@ -200,7 +202,10 @@ func (h *NotificationHandler) UnregisterDevice(c echo.Context) error {
|
|||||||
return apperrors.BadRequest("error.registration_id_required")
|
return apperrors.BadRequest("error.registration_id_required")
|
||||||
}
|
}
|
||||||
if req.Platform == "" {
|
if req.Platform == "" {
|
||||||
req.Platform = "ios" // Default to iOS
|
return apperrors.BadRequest("error.platform_required")
|
||||||
|
}
|
||||||
|
if req.Platform != "ios" && req.Platform != "android" {
|
||||||
|
return apperrors.BadRequest("error.invalid_platform")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.notificationService.UnregisterDevice(req.RegistrationID, req.Platform, user.ID)
|
err = h.notificationService.UnregisterDevice(req.RegistrationID, req.Platform, user.ID)
|
||||||
@@ -208,7 +213,7 @@ func (h *NotificationHandler) UnregisterDevice(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "message.device_unregistered"})
|
return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.device_removed")})
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteDevice handles DELETE /api/notifications/devices/:id/
|
// DeleteDevice handles DELETE /api/notifications/devices/:id/
|
||||||
@@ -225,7 +230,10 @@ func (h *NotificationHandler) DeleteDevice(c echo.Context) error {
|
|||||||
|
|
||||||
platform := c.QueryParam("platform")
|
platform := c.QueryParam("platform")
|
||||||
if platform == "" {
|
if platform == "" {
|
||||||
platform = "ios" // Default to iOS
|
return apperrors.BadRequest("error.platform_required")
|
||||||
|
}
|
||||||
|
if platform != "ios" && platform != "android" {
|
||||||
|
return apperrors.BadRequest("error.invalid_platform")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.notificationService.DeleteDevice(uint(deviceID), platform, user.ID)
|
err = h.notificationService.DeleteDevice(uint(deviceID), platform, user.ID)
|
||||||
@@ -233,5 +241,5 @@ func (h *NotificationHandler) DeleteDevice(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "message.device_removed"})
|
return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.device_removed")})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
|
|
||||||
"github.com/treytartt/honeydue-api/internal/apperrors"
|
"github.com/treytartt/honeydue-api/internal/apperrors"
|
||||||
|
"github.com/treytartt/honeydue-api/internal/i18n"
|
||||||
"github.com/treytartt/honeydue-api/internal/middleware"
|
"github.com/treytartt/honeydue-api/internal/middleware"
|
||||||
"github.com/treytartt/honeydue-api/internal/services"
|
"github.com/treytartt/honeydue-api/internal/services"
|
||||||
)
|
)
|
||||||
@@ -139,7 +140,7 @@ func (h *SubscriptionHandler) ProcessPurchase(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{
|
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||||
"message": "message.subscription_upgraded",
|
"message": i18n.LocalizedMessage(c, "message.subscription_upgraded"),
|
||||||
"subscription": subscription,
|
"subscription": subscription,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -157,7 +158,7 @@ func (h *SubscriptionHandler) CancelSubscription(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{
|
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||||
"message": "message.subscription_cancelled",
|
"message": i18n.LocalizedMessage(c, "message.subscription_cancelled"),
|
||||||
"subscription": subscription,
|
"subscription": subscription,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -182,8 +183,15 @@ func (h *SubscriptionHandler) RestoreSubscription(c echo.Context) error {
|
|||||||
|
|
||||||
switch req.Platform {
|
switch req.Platform {
|
||||||
case "ios":
|
case "ios":
|
||||||
|
// B-14: Validate that at least one of receipt_data or transaction_id is provided
|
||||||
|
if req.ReceiptData == "" && req.TransactionID == "" {
|
||||||
|
return apperrors.BadRequest("error.receipt_data_required")
|
||||||
|
}
|
||||||
subscription, err = h.subscriptionService.ProcessApplePurchase(user.ID, req.ReceiptData, req.TransactionID)
|
subscription, err = h.subscriptionService.ProcessApplePurchase(user.ID, req.ReceiptData, req.TransactionID)
|
||||||
case "android":
|
case "android":
|
||||||
|
if req.PurchaseToken == "" {
|
||||||
|
return apperrors.BadRequest("error.purchase_token_required")
|
||||||
|
}
|
||||||
subscription, err = h.subscriptionService.ProcessGooglePurchase(user.ID, req.PurchaseToken, req.ProductID)
|
subscription, err = h.subscriptionService.ProcessGooglePurchase(user.ID, req.PurchaseToken, req.ProductID)
|
||||||
default:
|
default:
|
||||||
return apperrors.BadRequest("error.invalid_platform")
|
return apperrors.BadRequest("error.invalid_platform")
|
||||||
@@ -194,7 +202,7 @@ func (h *SubscriptionHandler) RestoreSubscription(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{
|
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||||
"message": "message.subscription_restored",
|
"message": i18n.LocalizedMessage(c, "message.subscription_restored"),
|
||||||
"subscription": subscription,
|
"subscription": subscription,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,15 +2,18 @@ package handlers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
|
"crypto/rsa"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
@@ -23,6 +26,11 @@ import (
|
|||||||
"github.com/treytartt/honeydue-api/internal/services"
|
"github.com/treytartt/honeydue-api/internal/services"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// maxWebhookBodySize is the maximum allowed request body size for webhook
|
||||||
|
// payloads (1 MB). This prevents a malicious or misbehaving sender from
|
||||||
|
// forcing the server to allocate unbounded memory.
|
||||||
|
const maxWebhookBodySize = 1 << 20 // 1 MB
|
||||||
|
|
||||||
// SubscriptionWebhookHandler handles subscription webhook callbacks
|
// SubscriptionWebhookHandler handles subscription webhook callbacks
|
||||||
type SubscriptionWebhookHandler struct {
|
type SubscriptionWebhookHandler struct {
|
||||||
subscriptionRepo *repositories.SubscriptionRepository
|
subscriptionRepo *repositories.SubscriptionRepository
|
||||||
@@ -112,7 +120,7 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
|
|||||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"})
|
return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"})
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(c.Request().Body)
|
body, err := io.ReadAll(io.LimitReader(c.Request().Body, maxWebhookBodySize))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Apple Webhook: Failed to read body")
|
log.Error().Err(err).Msg("Apple Webhook: Failed to read body")
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
|
||||||
@@ -211,13 +219,22 @@ func (h *SubscriptionWebhookHandler) decodeAppleSignedPayload(signedPayload stri
|
|||||||
return ¬ification, nil
|
return ¬ification, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// decodeAppleTransaction decodes a signed transaction info JWS
|
// decodeAppleTransaction decodes and verifies a signed transaction info JWS.
|
||||||
|
// The inner JWS signature is verified using the same Apple certificate chain
|
||||||
|
// validation as the outer notification payload.
|
||||||
func (h *SubscriptionWebhookHandler) decodeAppleTransaction(signedTransaction string) (*AppleTransactionInfo, error) {
|
func (h *SubscriptionWebhookHandler) decodeAppleTransaction(signedTransaction string) (*AppleTransactionInfo, error) {
|
||||||
parts := strings.Split(signedTransaction, ".")
|
parts := strings.Split(signedTransaction, ".")
|
||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
return nil, fmt.Errorf("invalid JWS format")
|
return nil, fmt.Errorf("invalid JWS format")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// S-16: Verify the inner JWS signature for signedTransactionInfo.
|
||||||
|
// Apple signs each inner JWS independently with the same x5c certificate
|
||||||
|
// chain as the outer notification, so the same verification applies.
|
||||||
|
if err := h.VerifyAppleSignature(signedTransaction); err != nil {
|
||||||
|
return nil, fmt.Errorf("Apple transaction JWS signature verification failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to decode payload: %w", err)
|
return nil, fmt.Errorf("failed to decode payload: %w", err)
|
||||||
@@ -231,13 +248,20 @@ func (h *SubscriptionWebhookHandler) decodeAppleTransaction(signedTransaction st
|
|||||||
return &info, nil
|
return &info, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// decodeAppleRenewalInfo decodes signed renewal info JWS
|
// decodeAppleRenewalInfo decodes and verifies a signed renewal info JWS.
|
||||||
|
// The inner JWS signature is verified using the same Apple certificate chain
|
||||||
|
// validation as the outer notification payload.
|
||||||
func (h *SubscriptionWebhookHandler) decodeAppleRenewalInfo(signedRenewal string) (*AppleRenewalInfo, error) {
|
func (h *SubscriptionWebhookHandler) decodeAppleRenewalInfo(signedRenewal string) (*AppleRenewalInfo, error) {
|
||||||
parts := strings.Split(signedRenewal, ".")
|
parts := strings.Split(signedRenewal, ".")
|
||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
return nil, fmt.Errorf("invalid JWS format")
|
return nil, fmt.Errorf("invalid JWS format")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// S-16: Verify the inner JWS signature for signedRenewalInfo.
|
||||||
|
if err := h.VerifyAppleSignature(signedRenewal); err != nil {
|
||||||
|
return nil, fmt.Errorf("Apple renewal JWS signature verification failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to decode payload: %w", err)
|
return nil, fmt.Errorf("failed to decode payload: %w", err)
|
||||||
@@ -484,7 +508,12 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
|
|||||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"})
|
return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"})
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(c.Request().Body)
|
// C-01: Verify the Google Pub/Sub push authentication token before processing
|
||||||
|
if !h.VerifyGooglePubSubToken(c) {
|
||||||
|
return c.JSON(http.StatusUnauthorized, map[string]interface{}{"error": "unauthorized"})
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(io.LimitReader(c.Request().Body, maxWebhookBodySize))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Google Webhook: Failed to read body")
|
log.Error().Err(err).Msg("Google Webhook: Failed to read body")
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
|
||||||
@@ -781,7 +810,7 @@ func (h *SubscriptionWebhookHandler) HandleStripeWebhook(c echo.Context) error {
|
|||||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "not_configured"})
|
return c.JSON(http.StatusOK, map[string]interface{}{"status": "not_configured"})
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(c.Request().Body)
|
body, err := io.ReadAll(io.LimitReader(c.Request().Body, maxWebhookBodySize))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Stripe Webhook: Failed to read body")
|
log.Error().Err(err).Msg("Stripe Webhook: Failed to read body")
|
||||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
|
||||||
@@ -884,10 +913,109 @@ func (h *SubscriptionWebhookHandler) VerifyAppleSignature(signedPayload string)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// googleOIDCCertsURL is the endpoint that serves Google's public OAuth2
|
||||||
|
// certificates used to verify JWTs issued by accounts.google.com (including
|
||||||
|
// Pub/Sub push tokens).
|
||||||
|
const googleOIDCCertsURL = "https://www.googleapis.com/oauth2/v3/certs"
|
||||||
|
|
||||||
|
// googleJWKSCache caches the fetched Google JWKS keys so we don't hit the
|
||||||
|
// network on every webhook request. Keys are refreshed when the cache expires.
|
||||||
|
var (
|
||||||
|
googleJWKSCache map[string]*rsa.PublicKey
|
||||||
|
googleJWKSCacheMu sync.RWMutex
|
||||||
|
googleJWKSCacheTime time.Time
|
||||||
|
googleJWKSCacheTTL = 1 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
// googleJWKSResponse represents the JSON Web Key Set response from Google.
|
||||||
|
type googleJWKSResponse struct {
|
||||||
|
Keys []googleJWK `json:"keys"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// googleJWK represents a single JSON Web Key from Google's OIDC endpoint.
|
||||||
|
type googleJWK struct {
|
||||||
|
Kid string `json:"kid"` // Key ID
|
||||||
|
Kty string `json:"kty"` // Key type (RSA)
|
||||||
|
Alg string `json:"alg"` // Algorithm (RS256)
|
||||||
|
N string `json:"n"` // RSA modulus (base64url)
|
||||||
|
E string `json:"e"` // RSA exponent (base64url)
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchGoogleOIDCKeys fetches Google's public OIDC keys from their well-known
|
||||||
|
// endpoint, returning a map of key-id to RSA public key. Results are cached
|
||||||
|
// for googleJWKSCacheTTL to avoid excessive network calls.
|
||||||
|
func fetchGoogleOIDCKeys() (map[string]*rsa.PublicKey, error) {
|
||||||
|
googleJWKSCacheMu.RLock()
|
||||||
|
if googleJWKSCache != nil && time.Since(googleJWKSCacheTime) < googleJWKSCacheTTL {
|
||||||
|
cached := googleJWKSCache
|
||||||
|
googleJWKSCacheMu.RUnlock()
|
||||||
|
return cached, nil
|
||||||
|
}
|
||||||
|
googleJWKSCacheMu.RUnlock()
|
||||||
|
|
||||||
|
googleJWKSCacheMu.Lock()
|
||||||
|
defer googleJWKSCacheMu.Unlock()
|
||||||
|
|
||||||
|
// Double-check after acquiring write lock
|
||||||
|
if googleJWKSCache != nil && time.Since(googleJWKSCacheTime) < googleJWKSCacheTTL {
|
||||||
|
return googleJWKSCache, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: 10 * time.Second}
|
||||||
|
resp, err := client.Get(googleOIDCCertsURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to fetch Google OIDC keys: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("Google OIDC keys endpoint returned status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var jwks googleJWKSResponse
|
||||||
|
if err := json.NewDecoder(io.LimitReader(resp.Body, maxWebhookBodySize)).Decode(&jwks); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode Google OIDC keys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := make(map[string]*rsa.PublicKey, len(jwks.Keys))
|
||||||
|
for _, k := range jwks.Keys {
|
||||||
|
if k.Kty != "RSA" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
nBytes, err := base64.RawURLEncoding.DecodeString(k.N)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Str("kid", k.Kid).Err(err).Msg("Google OIDC: failed to decode modulus")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
eBytes, err := base64.RawURLEncoding.DecodeString(k.E)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Str("kid", k.Kid).Err(err).Msg("Google OIDC: failed to decode exponent")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
n := new(big.Int).SetBytes(nBytes)
|
||||||
|
e := 0
|
||||||
|
for _, b := range eBytes {
|
||||||
|
e = e<<8 + int(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
keys[k.Kid] = &rsa.PublicKey{N: n, E: e}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return nil, fmt.Errorf("no usable RSA keys found in Google OIDC response")
|
||||||
|
}
|
||||||
|
|
||||||
|
googleJWKSCache = keys
|
||||||
|
googleJWKSCacheTime = time.Now()
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
// VerifyGooglePubSubToken verifies the Pub/Sub push authentication token.
|
// VerifyGooglePubSubToken verifies the Pub/Sub push authentication token.
|
||||||
// Returns false (deny) when the Authorization header is missing or the token
|
// The token is a JWT signed by Google (accounts.google.com). This function
|
||||||
// cannot be validated. This prevents unauthenticated callers from injecting
|
// verifies the signature against Google's published OIDC public keys, checks
|
||||||
// webhook events.
|
// the issuer claim, and validates the email claim is a Google service account.
|
||||||
|
// Returns false (deny) when verification fails for any reason.
|
||||||
func (h *SubscriptionWebhookHandler) VerifyGooglePubSubToken(c echo.Context) bool {
|
func (h *SubscriptionWebhookHandler) VerifyGooglePubSubToken(c echo.Context) bool {
|
||||||
authHeader := c.Request().Header.Get("Authorization")
|
authHeader := c.Request().Header.Get("Authorization")
|
||||||
if authHeader == "" {
|
if authHeader == "" {
|
||||||
@@ -907,12 +1035,52 @@ func (h *SubscriptionWebhookHandler) VerifyGooglePubSubToken(c echo.Context) boo
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the token as a JWT. Google Pub/Sub push tokens are signed JWTs
|
// Fetch Google's OIDC public keys for signature verification
|
||||||
// issued by accounts.google.com. We verify the claims to ensure the
|
googleKeys, err := fetchGoogleOIDCKeys()
|
||||||
// token was intended for our service.
|
|
||||||
token, _, err := jwt.NewParser().ParseUnverified(bearerToken, jwt.MapClaims{})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Msg("Google Webhook: failed to parse Bearer token")
|
log.Error().Err(err).Msg("Google Webhook: failed to fetch OIDC keys, denying request")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse and verify the JWT signature against Google's published public keys
|
||||||
|
token, err := jwt.Parse(bearerToken, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
// Ensure the signing method is RSA (Google uses RS256)
|
||||||
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||||
|
}
|
||||||
|
|
||||||
|
kid, _ := token.Header["kid"].(string)
|
||||||
|
if kid == "" {
|
||||||
|
return nil, fmt.Errorf("missing kid header in token")
|
||||||
|
}
|
||||||
|
|
||||||
|
key, ok := googleKeys[kid]
|
||||||
|
if !ok {
|
||||||
|
// Key may have rotated; try refreshing once
|
||||||
|
googleJWKSCacheMu.Lock()
|
||||||
|
googleJWKSCacheTime = time.Time{} // Force refresh
|
||||||
|
googleJWKSCacheMu.Unlock()
|
||||||
|
|
||||||
|
refreshedKeys, err := fetchGoogleOIDCKeys()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to refresh Google OIDC keys: %w", err)
|
||||||
|
}
|
||||||
|
key, ok = refreshedKeys[kid]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unknown key ID: %s", kid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, nil
|
||||||
|
}, jwt.WithValidMethods([]string{"RS256"}))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Err(err).Msg("Google Webhook: JWT signature verification failed")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !token.Valid {
|
||||||
|
log.Warn().Msg("Google Webhook: token is invalid after verification")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -38,19 +38,29 @@ func (h *TaskHandler) ListTasks(c echo.Context) error {
|
|||||||
userNow := middleware.GetUserNow(c)
|
userNow := middleware.GetUserNow(c)
|
||||||
|
|
||||||
// Auto-capture timezone from header for background job calculations (e.g., daily digest)
|
// Auto-capture timezone from header for background job calculations (e.g., daily digest)
|
||||||
// Runs synchronously — this is a lightweight DB upsert that should complete quickly
|
// Only write to DB if the timezone has actually changed from the cached value
|
||||||
if tzHeader := c.Request().Header.Get("X-Timezone"); tzHeader != "" {
|
if tzHeader := c.Request().Header.Get("X-Timezone"); tzHeader != "" {
|
||||||
h.taskService.UpdateUserTimezone(user.ID, tzHeader)
|
cachedTZ, _ := c.Get("user_timezone").(string)
|
||||||
|
if cachedTZ != tzHeader {
|
||||||
|
h.taskService.UpdateUserTimezone(user.ID, tzHeader)
|
||||||
|
c.Set("user_timezone", tzHeader)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
daysThreshold := 30
|
daysThreshold := 30
|
||||||
// Support "days" param first, fall back to "days_threshold" for backward compatibility
|
// Support "days" param first, fall back to "days_threshold" for backward compatibility
|
||||||
if d := c.QueryParam("days"); d != "" {
|
if d := c.QueryParam("days"); d != "" {
|
||||||
if parsed, err := strconv.Atoi(d); err == nil {
|
if parsed, err := strconv.Atoi(d); err == nil {
|
||||||
|
if parsed < 1 || parsed > 3650 {
|
||||||
|
return apperrors.BadRequest("error.days_out_of_range")
|
||||||
|
}
|
||||||
daysThreshold = parsed
|
daysThreshold = parsed
|
||||||
}
|
}
|
||||||
} else if d := c.QueryParam("days_threshold"); d != "" {
|
} else if d := c.QueryParam("days_threshold"); d != "" {
|
||||||
if parsed, err := strconv.Atoi(d); err == nil {
|
if parsed, err := strconv.Atoi(d); err == nil {
|
||||||
|
if parsed < 1 || parsed > 3650 {
|
||||||
|
return apperrors.BadRequest("error.days_out_of_range")
|
||||||
|
}
|
||||||
daysThreshold = parsed
|
daysThreshold = parsed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -97,10 +107,16 @@ func (h *TaskHandler) GetTasksByResidence(c echo.Context) error {
|
|||||||
// Support "days" param first, fall back to "days_threshold" for backward compatibility
|
// Support "days" param first, fall back to "days_threshold" for backward compatibility
|
||||||
if d := c.QueryParam("days"); d != "" {
|
if d := c.QueryParam("days"); d != "" {
|
||||||
if parsed, err := strconv.Atoi(d); err == nil {
|
if parsed, err := strconv.Atoi(d); err == nil {
|
||||||
|
if parsed < 1 || parsed > 3650 {
|
||||||
|
return apperrors.BadRequest("error.days_out_of_range")
|
||||||
|
}
|
||||||
daysThreshold = parsed
|
daysThreshold = parsed
|
||||||
}
|
}
|
||||||
} else if d := c.QueryParam("days_threshold"); d != "" {
|
} else if d := c.QueryParam("days_threshold"); d != "" {
|
||||||
if parsed, err := strconv.Atoi(d); err == nil {
|
if parsed, err := strconv.Atoi(d); err == nil {
|
||||||
|
if parsed < 1 || parsed > 3650 {
|
||||||
|
return apperrors.BadRequest("error.days_out_of_range")
|
||||||
|
}
|
||||||
daysThreshold = parsed
|
daysThreshold = parsed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,19 +7,31 @@ import (
|
|||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
"github.com/treytartt/honeydue-api/internal/apperrors"
|
"github.com/treytartt/honeydue-api/internal/apperrors"
|
||||||
|
"github.com/treytartt/honeydue-api/internal/dto/responses"
|
||||||
|
"github.com/treytartt/honeydue-api/internal/i18n"
|
||||||
"github.com/treytartt/honeydue-api/internal/middleware"
|
"github.com/treytartt/honeydue-api/internal/middleware"
|
||||||
"github.com/treytartt/honeydue-api/internal/models"
|
|
||||||
"github.com/treytartt/honeydue-api/internal/services"
|
"github.com/treytartt/honeydue-api/internal/services"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// FileOwnershipChecker verifies whether a user owns a file referenced by URL.
|
||||||
|
// Implementations should check associated records (e.g., task completion images,
|
||||||
|
// document files, document images) to determine ownership.
|
||||||
|
type FileOwnershipChecker interface {
|
||||||
|
IsFileOwnedByUser(fileURL string, userID uint) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
// UploadHandler handles file upload endpoints
|
// UploadHandler handles file upload endpoints
|
||||||
type UploadHandler struct {
|
type UploadHandler struct {
|
||||||
storageService *services.StorageService
|
storageService *services.StorageService
|
||||||
|
fileOwnershipChecker FileOwnershipChecker
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUploadHandler creates a new upload handler
|
// NewUploadHandler creates a new upload handler
|
||||||
func NewUploadHandler(storageService *services.StorageService) *UploadHandler {
|
func NewUploadHandler(storageService *services.StorageService, fileOwnershipChecker FileOwnershipChecker) *UploadHandler {
|
||||||
return &UploadHandler{storageService: storageService}
|
return &UploadHandler{
|
||||||
|
storageService: storageService,
|
||||||
|
fileOwnershipChecker: fileOwnershipChecker,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// UploadImage handles POST /api/uploads/image
|
// UploadImage handles POST /api/uploads/image
|
||||||
@@ -83,13 +95,14 @@ type DeleteFileRequest struct {
|
|||||||
|
|
||||||
// DeleteFile handles DELETE /api/uploads
|
// DeleteFile handles DELETE /api/uploads
|
||||||
// Expects JSON body with "url" field.
|
// Expects JSON body with "url" field.
|
||||||
//
|
// Verifies that the requesting user owns the file by checking associated records
|
||||||
// TODO(SEC-18): Add ownership verification. Currently any authenticated user can delete
|
// (task completion images, document files/images) before allowing deletion.
|
||||||
// any file if they know the URL. The upload system does not track which user uploaded
|
|
||||||
// which file, so a proper fix requires adding an uploads table or file ownership metadata.
|
|
||||||
// For now, deletions are logged with user ID for audit trail, and StorageService.Delete
|
|
||||||
// enforces path containment to prevent deleting files outside the upload directory.
|
|
||||||
func (h *UploadHandler) DeleteFile(c echo.Context) error {
|
func (h *UploadHandler) DeleteFile(c echo.Context) error {
|
||||||
|
user, err := middleware.MustGetAuthUser(c)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
var req DeleteFileRequest
|
var req DeleteFileRequest
|
||||||
|
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
@@ -100,17 +113,28 @@ func (h *UploadHandler) DeleteFile(c echo.Context) error {
|
|||||||
return apperrors.BadRequest("error.url_required")
|
return apperrors.BadRequest("error.url_required")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Log the deletion with user ID for audit trail
|
// Verify ownership: the user must own a record that references this file URL
|
||||||
if user, ok := c.Get(middleware.AuthUserKey).(*models.User); ok {
|
if h.fileOwnershipChecker != nil {
|
||||||
log.Info().
|
owned, err := h.fileOwnershipChecker.IsFileOwnedByUser(req.URL, user.ID)
|
||||||
Uint("user_id", user.ID).
|
if err != nil {
|
||||||
Str("file_url", req.URL).
|
log.Error().Err(err).Uint("user_id", user.ID).Str("file_url", req.URL).Msg("Failed to check file ownership")
|
||||||
Msg("File deletion requested")
|
return apperrors.Internal(err)
|
||||||
|
}
|
||||||
|
if !owned {
|
||||||
|
log.Warn().Uint("user_id", user.ID).Str("file_url", req.URL).Msg("Unauthorized file deletion attempt")
|
||||||
|
return apperrors.Forbidden("error.file_access_denied")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Log the deletion with user ID for audit trail
|
||||||
|
log.Info().
|
||||||
|
Uint("user_id", user.ID).
|
||||||
|
Str("file_url", req.URL).
|
||||||
|
Msg("File deletion requested")
|
||||||
|
|
||||||
if err := h.storageService.Delete(req.URL); err != nil {
|
if err := h.storageService.Delete(req.URL); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "File deleted successfully"})
|
return c.JSON(http.StatusOK, responses.MessageResponse{Message: i18n.LocalizedMessage(c, "message.file_deleted")})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/treytartt/honeydue-api/internal/i18n"
|
"github.com/treytartt/honeydue-api/internal/i18n"
|
||||||
|
"github.com/treytartt/honeydue-api/internal/models"
|
||||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -18,12 +19,16 @@ func init() {
|
|||||||
func TestDeleteFile_MissingURL_Returns400(t *testing.T) {
|
func TestDeleteFile_MissingURL_Returns400(t *testing.T) {
|
||||||
// Use a test storage service — DeleteFile won't reach storage since validation fails first
|
// Use a test storage service — DeleteFile won't reach storage since validation fails first
|
||||||
storageSvc := newTestStorageService("/var/uploads")
|
storageSvc := newTestStorageService("/var/uploads")
|
||||||
handler := NewUploadHandler(storageSvc)
|
handler := NewUploadHandler(storageSvc, nil)
|
||||||
|
|
||||||
e := testutil.SetupTestRouter()
|
e := testutil.SetupTestRouter()
|
||||||
|
|
||||||
// Register route
|
// Register route with mock auth middleware
|
||||||
e.DELETE("/api/uploads/", handler.DeleteFile)
|
testUser := &models.User{FirstName: "Test", Email: "test@test.com"}
|
||||||
|
testUser.ID = 1
|
||||||
|
authGroup := e.Group("/api")
|
||||||
|
authGroup.Use(testutil.MockAuthMiddleware(testUser))
|
||||||
|
authGroup.DELETE("/uploads/", handler.DeleteFile)
|
||||||
|
|
||||||
// Send request with empty JSON body (url field missing)
|
// Send request with empty JSON body (url field missing)
|
||||||
w := testutil.MakeRequest(e, http.MethodDelete, "/api/uploads/", map[string]string{}, "test-token")
|
w := testutil.MakeRequest(e, http.MethodDelete, "/api/uploads/", map[string]string{}, "test-token")
|
||||||
@@ -32,10 +37,16 @@ func TestDeleteFile_MissingURL_Returns400(t *testing.T) {
|
|||||||
|
|
||||||
func TestDeleteFile_EmptyURL_Returns400(t *testing.T) {
|
func TestDeleteFile_EmptyURL_Returns400(t *testing.T) {
|
||||||
storageSvc := newTestStorageService("/var/uploads")
|
storageSvc := newTestStorageService("/var/uploads")
|
||||||
handler := NewUploadHandler(storageSvc)
|
handler := NewUploadHandler(storageSvc, nil)
|
||||||
|
|
||||||
e := testutil.SetupTestRouter()
|
e := testutil.SetupTestRouter()
|
||||||
e.DELETE("/api/uploads/", handler.DeleteFile)
|
|
||||||
|
// Register route with mock auth middleware
|
||||||
|
testUser := &models.User{FirstName: "Test", Email: "test@test.com"}
|
||||||
|
testUser.ID = 1
|
||||||
|
authGroup := e.Group("/api")
|
||||||
|
authGroup.Use(testutil.MockAuthMiddleware(testUser))
|
||||||
|
authGroup.DELETE("/uploads/", handler.DeleteFile)
|
||||||
|
|
||||||
// Send request with empty url field
|
// Send request with empty url field
|
||||||
w := testutil.MakeRequest(e, http.MethodDelete, "/api/uploads/", map[string]string{"url": ""}, "test-token")
|
w := testutil.MakeRequest(e, http.MethodDelete, "/api/uploads/", map[string]string{"url": ""}, "test-token")
|
||||||
|
|||||||
@@ -111,6 +111,11 @@
|
|||||||
"error.purchase_token_required": "purchase_token is required for Android",
|
"error.purchase_token_required": "purchase_token is required for Android",
|
||||||
|
|
||||||
"error.no_file_provided": "No file provided",
|
"error.no_file_provided": "No file provided",
|
||||||
|
"error.url_required": "File URL is required",
|
||||||
|
"error.file_access_denied": "You don't have access to this file",
|
||||||
|
"error.days_out_of_range": "Days parameter must be between 1 and 3650",
|
||||||
|
"error.platform_required": "Platform is required (ios or android)",
|
||||||
|
"error.registration_id_required": "Registration ID is required",
|
||||||
|
|
||||||
"error.failed_to_fetch_residence_types": "Failed to fetch residence types",
|
"error.failed_to_fetch_residence_types": "Failed to fetch residence types",
|
||||||
"error.failed_to_fetch_task_categories": "Failed to fetch task categories",
|
"error.failed_to_fetch_task_categories": "Failed to fetch task categories",
|
||||||
|
|||||||
@@ -25,19 +25,24 @@ const (
|
|||||||
TokenCacheTTL = 5 * time.Minute
|
TokenCacheTTL = 5 * time.Minute
|
||||||
// TokenCachePrefix is the prefix for token cache keys
|
// TokenCachePrefix is the prefix for token cache keys
|
||||||
TokenCachePrefix = "auth_token_"
|
TokenCachePrefix = "auth_token_"
|
||||||
|
// UserCacheTTL is how long full user records are cached in memory to
|
||||||
|
// avoid hitting the database on every authenticated request.
|
||||||
|
UserCacheTTL = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// AuthMiddleware provides token authentication middleware
|
// AuthMiddleware provides token authentication middleware
|
||||||
type AuthMiddleware struct {
|
type AuthMiddleware struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
cache *services.CacheService
|
cache *services.CacheService
|
||||||
|
userCache *UserCache
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAuthMiddleware creates a new auth middleware instance
|
// NewAuthMiddleware creates a new auth middleware instance
|
||||||
func NewAuthMiddleware(db *gorm.DB, cache *services.CacheService) *AuthMiddleware {
|
func NewAuthMiddleware(db *gorm.DB, cache *services.CacheService) *AuthMiddleware {
|
||||||
return &AuthMiddleware{
|
return &AuthMiddleware{
|
||||||
db: db,
|
db: db,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
|
userCache: NewUserCache(UserCacheTTL),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -138,7 +143,8 @@ func extractToken(c echo.Context) (string, error) {
|
|||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getUserFromCache tries to get user from Redis cache
|
// getUserFromCache tries to get user from Redis cache, then from the
|
||||||
|
// in-memory user cache, before falling back to the database.
|
||||||
func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*models.User, error) {
|
func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*models.User, error) {
|
||||||
if m.cache == nil {
|
if m.cache == nil {
|
||||||
return nil, fmt.Errorf("cache not available")
|
return nil, fmt.Errorf("cache not available")
|
||||||
@@ -152,10 +158,20 @@ func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*m
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get user from database by ID
|
// Try in-memory user cache first to avoid a DB round-trip
|
||||||
|
if cached := m.userCache.Get(userID); cached != nil {
|
||||||
|
if !cached.IsActive {
|
||||||
|
m.cache.InvalidateAuthToken(ctx, token)
|
||||||
|
m.userCache.Invalidate(userID)
|
||||||
|
return nil, fmt.Errorf("user is inactive")
|
||||||
|
}
|
||||||
|
return cached, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// In-memory cache miss — fetch from database
|
||||||
var user models.User
|
var user models.User
|
||||||
if err := m.db.First(&user, userID).Error; err != nil {
|
if err := m.db.First(&user, userID).Error; err != nil {
|
||||||
// User was deleted - invalidate cache
|
// User was deleted - invalidate caches
|
||||||
m.cache.InvalidateAuthToken(ctx, token)
|
m.cache.InvalidateAuthToken(ctx, token)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -166,10 +182,13 @@ func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*m
|
|||||||
return nil, fmt.Errorf("user is inactive")
|
return nil, fmt.Errorf("user is inactive")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store in in-memory cache for subsequent requests
|
||||||
|
m.userCache.Set(&user)
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getUserFromDatabase looks up the token in the database
|
// getUserFromDatabase looks up the token in the database and caches the
|
||||||
|
// resulting user record in memory.
|
||||||
func (m *AuthMiddleware) getUserFromDatabase(token string) (*models.User, error) {
|
func (m *AuthMiddleware) getUserFromDatabase(token string) (*models.User, error) {
|
||||||
var authToken models.AuthToken
|
var authToken models.AuthToken
|
||||||
if err := m.db.Preload("User").Where("key = ?", token).First(&authToken).Error; err != nil {
|
if err := m.db.Preload("User").Where("key = ?", token).First(&authToken).Error; err != nil {
|
||||||
@@ -181,6 +200,8 @@ func (m *AuthMiddleware) getUserFromDatabase(token string) (*models.User, error)
|
|||||||
return nil, fmt.Errorf("user is inactive")
|
return nil, fmt.Errorf("user is inactive")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store in in-memory cache for subsequent requests
|
||||||
|
m.userCache.Set(&authToken.User)
|
||||||
return &authToken.User, nil
|
return &authToken.User, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -220,7 +241,11 @@ func GetAuthToken(c echo.Context) string {
|
|||||||
if token == nil {
|
if token == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return token.(string)
|
tokenStr, ok := token.(string)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return tokenStr
|
||||||
}
|
}
|
||||||
|
|
||||||
// MustGetAuthUser retrieves the authenticated user or returns error with 401
|
// MustGetAuthUser retrieves the authenticated user or returns error with 401
|
||||||
|
|||||||
40
internal/middleware/host_check.go
Normal file
40
internal/middleware/host_check.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
|
||||||
|
"github.com/treytartt/honeydue-api/internal/dto/responses"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HostCheck returns middleware that validates the request Host header against
|
||||||
|
// a set of allowed hosts. This prevents SSRF attacks where an attacker crafts
|
||||||
|
// a request with an arbitrary Host header to reach internal services via the
|
||||||
|
// reverse proxy.
|
||||||
|
//
|
||||||
|
// If allowedHosts is empty the middleware is a no-op (all hosts pass).
|
||||||
|
func HostCheck(allowedHosts []string) echo.MiddlewareFunc {
|
||||||
|
allowed := make(map[string]struct{}, len(allowedHosts))
|
||||||
|
for _, h := range allowedHosts {
|
||||||
|
allowed[h] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
// If no allowed hosts configured, skip the check
|
||||||
|
if len(allowed) == 0 {
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
host := c.Request().Host
|
||||||
|
if _, ok := allowed[host]; !ok {
|
||||||
|
return c.JSON(http.StatusForbidden, responses.ErrorResponse{
|
||||||
|
Error: "Forbidden",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
68
internal/middleware/rate_limit.go
Normal file
68
internal/middleware/rate_limit.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/labstack/echo/v4/middleware"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
|
||||||
|
"github.com/treytartt/honeydue-api/internal/dto/responses"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthRateLimiter returns rate-limiting middleware tuned for authentication
|
||||||
|
// endpoints. It uses Echo's built-in in-memory rate limiter keyed by client
|
||||||
|
// IP address.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - ratePerSecond: sustained request rate (e.g., 10/60.0 for ~10 per minute)
|
||||||
|
// - burst: maximum burst size above the sustained rate
|
||||||
|
func AuthRateLimiter(ratePerSecond rate.Limit, burst int) echo.MiddlewareFunc {
|
||||||
|
store := middleware.NewRateLimiterMemoryStoreWithConfig(
|
||||||
|
middleware.RateLimiterMemoryStoreConfig{
|
||||||
|
Rate: ratePerSecond,
|
||||||
|
Burst: burst,
|
||||||
|
ExpiresIn: 5 * time.Minute,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return middleware.RateLimiterWithConfig(middleware.RateLimiterConfig{
|
||||||
|
Skipper: middleware.DefaultSkipper,
|
||||||
|
IdentifierExtractor: func(c echo.Context) (string, error) {
|
||||||
|
return c.RealIP(), nil
|
||||||
|
},
|
||||||
|
Store: store,
|
||||||
|
DenyHandler: func(c echo.Context, _ string, _ error) error {
|
||||||
|
return c.JSON(http.StatusTooManyRequests, responses.ErrorResponse{
|
||||||
|
Error: "Too many requests. Please try again later.",
|
||||||
|
})
|
||||||
|
},
|
||||||
|
ErrorHandler: func(c echo.Context, err error) error {
|
||||||
|
return c.JSON(http.StatusForbidden, responses.ErrorResponse{
|
||||||
|
Error: "Unable to process request.",
|
||||||
|
})
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginRateLimiter returns rate-limiting middleware for login endpoints.
|
||||||
|
// Allows 10 requests per minute with a burst of 5.
|
||||||
|
func LoginRateLimiter() echo.MiddlewareFunc {
|
||||||
|
// 10 requests per 60 seconds = ~0.167 req/s, burst 5
|
||||||
|
return AuthRateLimiter(rate.Limit(10.0/60.0), 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegistrationRateLimiter returns rate-limiting middleware for registration
|
||||||
|
// endpoints. Allows 5 requests per minute with a burst of 3.
|
||||||
|
func RegistrationRateLimiter() echo.MiddlewareFunc {
|
||||||
|
// 5 requests per 60 seconds = ~0.083 req/s, burst 3
|
||||||
|
return AuthRateLimiter(rate.Limit(5.0/60.0), 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasswordResetRateLimiter returns rate-limiting middleware for password
|
||||||
|
// reset endpoints. Allows 3 requests per minute with a burst of 2.
|
||||||
|
func PasswordResetRateLimiter() echo.MiddlewareFunc {
|
||||||
|
// 3 requests per 60 seconds = 0.05 req/s, burst 2
|
||||||
|
return AuthRateLimiter(rate.Limit(3.0/60.0), 2)
|
||||||
|
}
|
||||||
@@ -7,14 +7,25 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// TimezoneKey is the key used to store the user's timezone in the context
|
// TimezoneKey is the key used to store the user's timezone *time.Location in the context
|
||||||
TimezoneKey = "user_timezone"
|
TimezoneKey = "user_timezone"
|
||||||
|
// TimezoneNameKey stores the raw IANA timezone string from the request header
|
||||||
|
TimezoneNameKey = "user_timezone_name"
|
||||||
|
// TimezoneChangedKey is a bool context key indicating whether the timezone
|
||||||
|
// differs from the previously cached value for this user. Handlers should
|
||||||
|
// only persist the timezone to DB when this is true.
|
||||||
|
TimezoneChangedKey = "timezone_changed"
|
||||||
// UserNowKey is the key used to store the timezone-aware "now" time in the context
|
// UserNowKey is the key used to store the timezone-aware "now" time in the context
|
||||||
UserNowKey = "user_now"
|
UserNowKey = "user_now"
|
||||||
// TimezoneHeader is the HTTP header name for the user's timezone
|
// TimezoneHeader is the HTTP header name for the user's timezone
|
||||||
TimezoneHeader = "X-Timezone"
|
TimezoneHeader = "X-Timezone"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// package-level timezone cache shared across requests. It is safe for
|
||||||
|
// concurrent use and has no TTL — entries are only updated when a new
|
||||||
|
// timezone value is observed for a given user.
|
||||||
|
var tzCache = NewTimezoneCache()
|
||||||
|
|
||||||
// TimezoneMiddleware extracts the user's timezone from the request header
|
// TimezoneMiddleware extracts the user's timezone from the request header
|
||||||
// and stores a timezone-aware "now" time in the context.
|
// and stores a timezone-aware "now" time in the context.
|
||||||
//
|
//
|
||||||
@@ -22,14 +33,31 @@ const (
|
|||||||
// or a UTC offset (e.g., "-08:00", "+05:30").
|
// or a UTC offset (e.g., "-08:00", "+05:30").
|
||||||
//
|
//
|
||||||
// If no timezone is provided or it's invalid, UTC is used as the default.
|
// If no timezone is provided or it's invalid, UTC is used as the default.
|
||||||
|
//
|
||||||
|
// The middleware also compares the incoming timezone with a cached value per
|
||||||
|
// user and sets TimezoneChangedKey in the context so downstream handlers
|
||||||
|
// know whether a DB write is needed.
|
||||||
func TimezoneMiddleware() echo.MiddlewareFunc {
|
func TimezoneMiddleware() echo.MiddlewareFunc {
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
tzName := c.Request().Header.Get(TimezoneHeader)
|
tzName := c.Request().Header.Get(TimezoneHeader)
|
||||||
loc := parseTimezone(tzName)
|
loc := parseTimezone(tzName)
|
||||||
|
|
||||||
// Store the location and the current time in that timezone
|
// Store the location and the raw name in the context
|
||||||
c.Set(TimezoneKey, loc)
|
c.Set(TimezoneKey, loc)
|
||||||
|
c.Set(TimezoneNameKey, tzName)
|
||||||
|
|
||||||
|
// Determine whether the timezone changed for this user so handlers
|
||||||
|
// can skip unnecessary DB writes.
|
||||||
|
changed := false
|
||||||
|
if tzName != "" {
|
||||||
|
if user := GetAuthUser(c); user != nil {
|
||||||
|
if !tzCache.GetAndCompare(user.ID, tzName) {
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Set(TimezoneChangedKey, changed)
|
||||||
|
|
||||||
// Calculate "now" in the user's timezone, then get start of day
|
// Calculate "now" in the user's timezone, then get start of day
|
||||||
// For date comparisons, we want to compare against the START of the user's current day
|
// For date comparisons, we want to compare against the START of the user's current day
|
||||||
@@ -42,6 +70,20 @@ func TimezoneMiddleware() echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsTimezoneChanged returns true when the user's timezone header differs from
|
||||||
|
// the previously observed value. Handlers should only persist the timezone to
|
||||||
|
// DB when this returns true.
|
||||||
|
func IsTimezoneChanged(c echo.Context) bool {
|
||||||
|
val, ok := c.Get(TimezoneChangedKey).(bool)
|
||||||
|
return ok && val
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTimezoneName returns the raw timezone string from the request header.
|
||||||
|
func GetTimezoneName(c echo.Context) string {
|
||||||
|
val, _ := c.Get(TimezoneNameKey).(string)
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
// parseTimezone parses a timezone string and returns a *time.Location.
|
// parseTimezone parses a timezone string and returns a *time.Location.
|
||||||
// Supports IANA timezone names (e.g., "America/Los_Angeles") and
|
// Supports IANA timezone names (e.g., "America/Los_Angeles") and
|
||||||
// UTC offsets (e.g., "-08:00", "+05:30").
|
// UTC offsets (e.g., "-08:00", "+05:30").
|
||||||
|
|||||||
113
internal/middleware/user_cache.go
Normal file
113
internal/middleware/user_cache.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/treytartt/honeydue-api/internal/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// userCacheEntry holds a cached user record with an expiration time.
|
||||||
|
type userCacheEntry struct {
|
||||||
|
user *models.User
|
||||||
|
expiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserCache is a concurrency-safe in-memory cache for User records, keyed by
|
||||||
|
// user ID. Entries expire after a configurable TTL. The cache uses a sync.Map
|
||||||
|
// for lock-free reads on the hot path, with periodic lazy eviction of stale
|
||||||
|
// entries during Set operations.
|
||||||
|
type UserCache struct {
|
||||||
|
store sync.Map
|
||||||
|
ttl time.Duration
|
||||||
|
lastGC time.Time
|
||||||
|
gcMu sync.Mutex
|
||||||
|
gcEvery time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUserCache creates a UserCache with the given TTL for entries.
|
||||||
|
func NewUserCache(ttl time.Duration) *UserCache {
|
||||||
|
return &UserCache{
|
||||||
|
ttl: ttl,
|
||||||
|
lastGC: time.Now(),
|
||||||
|
gcEvery: 2 * time.Minute,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a cached user by ID, or nil if not found or expired.
|
||||||
|
func (c *UserCache) Get(userID uint) *models.User {
|
||||||
|
val, ok := c.store.Load(userID)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
entry := val.(*userCacheEntry)
|
||||||
|
if time.Now().After(entry.expiresAt) {
|
||||||
|
c.store.Delete(userID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Return a shallow copy so callers cannot mutate the cached value.
|
||||||
|
user := *entry.user
|
||||||
|
return &user
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set stores a user in the cache. It also triggers a background garbage-
|
||||||
|
// collection sweep if enough time has elapsed since the last one.
|
||||||
|
func (c *UserCache) Set(user *models.User) {
|
||||||
|
// Store a copy to prevent external mutation of the cached object.
|
||||||
|
copied := *user
|
||||||
|
c.store.Store(user.ID, &userCacheEntry{
|
||||||
|
user: &copied,
|
||||||
|
expiresAt: time.Now().Add(c.ttl),
|
||||||
|
})
|
||||||
|
c.maybeGC()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalidate removes a user from the cache by ID.
|
||||||
|
func (c *UserCache) Invalidate(userID uint) {
|
||||||
|
c.store.Delete(userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// maybeGC lazily sweeps expired entries at most once per gcEvery interval.
|
||||||
|
func (c *UserCache) maybeGC() {
|
||||||
|
c.gcMu.Lock()
|
||||||
|
if time.Since(c.lastGC) < c.gcEvery {
|
||||||
|
c.gcMu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.lastGC = time.Now()
|
||||||
|
c.gcMu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
c.store.Range(func(key, value any) bool {
|
||||||
|
entry := value.(*userCacheEntry)
|
||||||
|
if now.After(entry.expiresAt) {
|
||||||
|
c.store.Delete(key)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimezoneCache tracks the last-known timezone per user ID so the timezone
|
||||||
|
// middleware only writes to the database when the value actually changes.
|
||||||
|
type TimezoneCache struct {
|
||||||
|
store sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTimezoneCache creates a new TimezoneCache.
|
||||||
|
func NewTimezoneCache() *TimezoneCache {
|
||||||
|
return &TimezoneCache{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAndCompare returns true if the cached timezone for the user matches tz.
|
||||||
|
// If the timezone is different (or not yet cached), it updates the cache and
|
||||||
|
// returns false, signaling that a DB write is needed.
|
||||||
|
func (tc *TimezoneCache) GetAndCompare(userID uint, tz string) (unchanged bool) {
|
||||||
|
val, loaded := tc.store.Load(userID)
|
||||||
|
if loaded {
|
||||||
|
if cached, ok := val.(string); ok && cached == tz {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tc.store.Store(userID, tz)
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -13,12 +13,6 @@ type BaseModel struct {
|
|||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SoftDeleteModel extends BaseModel with soft delete support
|
|
||||||
type SoftDeleteModel struct {
|
|
||||||
BaseModel
|
|
||||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// BeforeCreate sets timestamps before creating a record
|
// BeforeCreate sets timestamps before creating a record
|
||||||
func (b *BaseModel) BeforeCreate(tx *gorm.DB) error {
|
func (b *BaseModel) BeforeCreate(tx *gorm.DB) error {
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
|||||||
@@ -29,10 +29,10 @@ type NotificationPreference struct {
|
|||||||
|
|
||||||
// Custom notification times (nullable, stored as UTC hour 0-23)
|
// Custom notification times (nullable, stored as UTC hour 0-23)
|
||||||
// When nil, system defaults from config are used
|
// When nil, system defaults from config are used
|
||||||
TaskDueSoonHour *int `gorm:"column:task_due_soon_hour" json:"task_due_soon_hour"`
|
TaskDueSoonHour *int `gorm:"column:task_due_soon_hour" json:"task_due_soon_hour" validate:"omitempty,min=0,max=23"`
|
||||||
TaskOverdueHour *int `gorm:"column:task_overdue_hour" json:"task_overdue_hour"`
|
TaskOverdueHour *int `gorm:"column:task_overdue_hour" json:"task_overdue_hour" validate:"omitempty,min=0,max=23"`
|
||||||
WarrantyExpiringHour *int `gorm:"column:warranty_expiring_hour" json:"warranty_expiring_hour"`
|
WarrantyExpiringHour *int `gorm:"column:warranty_expiring_hour" json:"warranty_expiring_hour" validate:"omitempty,min=0,max=23"`
|
||||||
DailyDigestHour *int `gorm:"column:daily_digest_hour" json:"daily_digest_hour"`
|
DailyDigestHour *int `gorm:"column:daily_digest_hour" json:"daily_digest_hour" validate:"omitempty,min=0,max=23"`
|
||||||
|
|
||||||
// User timezone for background job calculations (IANA name, e.g., "America/Los_Angeles")
|
// User timezone for background job calculations (IANA name, e.g., "America/Los_Angeles")
|
||||||
// Auto-captured from X-Timezone header on API calls
|
// Auto-captured from X-Timezone header on API calls
|
||||||
|
|||||||
@@ -181,88 +181,6 @@ func (t *Task) IsDueSoon(days int) bool {
|
|||||||
return !effectiveDate.Before(now) && effectiveDate.Before(threshold)
|
return !effectiveDate.Before(now) && effectiveDate.Before(threshold)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetKanbanColumn returns the kanban column name for this task using the
|
|
||||||
// Chain of Responsibility pattern from the categorization package.
|
|
||||||
// Uses UTC time for categorization.
|
|
||||||
//
|
|
||||||
// For timezone-aware categorization, use GetKanbanColumnWithTimezone.
|
|
||||||
func (t *Task) GetKanbanColumn(daysThreshold int) string {
|
|
||||||
// Import would cause circular dependency, so we inline the logic
|
|
||||||
// This delegates to the categorization package via internal/task re-export
|
|
||||||
return t.GetKanbanColumnWithTimezone(daysThreshold, time.Now().UTC())
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetKanbanColumnWithTimezone returns the kanban column name using a specific
|
|
||||||
// time (in the user's timezone). The time is used to determine "today" for
|
|
||||||
// overdue/due-soon calculations.
|
|
||||||
//
|
|
||||||
// Example: For a user in Tokyo, pass time.Now().In(tokyoLocation) to get
|
|
||||||
// accurate categorization relative to their local date.
|
|
||||||
func (t *Task) GetKanbanColumnWithTimezone(daysThreshold int, now time.Time) string {
|
|
||||||
// Note: We can't import categorization directly due to circular dependency.
|
|
||||||
// Instead, this method implements the categorization logic inline.
|
|
||||||
// The logic MUST match categorization.Chain exactly.
|
|
||||||
|
|
||||||
if daysThreshold <= 0 {
|
|
||||||
daysThreshold = 30
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start of day normalization
|
|
||||||
startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
|
||||||
threshold := startOfDay.AddDate(0, 0, daysThreshold)
|
|
||||||
|
|
||||||
// Priority 1: Cancelled
|
|
||||||
if t.IsCancelled {
|
|
||||||
return "cancelled_tasks"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Priority 2: Archived (goes to cancelled column - both are "inactive" states)
|
|
||||||
if t.IsArchived {
|
|
||||||
return "cancelled_tasks"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Priority 3: Completed (NextDueDate nil with completions)
|
|
||||||
hasCompletions := len(t.Completions) > 0 || t.CompletionCount > 0
|
|
||||||
if t.NextDueDate == nil && hasCompletions {
|
|
||||||
return "completed_tasks"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Priority 4: In Progress
|
|
||||||
if t.InProgress {
|
|
||||||
return "in_progress_tasks"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get effective date: NextDueDate ?? DueDate
|
|
||||||
var effectiveDate *time.Time
|
|
||||||
if t.NextDueDate != nil {
|
|
||||||
effectiveDate = t.NextDueDate
|
|
||||||
} else {
|
|
||||||
effectiveDate = t.DueDate
|
|
||||||
}
|
|
||||||
|
|
||||||
if effectiveDate != nil {
|
|
||||||
// Normalize effective date to same timezone for calendar date comparison
|
|
||||||
// Task dates are stored as UTC but represent calendar dates (YYYY-MM-DD)
|
|
||||||
normalizedEffective := time.Date(
|
|
||||||
effectiveDate.Year(), effectiveDate.Month(), effectiveDate.Day(),
|
|
||||||
0, 0, 0, 0, now.Location(),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Priority 5: Overdue (effective date before today)
|
|
||||||
if normalizedEffective.Before(startOfDay) {
|
|
||||||
return "overdue_tasks"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Priority 6: Due Soon (effective date before threshold)
|
|
||||||
if normalizedEffective.Before(threshold) {
|
|
||||||
return "due_soon_tasks"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Priority 7: Upcoming (default)
|
|
||||||
return "upcoming_tasks"
|
|
||||||
}
|
|
||||||
|
|
||||||
// TaskCompletion represents the task_taskcompletion table
|
// TaskCompletion represents the task_taskcompletion table
|
||||||
type TaskCompletion struct {
|
type TaskCompletion struct {
|
||||||
BaseModel
|
BaseModel
|
||||||
|
|||||||
@@ -248,180 +248,10 @@ func TestDocument_JSONSerialization(t *testing.T) {
|
|||||||
assert.Equal(t, "5000", result["purchase_price"]) // Decimal serializes as string
|
assert.Equal(t, "5000", result["purchase_price"]) // Decimal serializes as string
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================================================
|
|
||||||
// TASK KANBAN COLUMN TESTS
|
|
||||||
// These tests verify GetKanbanColumn and GetKanbanColumnWithTimezone methods
|
|
||||||
// ============================================================================
|
|
||||||
|
|
||||||
func timePtr(t time.Time) *time.Time {
|
func timePtr(t time.Time) *time.Time {
|
||||||
return &t
|
return &t
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTask_GetKanbanColumn_PriorityOrder(t *testing.T) {
|
|
||||||
now := time.Date(2025, 12, 16, 12, 0, 0, 0, time.UTC)
|
|
||||||
yesterday := time.Date(2025, 12, 15, 0, 0, 0, 0, time.UTC)
|
|
||||||
in5Days := time.Date(2025, 12, 21, 0, 0, 0, 0, time.UTC)
|
|
||||||
in60Days := time.Date(2026, 2, 14, 0, 0, 0, 0, time.UTC)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
task *Task
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
// Priority 1: Cancelled
|
|
||||||
{
|
|
||||||
name: "cancelled takes highest priority",
|
|
||||||
task: &Task{
|
|
||||||
IsCancelled: true,
|
|
||||||
NextDueDate: timePtr(yesterday),
|
|
||||||
InProgress: true,
|
|
||||||
},
|
|
||||||
expected: "cancelled_tasks",
|
|
||||||
},
|
|
||||||
|
|
||||||
// Priority 2: Completed
|
|
||||||
{
|
|
||||||
name: "completed: NextDueDate nil with completions",
|
|
||||||
task: &Task{
|
|
||||||
IsCancelled: false,
|
|
||||||
NextDueDate: nil,
|
|
||||||
DueDate: timePtr(yesterday),
|
|
||||||
Completions: []TaskCompletion{{BaseModel: BaseModel{ID: 1}}},
|
|
||||||
},
|
|
||||||
expected: "completed_tasks",
|
|
||||||
},
|
|
||||||
|
|
||||||
// Priority 3: In Progress
|
|
||||||
{
|
|
||||||
name: "in progress takes priority over overdue",
|
|
||||||
task: &Task{
|
|
||||||
IsCancelled: false,
|
|
||||||
NextDueDate: timePtr(yesterday),
|
|
||||||
InProgress: true,
|
|
||||||
},
|
|
||||||
expected: "in_progress_tasks",
|
|
||||||
},
|
|
||||||
|
|
||||||
// Priority 4: Overdue
|
|
||||||
{
|
|
||||||
name: "overdue: effective date in past",
|
|
||||||
task: &Task{
|
|
||||||
IsCancelled: false,
|
|
||||||
NextDueDate: timePtr(yesterday),
|
|
||||||
},
|
|
||||||
expected: "overdue_tasks",
|
|
||||||
},
|
|
||||||
|
|
||||||
// Priority 5: Due Soon
|
|
||||||
{
|
|
||||||
name: "due soon: within 30-day threshold",
|
|
||||||
task: &Task{
|
|
||||||
IsCancelled: false,
|
|
||||||
NextDueDate: timePtr(in5Days),
|
|
||||||
},
|
|
||||||
expected: "due_soon_tasks",
|
|
||||||
},
|
|
||||||
|
|
||||||
// Priority 6: Upcoming
|
|
||||||
{
|
|
||||||
name: "upcoming: beyond threshold",
|
|
||||||
task: &Task{
|
|
||||||
IsCancelled: false,
|
|
||||||
NextDueDate: timePtr(in60Days),
|
|
||||||
},
|
|
||||||
expected: "upcoming_tasks",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "upcoming: no due date",
|
|
||||||
task: &Task{
|
|
||||||
IsCancelled: false,
|
|
||||||
NextDueDate: nil,
|
|
||||||
DueDate: nil,
|
|
||||||
},
|
|
||||||
expected: "upcoming_tasks",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := tt.task.GetKanbanColumnWithTimezone(30, now)
|
|
||||||
assert.Equal(t, tt.expected, result)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTask_GetKanbanColumnWithTimezone_TimezoneAware(t *testing.T) {
|
|
||||||
// Task due Dec 17, 2025
|
|
||||||
taskDueDate := time.Date(2025, 12, 17, 0, 0, 0, 0, time.UTC)
|
|
||||||
|
|
||||||
task := &Task{
|
|
||||||
NextDueDate: timePtr(taskDueDate),
|
|
||||||
IsCancelled: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
// At 11 PM UTC on Dec 16 (UTC user) - task is tomorrow, due_soon
|
|
||||||
utcDec16Evening := time.Date(2025, 12, 16, 23, 0, 0, 0, time.UTC)
|
|
||||||
result := task.GetKanbanColumnWithTimezone(30, utcDec16Evening)
|
|
||||||
assert.Equal(t, "due_soon_tasks", result, "UTC Dec 16 evening")
|
|
||||||
|
|
||||||
// At 8 AM UTC on Dec 17 (UTC user) - task is today, due_soon
|
|
||||||
utcDec17Morning := time.Date(2025, 12, 17, 8, 0, 0, 0, time.UTC)
|
|
||||||
result = task.GetKanbanColumnWithTimezone(30, utcDec17Morning)
|
|
||||||
assert.Equal(t, "due_soon_tasks", result, "UTC Dec 17 morning")
|
|
||||||
|
|
||||||
// At 8 AM UTC on Dec 18 (UTC user) - task was yesterday, overdue
|
|
||||||
utcDec18Morning := time.Date(2025, 12, 18, 8, 0, 0, 0, time.UTC)
|
|
||||||
result = task.GetKanbanColumnWithTimezone(30, utcDec18Morning)
|
|
||||||
assert.Equal(t, "overdue_tasks", result, "UTC Dec 18 morning")
|
|
||||||
|
|
||||||
// Tokyo user at 11 PM UTC Dec 16 = 8 AM Dec 17 Tokyo
|
|
||||||
// Task due Dec 17 is TODAY for Tokyo user - due_soon
|
|
||||||
tokyo, _ := time.LoadLocation("Asia/Tokyo")
|
|
||||||
tokyoDec17Morning := utcDec16Evening.In(tokyo)
|
|
||||||
result = task.GetKanbanColumnWithTimezone(30, tokyoDec17Morning)
|
|
||||||
assert.Equal(t, "due_soon_tasks", result, "Tokyo Dec 17 morning")
|
|
||||||
|
|
||||||
// Tokyo at 8 AM Dec 18 UTC = 5 PM Dec 18 Tokyo
|
|
||||||
// Task due Dec 17 was YESTERDAY for Tokyo - overdue
|
|
||||||
tokyoDec18 := utcDec18Morning.In(tokyo)
|
|
||||||
result = task.GetKanbanColumnWithTimezone(30, tokyoDec18)
|
|
||||||
assert.Equal(t, "overdue_tasks", result, "Tokyo Dec 18")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTask_GetKanbanColumnWithTimezone_DueSoonThreshold(t *testing.T) {
|
|
||||||
now := time.Date(2025, 12, 16, 12, 0, 0, 0, time.UTC)
|
|
||||||
|
|
||||||
// Task due in 29 days - within 30-day threshold
|
|
||||||
due29Days := time.Date(2026, 1, 14, 0, 0, 0, 0, time.UTC)
|
|
||||||
task29 := &Task{NextDueDate: timePtr(due29Days)}
|
|
||||||
result := task29.GetKanbanColumnWithTimezone(30, now)
|
|
||||||
assert.Equal(t, "due_soon_tasks", result, "29 days should be due_soon")
|
|
||||||
|
|
||||||
// Task due in exactly 30 days - at threshold boundary (upcoming, not due_soon)
|
|
||||||
due30Days := time.Date(2026, 1, 15, 0, 0, 0, 0, time.UTC)
|
|
||||||
task30 := &Task{NextDueDate: timePtr(due30Days)}
|
|
||||||
result = task30.GetKanbanColumnWithTimezone(30, now)
|
|
||||||
assert.Equal(t, "upcoming_tasks", result, "30 days should be upcoming (at boundary)")
|
|
||||||
|
|
||||||
// Task due in 31 days - beyond threshold
|
|
||||||
due31Days := time.Date(2026, 1, 16, 0, 0, 0, 0, time.UTC)
|
|
||||||
task31 := &Task{NextDueDate: timePtr(due31Days)}
|
|
||||||
result = task31.GetKanbanColumnWithTimezone(30, now)
|
|
||||||
assert.Equal(t, "upcoming_tasks", result, "31 days should be upcoming")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTask_GetKanbanColumn_CompletionCount(t *testing.T) {
|
|
||||||
// Test that CompletionCount is also used for completion detection
|
|
||||||
task := &Task{
|
|
||||||
NextDueDate: nil,
|
|
||||||
CompletionCount: 1, // Using CompletionCount instead of Completions slice
|
|
||||||
Completions: []TaskCompletion{},
|
|
||||||
}
|
|
||||||
|
|
||||||
result := task.GetKanbanColumn(30)
|
|
||||||
assert.Equal(t, "completed_tasks", result)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTask_IsOverdueAt_DayBased(t *testing.T) {
|
func TestTask_IsOverdueAt_DayBased(t *testing.T) {
|
||||||
// Test that IsOverdueAt uses day-based comparison
|
// Test that IsOverdueAt uses day-based comparison
|
||||||
now := time.Date(2025, 12, 16, 15, 0, 0, 0, time.UTC) // 3 PM UTC
|
now := time.Date(2025, 12, 16, 15, 0, 0, 0, time.UTC) // 3 PM UTC
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ type User struct {
|
|||||||
Username string `gorm:"column:username;uniqueIndex;size:150;not null" json:"username"`
|
Username string `gorm:"column:username;uniqueIndex;size:150;not null" json:"username"`
|
||||||
FirstName string `gorm:"column:first_name;size:150" json:"first_name"`
|
FirstName string `gorm:"column:first_name;size:150" json:"first_name"`
|
||||||
LastName string `gorm:"column:last_name;size:150" json:"last_name"`
|
LastName string `gorm:"column:last_name;size:150" json:"last_name"`
|
||||||
Email string `gorm:"column:email;size:254" json:"email"`
|
Email string `gorm:"column:email;uniqueIndex;size:254" json:"email"`
|
||||||
IsStaff bool `gorm:"column:is_staff;default:false" json:"is_staff"`
|
IsStaff bool `gorm:"column:is_staff;default:false" json:"is_staff"`
|
||||||
IsActive bool `gorm:"column:is_active;default:true" json:"is_active"`
|
IsActive bool `gorm:"column:is_active;default:true" json:"is_active"`
|
||||||
DateJoined time.Time `gorm:"column:date_joined;autoCreateTime" json:"date_joined"`
|
DateJoined time.Time `gorm:"column:date_joined;autoCreateTime" json:"date_joined"`
|
||||||
@@ -142,7 +142,7 @@ func (UserProfile) TableName() string {
|
|||||||
type ConfirmationCode struct {
|
type ConfirmationCode struct {
|
||||||
BaseModel
|
BaseModel
|
||||||
UserID uint `gorm:"column:user_id;index;not null" json:"user_id"`
|
UserID uint `gorm:"column:user_id;index;not null" json:"user_id"`
|
||||||
Code string `gorm:"column:code;size:6;not null" json:"code"`
|
Code string `gorm:"column:code;size:6;not null" json:"-"`
|
||||||
ExpiresAt time.Time `gorm:"column:expires_at;not null" json:"expires_at"`
|
ExpiresAt time.Time `gorm:"column:expires_at;not null" json:"expires_at"`
|
||||||
IsUsed bool `gorm:"column:is_used;default:false" json:"is_used"`
|
IsUsed bool `gorm:"column:is_used;default:false" json:"is_used"`
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package monitoring
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hibiken/asynq"
|
"github.com/hibiken/asynq"
|
||||||
@@ -20,6 +21,7 @@ type Collector struct {
|
|||||||
httpCollector *HTTPStatsCollector // nil for worker
|
httpCollector *HTTPStatsCollector // nil for worker
|
||||||
asynqClient *asynq.Inspector // nil for api
|
asynqClient *asynq.Inspector // nil for api
|
||||||
stopChan chan struct{}
|
stopChan chan struct{}
|
||||||
|
stopOnce sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCollector creates a new stats collector
|
// NewCollector creates a new stats collector
|
||||||
@@ -193,7 +195,9 @@ func (c *Collector) publishStats() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop stops the stats publishing
|
// Stop stops the stats publishing. It is safe to call multiple times.
|
||||||
func (c *Collector) Stop() {
|
func (c *Collector) Stop() {
|
||||||
close(c.stopChan)
|
c.stopOnce.Do(func() {
|
||||||
|
close(c.stopChan)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -150,7 +150,10 @@ func (h *Handler) WebSocket(c echo.Context) error {
|
|||||||
defer statsTicker.Stop()
|
defer statsTicker.Stop()
|
||||||
|
|
||||||
// Send initial stats
|
// Send initial stats
|
||||||
h.sendStats(conn, &wsMu)
|
if err := h.sendStats(conn, &wsMu); err != nil {
|
||||||
|
cancel()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -173,11 +176,16 @@ func (h *Handler) WebSocket(c echo.Context) error {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug().Err(err).Msg("WebSocket write error")
|
log.Debug().Err(err).Msg("WebSocket write error")
|
||||||
|
cancel()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-statsTicker.C:
|
case <-statsTicker.C:
|
||||||
// Send periodic stats update
|
// Send periodic stats update
|
||||||
h.sendStats(conn, &wsMu)
|
if err := h.sendStats(conn, &wsMu); err != nil {
|
||||||
|
cancel()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil
|
return nil
|
||||||
@@ -185,9 +193,11 @@ func (h *Handler) WebSocket(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) sendStats(conn *websocket.Conn, mu *sync.Mutex) {
|
func (h *Handler) sendStats(conn *websocket.Conn, mu *sync.Mutex) error {
|
||||||
allStats, err := h.statsStore.GetAllStats()
|
allStats, err := h.statsStore.GetAllStats()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to send stats")
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
wsMsg := WSMessage{
|
wsMsg := WSMessage{
|
||||||
@@ -196,6 +206,12 @@ func (h *Handler) sendStats(conn *websocket.Conn, mu *sync.Mutex) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
conn.WriteJSON(wsMsg)
|
err = conn.WriteJSON(wsMsg)
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Debug().Err(err).Msg("WebSocket write error sending stats")
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,39 +10,33 @@ import (
|
|||||||
|
|
||||||
// HTTPStatsCollector collects HTTP request metrics
|
// HTTPStatsCollector collects HTTP request metrics
|
||||||
type HTTPStatsCollector struct {
|
type HTTPStatsCollector struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
requests map[string]int64 // endpoint -> count
|
requests map[string]int64 // endpoint -> count
|
||||||
totalLatency map[string]time.Duration // endpoint -> total latency
|
totalLatency map[string]time.Duration // endpoint -> total latency
|
||||||
errors map[string]int64 // endpoint -> error count
|
errors map[string]int64 // endpoint -> error count
|
||||||
byStatus map[int]int64 // status code -> count
|
byStatus map[int]int64 // status code -> count
|
||||||
latencies []latencySample // recent latency samples for P95
|
endpointLatencies map[string][]time.Duration // per-endpoint sorted latency buffers for P95
|
||||||
startTime time.Time
|
startTime time.Time
|
||||||
lastReset time.Time
|
lastReset time.Time
|
||||||
}
|
|
||||||
|
|
||||||
type latencySample struct {
|
|
||||||
endpoint string
|
|
||||||
latency time.Duration
|
|
||||||
timestamp time.Time
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
maxLatencySamples = 1000
|
maxLatencySamplesPerEndpoint = 200 // Max latency samples kept per endpoint
|
||||||
maxEndpoints = 200 // Cap unique endpoints tracked
|
maxEndpoints = 200 // Cap unique endpoints tracked
|
||||||
statsResetPeriod = 1 * time.Hour // Reset stats periodically to prevent unbounded growth
|
statsResetPeriod = 1 * time.Hour // Reset stats periodically to prevent unbounded growth
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewHTTPStatsCollector creates a new HTTP stats collector
|
// NewHTTPStatsCollector creates a new HTTP stats collector
|
||||||
func NewHTTPStatsCollector() *HTTPStatsCollector {
|
func NewHTTPStatsCollector() *HTTPStatsCollector {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
return &HTTPStatsCollector{
|
return &HTTPStatsCollector{
|
||||||
requests: make(map[string]int64),
|
requests: make(map[string]int64),
|
||||||
totalLatency: make(map[string]time.Duration),
|
totalLatency: make(map[string]time.Duration),
|
||||||
errors: make(map[string]int64),
|
errors: make(map[string]int64),
|
||||||
byStatus: make(map[int]int64),
|
byStatus: make(map[int]int64),
|
||||||
latencies: make([]latencySample, 0, maxLatencySamples),
|
endpointLatencies: make(map[string][]time.Duration),
|
||||||
startTime: now,
|
startTime: now,
|
||||||
lastReset: now,
|
lastReset: now,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,17 +64,22 @@ func (c *HTTPStatsCollector) Record(endpoint string, latency time.Duration, stat
|
|||||||
c.errors[endpoint]++
|
c.errors[endpoint]++
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store latency sample
|
// Insert latency into per-endpoint sorted buffer using binary search
|
||||||
c.latencies = append(c.latencies, latencySample{
|
buf := c.endpointLatencies[endpoint]
|
||||||
endpoint: endpoint,
|
idx := sort.Search(len(buf), func(i int) bool {
|
||||||
latency: latency,
|
return buf[i] >= latency
|
||||||
timestamp: time.Now(),
|
|
||||||
})
|
})
|
||||||
|
buf = append(buf, 0)
|
||||||
|
copy(buf[idx+1:], buf[idx:])
|
||||||
|
buf[idx] = latency
|
||||||
|
|
||||||
// Keep only recent samples
|
// Trim to max samples per endpoint by removing the median element
|
||||||
if len(c.latencies) > maxLatencySamples {
|
// to preserve distribution tails (important for P95 accuracy)
|
||||||
c.latencies = c.latencies[len(c.latencies)-maxLatencySamples:]
|
if len(buf) > maxLatencySamplesPerEndpoint {
|
||||||
|
mid := len(buf) / 2
|
||||||
|
buf = append(buf[:mid], buf[mid+1:]...)
|
||||||
}
|
}
|
||||||
|
c.endpointLatencies[endpoint] = buf
|
||||||
}
|
}
|
||||||
|
|
||||||
// resetLocked resets stats while holding the lock
|
// resetLocked resets stats while holding the lock
|
||||||
@@ -89,7 +88,7 @@ func (c *HTTPStatsCollector) resetLocked() {
|
|||||||
c.totalLatency = make(map[string]time.Duration)
|
c.totalLatency = make(map[string]time.Duration)
|
||||||
c.errors = make(map[string]int64)
|
c.errors = make(map[string]int64)
|
||||||
c.byStatus = make(map[int]int64)
|
c.byStatus = make(map[int]int64)
|
||||||
c.latencies = make([]latencySample, 0, maxLatencySamples)
|
c.endpointLatencies = make(map[string][]time.Duration)
|
||||||
c.lastReset = time.Now()
|
c.lastReset = time.Now()
|
||||||
// Keep startTime for uptime calculation
|
// Keep startTime for uptime calculation
|
||||||
}
|
}
|
||||||
@@ -147,33 +146,23 @@ func (c *HTTPStatsCollector) GetStats() HTTPStats {
|
|||||||
return stats
|
return stats
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculateP95 calculates the 95th percentile latency for an endpoint
|
// calculateP95 calculates the 95th percentile latency for an endpoint.
|
||||||
// Must be called with read lock held
|
// The per-endpoint buffer is maintained in sorted order during insertion,
|
||||||
|
// so this is an O(1) index lookup.
|
||||||
|
// Must be called with read lock held.
|
||||||
func (c *HTTPStatsCollector) calculateP95(endpoint string) float64 {
|
func (c *HTTPStatsCollector) calculateP95(endpoint string) float64 {
|
||||||
var endpointLatencies []time.Duration
|
buf := c.endpointLatencies[endpoint]
|
||||||
|
if len(buf) == 0 {
|
||||||
for _, sample := range c.latencies {
|
|
||||||
if sample.endpoint == endpoint {
|
|
||||||
endpointLatencies = append(endpointLatencies, sample.latency)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(endpointLatencies) == 0 {
|
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort latencies
|
// Buffer is already sorted; direct index lookup
|
||||||
sort.Slice(endpointLatencies, func(i, j int) bool {
|
p95Index := int(float64(len(buf)) * 0.95)
|
||||||
return endpointLatencies[i] < endpointLatencies[j]
|
if p95Index >= len(buf) {
|
||||||
})
|
p95Index = len(buf) - 1
|
||||||
|
|
||||||
// Calculate P95 index
|
|
||||||
p95Index := int(float64(len(endpointLatencies)) * 0.95)
|
|
||||||
if p95Index >= len(endpointLatencies) {
|
|
||||||
p95Index = len(endpointLatencies) - 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return float64(endpointLatencies[p95Index].Milliseconds())
|
return float64(buf[p95Index].Milliseconds())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset clears all collected stats
|
// Reset clears all collected stats
|
||||||
@@ -185,7 +174,7 @@ func (c *HTTPStatsCollector) Reset() {
|
|||||||
c.totalLatency = make(map[string]time.Duration)
|
c.totalLatency = make(map[string]time.Duration)
|
||||||
c.errors = make(map[string]int64)
|
c.errors = make(map[string]int64)
|
||||||
c.byStatus = make(map[int]int64)
|
c.byStatus = make(map[int]int64)
|
||||||
c.latencies = make([]latencySample, 0, maxLatencySamples)
|
c.endpointLatencies = make(map[string][]time.Duration)
|
||||||
c.startTime = time.Now()
|
c.startTime = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package monitoring
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hibiken/asynq"
|
"github.com/hibiken/asynq"
|
||||||
@@ -31,6 +32,8 @@ type Service struct {
|
|||||||
logWriter *RedisLogWriter
|
logWriter *RedisLogWriter
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
settingsStopCh chan struct{}
|
settingsStopCh chan struct{}
|
||||||
|
stopOnce sync.Once
|
||||||
|
statsInterval time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config holds configuration for the monitoring service
|
// Config holds configuration for the monitoring service
|
||||||
@@ -71,6 +74,7 @@ func NewService(cfg Config) *Service {
|
|||||||
logWriter: logWriter,
|
logWriter: logWriter,
|
||||||
db: cfg.DB,
|
db: cfg.DB,
|
||||||
settingsStopCh: make(chan struct{}),
|
settingsStopCh: make(chan struct{}),
|
||||||
|
statsInterval: cfg.StatsInterval,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check initial setting from database
|
// Check initial setting from database
|
||||||
@@ -90,11 +94,11 @@ func (s *Service) SetAsynqInspector(inspector *asynq.Inspector) {
|
|||||||
func (s *Service) Start() {
|
func (s *Service) Start() {
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("process", s.process).
|
Str("process", s.process).
|
||||||
Dur("interval", DefaultStatsInterval).
|
Dur("interval", s.statsInterval).
|
||||||
Bool("enabled", s.logWriter.IsEnabled()).
|
Bool("enabled", s.logWriter.IsEnabled()).
|
||||||
Msg("Starting monitoring service")
|
Msg("Starting monitoring service")
|
||||||
|
|
||||||
s.collector.StartPublishing(DefaultStatsInterval)
|
s.collector.StartPublishing(s.statsInterval)
|
||||||
|
|
||||||
// Start settings sync if database is available
|
// Start settings sync if database is available
|
||||||
if s.db != nil {
|
if s.db != nil {
|
||||||
@@ -102,17 +106,19 @@ func (s *Service) Start() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop stops the monitoring service
|
// Stop stops the monitoring service. It is safe to call multiple times.
|
||||||
func (s *Service) Stop() {
|
func (s *Service) Stop() {
|
||||||
// Stop settings sync
|
s.stopOnce.Do(func() {
|
||||||
close(s.settingsStopCh)
|
// Stop settings sync
|
||||||
|
close(s.settingsStopCh)
|
||||||
|
|
||||||
s.collector.Stop()
|
s.collector.Stop()
|
||||||
|
|
||||||
// Flush and close the log writer's background goroutine
|
// Flush and close the log writer's background goroutine
|
||||||
s.logWriter.Close()
|
s.logWriter.Close()
|
||||||
|
|
||||||
log.Info().Str("process", s.process).Msg("Monitoring service stopped")
|
log.Info().Str("process", s.process).Msg("Monitoring service stopped")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// syncSettingsFromDB checks the database for the enable_monitoring setting
|
// syncSettingsFromDB checks the database for the enable_monitoring setting
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package monitoring
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -18,11 +19,12 @@ const (
|
|||||||
// It uses a single background goroutine with a buffered channel instead of
|
// It uses a single background goroutine with a buffered channel instead of
|
||||||
// spawning a new goroutine per log line, preventing unbounded goroutine growth.
|
// spawning a new goroutine per log line, preventing unbounded goroutine growth.
|
||||||
type RedisLogWriter struct {
|
type RedisLogWriter struct {
|
||||||
buffer *LogBuffer
|
buffer *LogBuffer
|
||||||
process string
|
process string
|
||||||
enabled atomic.Bool
|
enabled atomic.Bool
|
||||||
ch chan LogEntry
|
ch chan LogEntry
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRedisLogWriter creates a new writer that captures logs to Redis.
|
// NewRedisLogWriter creates a new writer that captures logs to Redis.
|
||||||
@@ -53,9 +55,12 @@ func (w *RedisLogWriter) drainLoop() {
|
|||||||
|
|
||||||
// Close shuts down the background goroutine. It should be called during
|
// Close shuts down the background goroutine. It should be called during
|
||||||
// graceful shutdown to ensure all buffered entries are flushed.
|
// graceful shutdown to ensure all buffered entries are flushed.
|
||||||
|
// It is safe to call multiple times.
|
||||||
func (w *RedisLogWriter) Close() {
|
func (w *RedisLogWriter) Close() {
|
||||||
close(w.ch)
|
w.closeOnce.Do(func() {
|
||||||
<-w.done // Wait for drain to finish
|
close(w.ch)
|
||||||
|
<-w.done // Wait for drain to finish
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetEnabled enables or disables log capture to Redis
|
// SetEnabled enables or disables log capture to Redis
|
||||||
|
|||||||
@@ -38,16 +38,15 @@ func NewAPNsClient(cfg *config.PushConfig) (*APNsClient, error) {
|
|||||||
TeamID: cfg.APNSTeamID,
|
TeamID: cfg.APNSTeamID,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create client - production or sandbox
|
// Create client - sandbox if APNSSandbox is true, production otherwise.
|
||||||
// Use APNSProduction if set, otherwise fall back to inverse of APNSSandbox
|
// APNSSandbox is the single source of truth (defaults to true for safety).
|
||||||
var client *apns2.Client
|
var client *apns2.Client
|
||||||
useProduction := cfg.APNSProduction || !cfg.APNSSandbox
|
if cfg.APNSSandbox {
|
||||||
if useProduction {
|
|
||||||
client = apns2.NewTokenClient(authToken).Production()
|
|
||||||
log.Info().Msg("APNs client configured for PRODUCTION")
|
|
||||||
} else {
|
|
||||||
client = apns2.NewTokenClient(authToken).Development()
|
client = apns2.NewTokenClient(authToken).Development()
|
||||||
log.Info().Msg("APNs client configured for DEVELOPMENT/SANDBOX")
|
log.Info().Msg("APNs client configured for DEVELOPMENT/SANDBOX")
|
||||||
|
} else {
|
||||||
|
client = apns2.NewTokenClient(authToken).Production()
|
||||||
|
log.Info().Msg("APNs client configured for PRODUCTION")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &APNsClient{
|
return &APNsClient{
|
||||||
|
|||||||
@@ -38,17 +38,17 @@ func NewClient(cfg *config.PushConfig, enabled bool) (*Client, error) {
|
|||||||
log.Warn().Msg("APNs not configured - iOS push disabled")
|
log.Warn().Msg("APNs not configured - iOS push disabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize FCM client (Android)
|
// Initialize FCM client (Android) - requires project ID + service account credentials
|
||||||
if cfg.FCMServerKey != "" {
|
if cfg.FCMProjectID != "" && (cfg.FCMServiceAccountPath != "" || cfg.FCMServiceAccountJSON != "") {
|
||||||
fcmClient, err := NewFCMClient(cfg)
|
fcmClient, err := NewFCMClient(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Msg("Failed to initialize FCM client - Android push disabled")
|
log.Warn().Err(err).Msg("Failed to initialize FCM v1 client - Android push disabled")
|
||||||
} else {
|
} else {
|
||||||
client.fcm = fcmClient
|
client.fcm = fcmClient
|
||||||
log.Info().Msg("FCM client initialized successfully")
|
log.Info().Msg("FCM v1 client initialized successfully")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Warn().Msg("FCM not configured - Android push disabled")
|
log.Warn().Msg("FCM not configured (need FCM_PROJECT_ID + FCM_SERVICE_ACCOUNT_PATH or FCM_SERVICE_ACCOUNT_JSON) - Android push disabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
return client, nil
|
return client, nil
|
||||||
|
|||||||
@@ -5,138 +5,304 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
"golang.org/x/oauth2/google"
|
||||||
|
|
||||||
"github.com/treytartt/honeydue-api/internal/config"
|
"github.com/treytartt/honeydue-api/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
const fcmEndpoint = "https://fcm.googleapis.com/fcm/send"
|
const (
|
||||||
|
// fcmV1EndpointFmt is the FCM HTTP v1 API endpoint template.
|
||||||
|
fcmV1EndpointFmt = "https://fcm.googleapis.com/v1/projects/%s/messages:send"
|
||||||
|
|
||||||
// FCMClient handles direct communication with Firebase Cloud Messaging
|
// fcmScope is the OAuth 2.0 scope required for FCM HTTP v1 API.
|
||||||
|
fcmScope = "https://www.googleapis.com/auth/firebase.messaging"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FCMClient handles communication with Firebase Cloud Messaging
|
||||||
|
// using the HTTP v1 API and OAuth 2.0 service account authentication.
|
||||||
type FCMClient struct {
|
type FCMClient struct {
|
||||||
serverKey string
|
projectID string
|
||||||
|
endpoint string
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// FCMMessage represents an FCM message payload
|
// --- Request types (FCM v1 API) ---
|
||||||
type FCMMessage struct {
|
|
||||||
To string `json:"to,omitempty"`
|
// fcmV1Request is the top-level request body for the FCM v1 API.
|
||||||
RegistrationIDs []string `json:"registration_ids,omitempty"`
|
type fcmV1Request struct {
|
||||||
Notification *FCMNotification `json:"notification,omitempty"`
|
Message *fcmV1Message `json:"message"`
|
||||||
Data map[string]string `json:"data,omitempty"`
|
|
||||||
Priority string `json:"priority,omitempty"`
|
|
||||||
ContentAvailable bool `json:"content_available,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FCMNotification represents the notification payload
|
// fcmV1Message represents a single FCM v1 message.
|
||||||
|
type fcmV1Message struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
Notification *FCMNotification `json:"notification,omitempty"`
|
||||||
|
Data map[string]string `json:"data,omitempty"`
|
||||||
|
Android *fcmAndroidConfig `json:"android,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FCMNotification represents the notification payload.
|
||||||
type FCMNotification struct {
|
type FCMNotification struct {
|
||||||
Title string `json:"title,omitempty"`
|
Title string `json:"title,omitempty"`
|
||||||
Body string `json:"body,omitempty"`
|
Body string `json:"body,omitempty"`
|
||||||
Sound string `json:"sound,omitempty"`
|
|
||||||
Badge string `json:"badge,omitempty"`
|
|
||||||
Icon string `json:"icon,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FCMResponse represents the FCM API response
|
// fcmAndroidConfig provides Android-specific message configuration.
|
||||||
type FCMResponse struct {
|
type fcmAndroidConfig struct {
|
||||||
MulticastID int64 `json:"multicast_id"`
|
Priority string `json:"priority,omitempty"`
|
||||||
Success int `json:"success"`
|
|
||||||
Failure int `json:"failure"`
|
|
||||||
CanonicalIDs int `json:"canonical_ids"`
|
|
||||||
Results []FCMResult `json:"results"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FCMResult represents a single result in the FCM response
|
// --- Response types (FCM v1 API) ---
|
||||||
type FCMResult struct {
|
|
||||||
MessageID string `json:"message_id,omitempty"`
|
// fcmV1Response is the successful response from the FCM v1 API.
|
||||||
RegistrationID string `json:"registration_id,omitempty"`
|
type fcmV1Response struct {
|
||||||
Error string `json:"error,omitempty"`
|
Name string `json:"name"` // e.g. "projects/myproject/messages/0:1234567890"
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFCMClient creates a new FCM client
|
// fcmV1ErrorResponse is the error response from the FCM v1 API.
|
||||||
|
type fcmV1ErrorResponse struct {
|
||||||
|
Error fcmV1Error `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// fcmV1Error contains the structured error details.
|
||||||
|
type fcmV1Error struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Details json.RawMessage `json:"details,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Error types ---
|
||||||
|
|
||||||
|
// FCMErrorCode represents well-known FCM v1 error codes for programmatic handling.
|
||||||
|
type FCMErrorCode string
|
||||||
|
|
||||||
|
const (
|
||||||
|
FCMErrUnregistered FCMErrorCode = "UNREGISTERED"
|
||||||
|
FCMErrQuotaExceeded FCMErrorCode = "QUOTA_EXCEEDED"
|
||||||
|
FCMErrUnavailable FCMErrorCode = "UNAVAILABLE"
|
||||||
|
FCMErrInternal FCMErrorCode = "INTERNAL"
|
||||||
|
FCMErrInvalidArgument FCMErrorCode = "INVALID_ARGUMENT"
|
||||||
|
FCMErrSenderIDMismatch FCMErrorCode = "SENDER_ID_MISMATCH"
|
||||||
|
FCMErrThirdPartyAuth FCMErrorCode = "THIRD_PARTY_AUTH_ERROR"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FCMSendError is a structured error returned when an individual FCM send fails.
|
||||||
|
type FCMSendError struct {
|
||||||
|
Token string
|
||||||
|
StatusCode int
|
||||||
|
ErrorCode FCMErrorCode
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *FCMSendError) Error() string {
|
||||||
|
return fmt.Sprintf("FCM send failed for token %s: %s (status %d, code %s)",
|
||||||
|
truncateToken(e.Token), e.Message, e.StatusCode, e.ErrorCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsUnregistered returns true if the device token is no longer valid and should be removed.
|
||||||
|
func (e *FCMSendError) IsUnregistered() bool {
|
||||||
|
return e.ErrorCode == FCMErrUnregistered
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- OAuth 2.0 transport ---
|
||||||
|
|
||||||
|
// oauth2BearerTransport is an http.RoundTripper that attaches an OAuth 2.0 Bearer
|
||||||
|
// token to every outgoing request. The token source handles refresh automatically.
|
||||||
|
type oauth2BearerTransport struct {
|
||||||
|
base http.RoundTripper
|
||||||
|
getToken func() (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *oauth2BearerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
accessToken, err := t.getToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to obtain OAuth 2.0 token for FCM: %w", err)
|
||||||
|
}
|
||||||
|
r := req.Clone(req.Context())
|
||||||
|
r.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
return t.base.RoundTrip(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Client construction ---
|
||||||
|
|
||||||
|
// NewFCMClient creates a new FCM client using the HTTP v1 API with OAuth 2.0
|
||||||
|
// service account authentication. It accepts either a path to a service account
|
||||||
|
// JSON file or the raw JSON content directly via config.
|
||||||
func NewFCMClient(cfg *config.PushConfig) (*FCMClient, error) {
|
func NewFCMClient(cfg *config.PushConfig) (*FCMClient, error) {
|
||||||
if cfg.FCMServerKey == "" {
|
if cfg.FCMProjectID == "" {
|
||||||
return nil, fmt.Errorf("FCM server key not configured")
|
return nil, fmt.Errorf("FCM project ID not configured (set FCM_PROJECT_ID)")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &FCMClient{
|
credJSON, err := resolveServiceAccountJSON(cfg)
|
||||||
serverKey: cfg.FCMServerKey,
|
if err != nil {
|
||||||
httpClient: &http.Client{
|
return nil, err
|
||||||
Timeout: 30 * time.Second,
|
}
|
||||||
|
|
||||||
|
// Create OAuth 2.0 credentials with the FCM messaging scope.
|
||||||
|
// The google library handles automatic token refresh.
|
||||||
|
creds, err := google.CredentialsFromJSON(context.Background(), credJSON, fcmScope)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse FCM service account credentials: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build an HTTP client that automatically attaches and refreshes OAuth tokens.
|
||||||
|
transport := &oauth2BearerTransport{
|
||||||
|
base: http.DefaultTransport,
|
||||||
|
getToken: func() (string, error) {
|
||||||
|
tok, err := creds.TokenSource.Token()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return tok.AccessToken, nil
|
||||||
},
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
Transport: transport,
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := fmt.Sprintf(fcmV1EndpointFmt, cfg.FCMProjectID)
|
||||||
|
|
||||||
|
log.Info().
|
||||||
|
Str("project_id", cfg.FCMProjectID).
|
||||||
|
Str("endpoint", endpoint).
|
||||||
|
Msg("FCM v1 client initialized with OAuth 2.0")
|
||||||
|
|
||||||
|
return &FCMClient{
|
||||||
|
projectID: cfg.FCMProjectID,
|
||||||
|
endpoint: endpoint,
|
||||||
|
httpClient: httpClient,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send sends a push notification to Android devices
|
// resolveServiceAccountJSON returns the service account JSON bytes from config.
|
||||||
|
func resolveServiceAccountJSON(cfg *config.PushConfig) ([]byte, error) {
|
||||||
|
if cfg.FCMServiceAccountJSON != "" {
|
||||||
|
return []byte(cfg.FCMServiceAccountJSON), nil
|
||||||
|
}
|
||||||
|
if cfg.FCMServiceAccountPath != "" {
|
||||||
|
data, err := os.ReadFile(cfg.FCMServiceAccountPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read FCM service account file %s: %w", cfg.FCMServiceAccountPath, err)
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("FCM service account not configured (set FCM_SERVICE_ACCOUNT_PATH or FCM_SERVICE_ACCOUNT_JSON)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Sending ---
|
||||||
|
|
||||||
|
// Send sends a push notification to Android devices via the FCM HTTP v1 API.
|
||||||
|
// The v1 API requires one request per device token, so this iterates over all tokens.
|
||||||
|
// The method signature is kept identical to the previous legacy implementation
|
||||||
|
// so callers do not need to change.
|
||||||
func (c *FCMClient) Send(ctx context.Context, tokens []string, title, message string, data map[string]string) error {
|
func (c *FCMClient) Send(ctx context.Context, tokens []string, title, message string, data map[string]string) error {
|
||||||
if len(tokens) == 0 {
|
if len(tokens) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := FCMMessage{
|
var sendErrors []error
|
||||||
RegistrationIDs: tokens,
|
successCount := 0
|
||||||
Notification: &FCMNotification{
|
|
||||||
Title: title,
|
|
||||||
Body: message,
|
|
||||||
Sound: "default",
|
|
||||||
},
|
|
||||||
Data: data,
|
|
||||||
Priority: "high",
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := json.Marshal(msg)
|
for _, token := range tokens {
|
||||||
if err != nil {
|
err := c.sendOne(ctx, token, title, message, data)
|
||||||
return fmt.Errorf("failed to marshal FCM message: %w", err)
|
if err != nil {
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", fcmEndpoint, bytes.NewReader(body))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create FCM request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Authorization", "key="+c.serverKey)
|
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to send FCM request: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return fmt.Errorf("FCM returned status %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
var fcmResp FCMResponse
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&fcmResp); err != nil {
|
|
||||||
return fmt.Errorf("failed to decode FCM response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Log individual results
|
|
||||||
for i, result := range fcmResp.Results {
|
|
||||||
if i >= len(tokens) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if result.Error != "" {
|
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("token", truncateToken(tokens[i])).
|
Err(err).
|
||||||
Str("error", result.Error).
|
Str("token", truncateToken(token)).
|
||||||
Msg("FCM notification failed")
|
Msg("FCM v1 notification failed")
|
||||||
|
sendErrors = append(sendErrors, err)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
successCount++
|
||||||
|
log.Debug().
|
||||||
|
Str("token", truncateToken(token)).
|
||||||
|
Msg("FCM v1 notification sent successfully")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().
|
log.Info().
|
||||||
Int("total", len(tokens)).
|
Int("total", len(tokens)).
|
||||||
Int("success", fcmResp.Success).
|
Int("success", successCount).
|
||||||
Int("failure", fcmResp.Failure).
|
Int("failed", len(sendErrors)).
|
||||||
Msg("FCM batch send complete")
|
Msg("FCM v1 batch send complete")
|
||||||
|
|
||||||
if fcmResp.Success == 0 && fcmResp.Failure > 0 {
|
if len(sendErrors) > 0 && successCount == 0 {
|
||||||
return fmt.Errorf("all FCM notifications failed")
|
return fmt.Errorf("all FCM notifications failed: first error: %w", sendErrors[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sendOne sends a single FCM v1 message to one device token.
|
||||||
|
func (c *FCMClient) sendOne(ctx context.Context, token, title, message string, data map[string]string) error {
|
||||||
|
reqBody := fcmV1Request{
|
||||||
|
Message: &fcmV1Message{
|
||||||
|
Token: token,
|
||||||
|
Notification: &FCMNotification{
|
||||||
|
Title: title,
|
||||||
|
Body: message,
|
||||||
|
},
|
||||||
|
Data: data,
|
||||||
|
Android: &fcmAndroidConfig{
|
||||||
|
Priority: "HIGH",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal FCM v1 message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create FCM v1 request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to send FCM v1 request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read FCM v1 response body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusOK {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the error response for structured error information.
|
||||||
|
return parseFCMV1Error(token, resp.StatusCode, respBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseFCMV1Error extracts a structured FCMSendError from the v1 API error response.
|
||||||
|
func parseFCMV1Error(token string, statusCode int, body []byte) *FCMSendError {
|
||||||
|
var errResp fcmV1ErrorResponse
|
||||||
|
if err := json.Unmarshal(body, &errResp); err != nil {
|
||||||
|
return &FCMSendError{
|
||||||
|
Token: token,
|
||||||
|
StatusCode: statusCode,
|
||||||
|
Message: fmt.Sprintf("unparseable error response: %s", string(body)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &FCMSendError{
|
||||||
|
Token: token,
|
||||||
|
StatusCode: statusCode,
|
||||||
|
ErrorCode: FCMErrorCode(errResp.Error.Status),
|
||||||
|
Message: errResp.Error.Message,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,182 +5,266 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// newTestFCMClient creates an FCMClient pointing at the given test server URL.
|
// newTestFCMClient creates an FCMClient whose endpoint points at the given test
|
||||||
|
// server URL. The HTTP client uses no OAuth transport so tests can run without
|
||||||
|
// real Google credentials.
|
||||||
func newTestFCMClient(serverURL string) *FCMClient {
|
func newTestFCMClient(serverURL string) *FCMClient {
|
||||||
return &FCMClient{
|
return &FCMClient{
|
||||||
serverKey: "test-server-key",
|
projectID: "test-project",
|
||||||
httpClient: http.DefaultClient,
|
endpoint: serverURL,
|
||||||
|
httpClient: &http.Client{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// serveFCMResponse creates an httptest.Server that returns the given FCMResponse as JSON.
|
// serveFCMV1Success creates a test server that returns a successful v1 response
|
||||||
func serveFCMResponse(t *testing.T, resp FCMResponse) *httptest.Server {
|
// for every request.
|
||||||
|
func serveFCMV1Success(t *testing.T) *httptest.Server {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Verify it looks like a v1 request.
|
||||||
|
var req fcmV1Request
|
||||||
|
err := json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, req.Message)
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
err := json.NewEncoder(w).Encode(resp)
|
resp := fcmV1Response{Name: "projects/test-project/messages/0:12345"}
|
||||||
require.NoError(t, err)
|
_ = json.NewEncoder(w).Encode(resp)
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendWithEndpoint is a helper that sends an FCM notification using a custom endpoint
|
// serveFCMV1Error creates a test server that returns the given status code and
|
||||||
// (the test server) instead of the real FCM endpoint. This avoids modifying the
|
// a structured v1 error response for every request.
|
||||||
// production code to be testable and instead temporarily overrides the client's HTTP
|
func serveFCMV1Error(t *testing.T, statusCode int, errStatus string, errMessage string) *httptest.Server {
|
||||||
// transport to redirect requests to our test server.
|
t.Helper()
|
||||||
func sendWithEndpoint(client *FCMClient, server *httptest.Server, ctx context.Context, tokens []string, title, message string, data map[string]string) error {
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Override the HTTP client to redirect all requests to the test server
|
w.Header().Set("Content-Type", "application/json")
|
||||||
client.httpClient = server.Client()
|
w.WriteHeader(statusCode)
|
||||||
|
resp := fcmV1ErrorResponse{
|
||||||
// We need to intercept the request and redirect it to our test server.
|
Error: fcmV1Error{
|
||||||
// Use a custom RoundTripper that rewrites the URL.
|
Code: statusCode,
|
||||||
originalTransport := server.Client().Transport
|
Message: errMessage,
|
||||||
client.httpClient.Transport = roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
Status: errStatus,
|
||||||
// Rewrite the URL to point to the test server
|
},
|
||||||
req.URL.Scheme = "http"
|
|
||||||
req.URL.Host = server.Listener.Addr().String()
|
|
||||||
if originalTransport != nil {
|
|
||||||
return originalTransport.RoundTrip(req)
|
|
||||||
}
|
}
|
||||||
return http.DefaultTransport.RoundTrip(req)
|
_ = json.NewEncoder(w).Encode(resp)
|
||||||
})
|
}))
|
||||||
|
|
||||||
return client.Send(ctx, tokens, title, message, data)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// roundTripFunc is a function that implements http.RoundTripper.
|
func TestFCMV1Send_Success_SingleToken(t *testing.T) {
|
||||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
server := serveFCMV1Success(t)
|
||||||
|
|
||||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
||||||
return f(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFCMSend_MoreResultsThanTokens_NoPanic(t *testing.T) {
|
|
||||||
// FCM returns 5 results but we only sent 2 tokens.
|
|
||||||
// Before the bounds check fix, this would panic with index out of range.
|
|
||||||
fcmResp := FCMResponse{
|
|
||||||
MulticastID: 12345,
|
|
||||||
Success: 2,
|
|
||||||
Failure: 3,
|
|
||||||
Results: []FCMResult{
|
|
||||||
{MessageID: "msg1"},
|
|
||||||
{MessageID: "msg2"},
|
|
||||||
{Error: "InvalidRegistration"},
|
|
||||||
{Error: "NotRegistered"},
|
|
||||||
{Error: "InvalidRegistration"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
server := serveFCMResponse(t, fcmResp)
|
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := newTestFCMClient(server.URL)
|
client := newTestFCMClient(server.URL)
|
||||||
tokens := []string{"token-aaa-111", "token-bbb-222"}
|
err := client.Send(context.Background(), []string{"token-aaa-111"}, "Title", "Body", nil)
|
||||||
|
|
||||||
// This must not panic
|
|
||||||
err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil)
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFCMSend_FewerResultsThanTokens_NoPanic(t *testing.T) {
|
func TestFCMV1Send_Success_MultipleTokens(t *testing.T) {
|
||||||
// FCM returns fewer results than tokens we sent.
|
var mu sync.Mutex
|
||||||
// This is also a malformed response but should not panic.
|
receivedTokens := make([]string, 0)
|
||||||
fcmResp := FCMResponse{
|
|
||||||
MulticastID: 12345,
|
|
||||||
Success: 1,
|
|
||||||
Failure: 0,
|
|
||||||
Results: []FCMResult{
|
|
||||||
{MessageID: "msg1"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
server := serveFCMResponse(t, fcmResp)
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req fcmV1Request
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
receivedTokens = append(receivedTokens, req.Message.Token)
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
resp := fcmV1Response{Name: "projects/test-project/messages/0:12345"}
|
||||||
|
_ = json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := newTestFCMClient(server.URL)
|
client := newTestFCMClient(server.URL)
|
||||||
tokens := []string{"token-aaa-111", "token-bbb-222", "token-ccc-333"}
|
tokens := []string{"token-aaa-111", "token-bbb-222", "token-ccc-333"}
|
||||||
|
|
||||||
err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil)
|
err := client.Send(context.Background(), tokens, "Title", "Body", map[string]string{"key": "value"})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
assert.ElementsMatch(t, tokens, receivedTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFCMSend_EmptyResponse_NoPanic(t *testing.T) {
|
func TestFCMV1Send_EmptyTokens_ReturnsNil(t *testing.T) {
|
||||||
// FCM returns an empty Results slice.
|
|
||||||
fcmResp := FCMResponse{
|
|
||||||
MulticastID: 12345,
|
|
||||||
Success: 0,
|
|
||||||
Failure: 0,
|
|
||||||
Results: []FCMResult{},
|
|
||||||
}
|
|
||||||
|
|
||||||
server := serveFCMResponse(t, fcmResp)
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := newTestFCMClient(server.URL)
|
|
||||||
tokens := []string{"token-aaa-111"}
|
|
||||||
|
|
||||||
err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil)
|
|
||||||
// No panic expected. The function returns nil because fcmResp.Success == 0
|
|
||||||
// and fcmResp.Failure == 0 (the "all failed" check requires Failure > 0).
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFCMSend_NilResultsSlice_NoPanic(t *testing.T) {
|
|
||||||
// FCM returns a response with nil Results (e.g., malformed JSON).
|
|
||||||
fcmResp := FCMResponse{
|
|
||||||
MulticastID: 12345,
|
|
||||||
Success: 0,
|
|
||||||
Failure: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
server := serveFCMResponse(t, fcmResp)
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := newTestFCMClient(server.URL)
|
|
||||||
tokens := []string{"token-aaa-111"}
|
|
||||||
|
|
||||||
err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil)
|
|
||||||
// Should return error because Success == 0 and Failure > 0
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Contains(t, err.Error(), "all FCM notifications failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFCMSend_EmptyTokens_ReturnsNil(t *testing.T) {
|
|
||||||
// Verify the early return for empty tokens.
|
|
||||||
client := &FCMClient{
|
client := &FCMClient{
|
||||||
serverKey: "test-key",
|
projectID: "test-project",
|
||||||
httpClient: http.DefaultClient,
|
endpoint: "http://unused",
|
||||||
|
httpClient: &http.Client{},
|
||||||
}
|
}
|
||||||
|
|
||||||
err := client.Send(context.Background(), []string{}, "Test", "Body", nil)
|
err := client.Send(context.Background(), []string{}, "Title", "Body", nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFCMSend_ResultsWithErrorsMatchTokens(t *testing.T) {
|
func TestFCMV1Send_AllFail_ReturnsError(t *testing.T) {
|
||||||
// Normal case: results count matches tokens count, all with errors.
|
server := serveFCMV1Error(t, http.StatusNotFound, "UNREGISTERED", "The registration token is not registered")
|
||||||
fcmResp := FCMResponse{
|
|
||||||
MulticastID: 12345,
|
|
||||||
Success: 0,
|
|
||||||
Failure: 2,
|
|
||||||
Results: []FCMResult{
|
|
||||||
{Error: "InvalidRegistration"},
|
|
||||||
{Error: "NotRegistered"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
server := serveFCMResponse(t, fcmResp)
|
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := newTestFCMClient(server.URL)
|
client := newTestFCMClient(server.URL)
|
||||||
tokens := []string{"token-aaa-111", "token-bbb-222"}
|
tokens := []string{"bad-token-1", "bad-token-2"}
|
||||||
|
|
||||||
err := sendWithEndpoint(client, server, context.Background(), tokens, "Test", "Body", nil)
|
err := client.Send(context.Background(), tokens, "Title", "Body", nil)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "all FCM notifications failed")
|
assert.Contains(t, err.Error(), "all FCM notifications failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFCMV1Send_PartialFailure_ReturnsNil(t *testing.T) {
|
||||||
|
var mu sync.Mutex
|
||||||
|
callCount := 0
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
mu.Lock()
|
||||||
|
callCount++
|
||||||
|
n := callCount
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if n == 1 {
|
||||||
|
// First token succeeds
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
resp := fcmV1Response{Name: "projects/test-project/messages/0:12345"}
|
||||||
|
_ = json.NewEncoder(w).Encode(resp)
|
||||||
|
} else {
|
||||||
|
// Second token fails
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
resp := fcmV1ErrorResponse{
|
||||||
|
Error: fcmV1Error{Code: 404, Message: "not registered", Status: "UNREGISTERED"},
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(resp)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := newTestFCMClient(server.URL)
|
||||||
|
tokens := []string{"good-token", "bad-token"}
|
||||||
|
|
||||||
|
// Partial failure: at least one succeeded, so no error returned.
|
||||||
|
err := client.Send(context.Background(), tokens, "Title", "Body", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFCMV1Send_UnregisteredError(t *testing.T) {
|
||||||
|
server := serveFCMV1Error(t, http.StatusNotFound, "UNREGISTERED", "The registration token is not registered")
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := newTestFCMClient(server.URL)
|
||||||
|
|
||||||
|
err := client.sendOne(context.Background(), "stale-token", "Title", "Body", nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var sendErr *FCMSendError
|
||||||
|
require.ErrorAs(t, err, &sendErr)
|
||||||
|
assert.True(t, sendErr.IsUnregistered())
|
||||||
|
assert.Equal(t, FCMErrUnregistered, sendErr.ErrorCode)
|
||||||
|
assert.Equal(t, http.StatusNotFound, sendErr.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFCMV1Send_QuotaExceededError(t *testing.T) {
|
||||||
|
server := serveFCMV1Error(t, http.StatusTooManyRequests, "QUOTA_EXCEEDED", "Sending quota exceeded")
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := newTestFCMClient(server.URL)
|
||||||
|
|
||||||
|
err := client.sendOne(context.Background(), "some-token", "Title", "Body", nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var sendErr *FCMSendError
|
||||||
|
require.ErrorAs(t, err, &sendErr)
|
||||||
|
assert.Equal(t, FCMErrQuotaExceeded, sendErr.ErrorCode)
|
||||||
|
assert.Equal(t, http.StatusTooManyRequests, sendErr.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFCMV1Send_UnparseableErrorResponse(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
_, _ = w.Write([]byte("not json at all"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := newTestFCMClient(server.URL)
|
||||||
|
|
||||||
|
err := client.sendOne(context.Background(), "some-token", "Title", "Body", nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var sendErr *FCMSendError
|
||||||
|
require.ErrorAs(t, err, &sendErr)
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, sendErr.StatusCode)
|
||||||
|
assert.Contains(t, sendErr.Message, "unparseable error response")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFCMV1Send_RequestPayloadFormat(t *testing.T) {
|
||||||
|
var receivedReq fcmV1Request
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Verify Content-Type header
|
||||||
|
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
|
||||||
|
assert.Equal(t, http.MethodPost, r.Method)
|
||||||
|
|
||||||
|
err := json.NewDecoder(r.Body).Decode(&receivedReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_ = json.NewEncoder(w).Encode(fcmV1Response{Name: "projects/test-project/messages/0:12345"})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := newTestFCMClient(server.URL)
|
||||||
|
data := map[string]string{"task_id": "42", "action": "complete"}
|
||||||
|
err := client.Send(context.Background(), []string{"device-token-xyz"}, "Task Due", "Your task is due today", data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the v1 message structure
|
||||||
|
require.NotNil(t, receivedReq.Message)
|
||||||
|
assert.Equal(t, "device-token-xyz", receivedReq.Message.Token)
|
||||||
|
assert.Equal(t, "Task Due", receivedReq.Message.Notification.Title)
|
||||||
|
assert.Equal(t, "Your task is due today", receivedReq.Message.Notification.Body)
|
||||||
|
assert.Equal(t, "42", receivedReq.Message.Data["task_id"])
|
||||||
|
assert.Equal(t, "complete", receivedReq.Message.Data["action"])
|
||||||
|
assert.Equal(t, "HIGH", receivedReq.Message.Android.Priority)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFCMSendError_Error(t *testing.T) {
|
||||||
|
sendErr := &FCMSendError{
|
||||||
|
Token: "abcdef1234567890",
|
||||||
|
StatusCode: 404,
|
||||||
|
ErrorCode: FCMErrUnregistered,
|
||||||
|
Message: "token not registered",
|
||||||
|
}
|
||||||
|
|
||||||
|
errStr := sendErr.Error()
|
||||||
|
assert.Contains(t, errStr, "abcdef12...")
|
||||||
|
assert.Contains(t, errStr, "token not registered")
|
||||||
|
assert.Contains(t, errStr, "404")
|
||||||
|
assert.Contains(t, errStr, "UNREGISTERED")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFCMSendError_IsUnregistered(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
code FCMErrorCode
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"unregistered", FCMErrUnregistered, true},
|
||||||
|
{"quota_exceeded", FCMErrQuotaExceeded, false},
|
||||||
|
{"internal", FCMErrInternal, false},
|
||||||
|
{"empty", "", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := &FCMSendError{ErrorCode: tt.code}
|
||||||
|
assert.Equal(t, tt.expected, err.IsUnregistered())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -124,29 +124,33 @@ func (r *ContractorRepository) GetTasksForContractor(contractorID uint) ([]model
|
|||||||
return tasks, err
|
return tasks, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSpecialties sets the specialties for a contractor
|
// SetSpecialties sets the specialties for a contractor.
|
||||||
|
// Wrapped in a transaction so that clearing existing specialties and
|
||||||
|
// appending new ones are atomic -- a failure in either step rolls back both.
|
||||||
func (r *ContractorRepository) SetSpecialties(contractorID uint, specialtyIDs []uint) error {
|
func (r *ContractorRepository) SetSpecialties(contractorID uint, specialtyIDs []uint) error {
|
||||||
var contractor models.Contractor
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
if err := r.db.First(&contractor, contractorID).Error; err != nil {
|
var contractor models.Contractor
|
||||||
return err
|
if err := tx.First(&contractor, contractorID).Error; err != nil {
|
||||||
}
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Clear existing specialties
|
// Clear existing specialties
|
||||||
if err := r.db.Model(&contractor).Association("Specialties").Clear(); err != nil {
|
if err := tx.Model(&contractor).Association("Specialties").Clear(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(specialtyIDs) == 0 {
|
if len(specialtyIDs) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add new specialties
|
// Add new specialties
|
||||||
var specialties []models.ContractorSpecialty
|
var specialties []models.ContractorSpecialty
|
||||||
if err := r.db.Where("id IN ?", specialtyIDs).Find(&specialties).Error; err != nil {
|
if err := tx.Where("id IN ?", specialtyIDs).Find(&specialties).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.db.Model(&contractor).Association("Specialties").Append(specialties)
|
return tx.Model(&contractor).Association("Specialties").Append(specialties)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountByResidence counts contractors in a residence
|
// CountByResidence counts contractors in a residence
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ func (r *DocumentRepository) FindByUserFiltered(residenceIDs []uint, filter *Doc
|
|||||||
}
|
}
|
||||||
|
|
||||||
var documents []models.Document
|
var documents []models.Document
|
||||||
err := query.Order("created_at DESC").Find(&documents).Error
|
err := query.Order("created_at DESC").Limit(500).Find(&documents).Error
|
||||||
return documents, err
|
return documents, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -88,10 +88,33 @@ func (r *ReminderRepository) HasSentReminderBatch(keys []ReminderKey) (map[int]b
|
|||||||
userIDs = append(userIDs, id)
|
userIDs = append(userIDs, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query all matching reminder logs in one query
|
// Collect unique stages and due dates for tighter SQL filtering
|
||||||
|
stageSet := make(map[models.ReminderStage]bool)
|
||||||
|
dueDateSet := make(map[string]bool)
|
||||||
|
var minDueDate, maxDueDate time.Time
|
||||||
|
for _, k := range keys {
|
||||||
|
stageSet[k.Stage] = true
|
||||||
|
dueDateOnly := time.Date(k.DueDate.Year(), k.DueDate.Month(), k.DueDate.Day(), 0, 0, 0, 0, time.UTC)
|
||||||
|
dueDateSet[dueDateOnly.Format("2006-01-02")] = true
|
||||||
|
if minDueDate.IsZero() || dueDateOnly.Before(minDueDate) {
|
||||||
|
minDueDate = dueDateOnly
|
||||||
|
}
|
||||||
|
if maxDueDate.IsZero() || dueDateOnly.After(maxDueDate) {
|
||||||
|
maxDueDate = dueDateOnly
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stages := make([]models.ReminderStage, 0, len(stageSet))
|
||||||
|
for s := range stageSet {
|
||||||
|
stages = append(stages, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query matching reminder logs with tighter filters to reduce result set.
|
||||||
|
// Filter on reminder_stage and due_date range in addition to task_id/user_id.
|
||||||
var logs []models.TaskReminderLog
|
var logs []models.TaskReminderLog
|
||||||
err := r.db.Where("task_id IN ? AND user_id IN ?", taskIDs, userIDs).
|
err := r.db.Where(
|
||||||
Find(&logs).Error
|
"task_id IN ? AND user_id IN ? AND reminder_stage IN ? AND due_date >= ? AND due_date <= ?",
|
||||||
|
taskIDs, userIDs, stages, minDueDate, maxDueDate,
|
||||||
|
).Find(&logs).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -196,6 +196,20 @@ func (r *ResidenceRepository) CountByOwner(userID uint) (int64, error) {
|
|||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FindResidenceIDsByOwner returns just the IDs of residences a user owns.
|
||||||
|
// This is a lightweight alternative to FindOwnedByUser() when only IDs are needed
|
||||||
|
// for batch queries against related tables (tasks, contractors, documents).
|
||||||
|
func (r *ResidenceRepository) FindResidenceIDsByOwner(userID uint) ([]uint, error) {
|
||||||
|
var ids []uint
|
||||||
|
err := r.db.Model(&models.Residence{}).
|
||||||
|
Where("owner_id = ? AND is_active = ?", userID, true).
|
||||||
|
Pluck("id", &ids).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
// === Share Code Operations ===
|
// === Share Code Operations ===
|
||||||
|
|
||||||
// CreateShareCode creates a new share code for a residence
|
// CreateShareCode creates a new share code for a residence
|
||||||
|
|||||||
@@ -129,12 +129,21 @@ func (r *SubscriptionRepository) UpdatePurchaseToken(userID uint, token string)
|
|||||||
Update("google_purchase_token", token).Error
|
Update("google_purchase_token", token).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindByAppleReceiptContains finds a subscription by Apple transaction ID
|
// FindByAppleReceiptContains finds a subscription by Apple transaction ID.
|
||||||
// Used by webhooks to find the user associated with a transaction
|
// Used by webhooks to find the user associated with a transaction.
|
||||||
|
//
|
||||||
|
// PERFORMANCE NOTE: This uses a LIKE '%...%' scan on apple_receipt_data which
|
||||||
|
// cannot use a B-tree index and results in a full table scan. For better
|
||||||
|
// performance at scale, add a dedicated indexed column:
|
||||||
|
//
|
||||||
|
// AppleTransactionID *string `gorm:"column:apple_transaction_id;size:255;index"`
|
||||||
|
//
|
||||||
|
// Then look up by exact match: WHERE apple_transaction_id = ?
|
||||||
func (r *SubscriptionRepository) FindByAppleReceiptContains(transactionID string) (*models.UserSubscription, error) {
|
func (r *SubscriptionRepository) FindByAppleReceiptContains(transactionID string) (*models.UserSubscription, error) {
|
||||||
var sub models.UserSubscription
|
var sub models.UserSubscription
|
||||||
// Search for transaction ID in the stored receipt data
|
// Escape LIKE wildcards in the transaction ID to prevent wildcard injection
|
||||||
err := r.db.Where("apple_receipt_data LIKE ?", "%"+transactionID+"%").First(&sub).Error
|
escaped := escapeLikeWildcards(transactionID)
|
||||||
|
err := r.db.Where("apple_receipt_data LIKE ?", "%"+escaped+"%").First(&sub).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,29 +38,40 @@ func (r *TaskRepository) CreateCompletionTx(tx *gorm.DB, completion *models.Task
|
|||||||
return tx.Create(completion).Error
|
return tx.Create(completion).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// taskUpdateFields returns the canonical field map used by both Update and UpdateTx.
|
||||||
|
// Centralised here so the two methods never drift out of sync.
|
||||||
|
func taskUpdateFields(t *models.Task) map[string]interface{} {
|
||||||
|
return map[string]interface{}{
|
||||||
|
"title": t.Title,
|
||||||
|
"description": t.Description,
|
||||||
|
"category_id": t.CategoryID,
|
||||||
|
"priority_id": t.PriorityID,
|
||||||
|
"frequency_id": t.FrequencyID,
|
||||||
|
"custom_interval_days": t.CustomIntervalDays,
|
||||||
|
"in_progress": t.InProgress,
|
||||||
|
"assigned_to_id": t.AssignedToID,
|
||||||
|
"due_date": t.DueDate,
|
||||||
|
"next_due_date": t.NextDueDate,
|
||||||
|
"estimated_cost": t.EstimatedCost,
|
||||||
|
"actual_cost": t.ActualCost,
|
||||||
|
"contractor_id": t.ContractorID,
|
||||||
|
"is_cancelled": t.IsCancelled,
|
||||||
|
"is_archived": t.IsArchived,
|
||||||
|
"version": gorm.Expr("version + 1"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// taskUpdateOmitAssociations lists the association fields to omit during task updates.
|
||||||
|
var taskUpdateOmitAssociations = []string{
|
||||||
|
"Residence", "CreatedBy", "AssignedTo", "Category", "Priority", "Frequency", "ParentTask", "Completions",
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateTx updates a task with optimistic locking within an existing transaction.
|
// UpdateTx updates a task with optimistic locking within an existing transaction.
|
||||||
func (r *TaskRepository) UpdateTx(tx *gorm.DB, task *models.Task) error {
|
func (r *TaskRepository) UpdateTx(tx *gorm.DB, task *models.Task) error {
|
||||||
result := tx.Model(task).
|
result := tx.Model(task).
|
||||||
Where("id = ? AND version = ?", task.ID, task.Version).
|
Where("id = ? AND version = ?", task.ID, task.Version).
|
||||||
Omit("Residence", "CreatedBy", "AssignedTo", "Category", "Priority", "Frequency", "ParentTask", "Completions").
|
Omit(taskUpdateOmitAssociations...).
|
||||||
Updates(map[string]interface{}{
|
Updates(taskUpdateFields(task))
|
||||||
"title": task.Title,
|
|
||||||
"description": task.Description,
|
|
||||||
"category_id": task.CategoryID,
|
|
||||||
"priority_id": task.PriorityID,
|
|
||||||
"frequency_id": task.FrequencyID,
|
|
||||||
"custom_interval_days": task.CustomIntervalDays,
|
|
||||||
"in_progress": task.InProgress,
|
|
||||||
"assigned_to_id": task.AssignedToID,
|
|
||||||
"due_date": task.DueDate,
|
|
||||||
"next_due_date": task.NextDueDate,
|
|
||||||
"estimated_cost": task.EstimatedCost,
|
|
||||||
"actual_cost": task.ActualCost,
|
|
||||||
"contractor_id": task.ContractorID,
|
|
||||||
"is_cancelled": task.IsCancelled,
|
|
||||||
"is_archived": task.IsArchived,
|
|
||||||
"version": gorm.Expr("version + 1"),
|
|
||||||
})
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
@@ -350,25 +361,8 @@ func (r *TaskRepository) Create(task *models.Task) error {
|
|||||||
func (r *TaskRepository) Update(task *models.Task) error {
|
func (r *TaskRepository) Update(task *models.Task) error {
|
||||||
result := r.db.Model(task).
|
result := r.db.Model(task).
|
||||||
Where("id = ? AND version = ?", task.ID, task.Version).
|
Where("id = ? AND version = ?", task.ID, task.Version).
|
||||||
Omit("Residence", "CreatedBy", "AssignedTo", "Category", "Priority", "Frequency", "ParentTask", "Completions").
|
Omit(taskUpdateOmitAssociations...).
|
||||||
Updates(map[string]interface{}{
|
Updates(taskUpdateFields(task))
|
||||||
"title": task.Title,
|
|
||||||
"description": task.Description,
|
|
||||||
"category_id": task.CategoryID,
|
|
||||||
"priority_id": task.PriorityID,
|
|
||||||
"frequency_id": task.FrequencyID,
|
|
||||||
"custom_interval_days": task.CustomIntervalDays,
|
|
||||||
"in_progress": task.InProgress,
|
|
||||||
"assigned_to_id": task.AssignedToID,
|
|
||||||
"due_date": task.DueDate,
|
|
||||||
"next_due_date": task.NextDueDate,
|
|
||||||
"estimated_cost": task.EstimatedCost,
|
|
||||||
"actual_cost": task.ActualCost,
|
|
||||||
"contractor_id": task.ContractorID,
|
|
||||||
"is_cancelled": task.IsCancelled,
|
|
||||||
"is_archived": task.IsArchived,
|
|
||||||
"version": gorm.Expr("version + 1"),
|
|
||||||
})
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
@@ -728,13 +722,18 @@ func (r *TaskRepository) UpdateCompletion(completion *models.TaskCompletion) err
|
|||||||
return r.db.Omit("Task", "CompletedBy", "Images").Save(completion).Error
|
return r.db.Omit("Task", "CompletedBy", "Images").Save(completion).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteCompletion deletes a task completion
|
// DeleteCompletion deletes a task completion and its associated images atomically.
|
||||||
|
// Wrapped in a transaction so that if the completion delete fails, image
|
||||||
|
// deletions are rolled back as well.
|
||||||
func (r *TaskRepository) DeleteCompletion(id uint) error {
|
func (r *TaskRepository) DeleteCompletion(id uint) error {
|
||||||
// Delete images first
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
if err := r.db.Where("completion_id = ?", id).Delete(&models.TaskCompletionImage{}).Error; err != nil {
|
// Delete images first
|
||||||
log.Error().Err(err).Uint("completion_id", id).Msg("Failed to delete completion images")
|
if err := tx.Where("completion_id = ?", id).Delete(&models.TaskCompletionImage{}).Error; err != nil {
|
||||||
}
|
log.Error().Err(err).Uint("completion_id", id).Msg("Failed to delete completion images")
|
||||||
return r.db.Delete(&models.TaskCompletion{}, id).Error
|
return err
|
||||||
|
}
|
||||||
|
return tx.Delete(&models.TaskCompletion{}, id).Error
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateCompletionImage creates a new completion image
|
// CreateCompletionImage creates a new completion image
|
||||||
@@ -912,3 +911,128 @@ func (r *TaskRepository) GetCompletionSummary(residenceID uint, now time.Time, m
|
|||||||
Months: months,
|
Months: months,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetBatchCompletionSummaries returns completion summaries for multiple residences
|
||||||
|
// in two queries total (one for all-time counts, one for monthly breakdowns),
|
||||||
|
// instead of 2*N queries when calling GetCompletionSummary per residence.
|
||||||
|
func (r *TaskRepository) GetBatchCompletionSummaries(residenceIDs []uint, now time.Time, maxPerMonth int) (map[uint]*responses.CompletionSummary, error) {
|
||||||
|
result := make(map[uint]*responses.CompletionSummary, len(residenceIDs))
|
||||||
|
if len(residenceIDs) == 0 {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. Total all-time completions per residence (single query)
|
||||||
|
type allTimeRow struct {
|
||||||
|
ResidenceID uint
|
||||||
|
Count int64
|
||||||
|
}
|
||||||
|
var allTimeRows []allTimeRow
|
||||||
|
err := r.db.Model(&models.TaskCompletion{}).
|
||||||
|
Select("task_task.residence_id, COUNT(*) as count").
|
||||||
|
Joins("JOIN task_task ON task_task.id = task_taskcompletion.task_id").
|
||||||
|
Where("task_task.residence_id IN ?", residenceIDs).
|
||||||
|
Group("task_task.residence_id").
|
||||||
|
Scan(&allTimeRows).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
allTimeMap := make(map[uint]int64, len(allTimeRows))
|
||||||
|
for _, row := range allTimeRows {
|
||||||
|
allTimeMap[row.ResidenceID] = row.Count
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Monthly breakdown for last 12 months across all residences (single query)
|
||||||
|
startDate := time.Date(now.Year()-1, now.Month(), 1, 0, 0, 0, 0, now.Location())
|
||||||
|
|
||||||
|
dateExpr := "TO_CHAR(task_taskcompletion.completed_at, 'YYYY-MM')"
|
||||||
|
if r.db.Dialector.Name() == "sqlite" {
|
||||||
|
dateExpr = "strftime('%Y-%m', task_taskcompletion.completed_at)"
|
||||||
|
}
|
||||||
|
|
||||||
|
var rows []completionAggRow
|
||||||
|
err = r.db.Model(&models.TaskCompletion{}).
|
||||||
|
Select(fmt.Sprintf("task_task.residence_id, task_taskcompletion.completed_from_column, %s as completed_month, COUNT(*) as count", dateExpr)).
|
||||||
|
Joins("JOIN task_task ON task_task.id = task_taskcompletion.task_id").
|
||||||
|
Where("task_task.residence_id IN ? AND task_taskcompletion.completed_at >= ?", residenceIDs, startDate).
|
||||||
|
Group(fmt.Sprintf("task_task.residence_id, task_taskcompletion.completed_from_column, %s", dateExpr)).
|
||||||
|
Order("completed_month ASC").
|
||||||
|
Scan(&rows).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Build per-residence summaries
|
||||||
|
type monthData struct {
|
||||||
|
columns map[string]int
|
||||||
|
total int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize all residences with empty month maps
|
||||||
|
residenceMonths := make(map[uint]map[string]*monthData, len(residenceIDs))
|
||||||
|
for _, rid := range residenceIDs {
|
||||||
|
mm := make(map[string]*monthData, 12)
|
||||||
|
for i := 0; i < 12; i++ {
|
||||||
|
m := startDate.AddDate(0, i, 0)
|
||||||
|
key := m.Format("2006-01")
|
||||||
|
mm[key] = &monthData{columns: make(map[string]int)}
|
||||||
|
}
|
||||||
|
residenceMonths[rid] = mm
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate from query results
|
||||||
|
residenceLast12 := make(map[uint]int, len(residenceIDs))
|
||||||
|
for _, row := range rows {
|
||||||
|
mm, ok := residenceMonths[row.ResidenceID]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
md, ok := mm[row.CompletedMonth]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
md.columns[row.CompletedFromColumn] = int(row.Count)
|
||||||
|
md.total += int(row.Count)
|
||||||
|
residenceLast12[row.ResidenceID] += int(row.Count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to response DTOs per residence
|
||||||
|
for _, rid := range residenceIDs {
|
||||||
|
mm := residenceMonths[rid]
|
||||||
|
months := make([]responses.MonthlyCompletionSummary, 0, 12)
|
||||||
|
for i := 0; i < 12; i++ {
|
||||||
|
m := startDate.AddDate(0, i, 0)
|
||||||
|
key := m.Format("2006-01")
|
||||||
|
md := mm[key]
|
||||||
|
|
||||||
|
completions := make([]responses.ColumnCompletionCount, 0)
|
||||||
|
for col, count := range md.columns {
|
||||||
|
completions = append(completions, responses.ColumnCompletionCount{
|
||||||
|
Column: col,
|
||||||
|
Color: KanbanColumnColor(col),
|
||||||
|
Count: count,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
overflow := 0
|
||||||
|
if md.total > maxPerMonth {
|
||||||
|
overflow = md.total - maxPerMonth
|
||||||
|
}
|
||||||
|
|
||||||
|
months = append(months, responses.MonthlyCompletionSummary{
|
||||||
|
Month: key,
|
||||||
|
Completions: completions,
|
||||||
|
Total: md.total,
|
||||||
|
Overflow: overflow,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
result[rid] = &responses.CompletionSummary{
|
||||||
|
TotalAllTime: int(allTimeMap[rid]),
|
||||||
|
TotalLast12Months: residenceLast12[rid],
|
||||||
|
Months: months,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -34,6 +34,16 @@ func NewUserRepository(db *gorm.DB) *UserRepository {
|
|||||||
return &UserRepository{db: db}
|
return &UserRepository{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Transaction runs fn inside a database transaction. The callback receives a
|
||||||
|
// new UserRepository backed by the transaction so all operations within fn
|
||||||
|
// share the same transactional connection.
|
||||||
|
func (r *UserRepository) Transaction(fn func(txRepo *UserRepository) error) error {
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
txRepo := &UserRepository{db: tx}
|
||||||
|
return fn(txRepo)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// FindByID finds a user by ID
|
// FindByID finds a user by ID
|
||||||
func (r *UserRepository) FindByID(id uint) (*models.User, error) {
|
func (r *UserRepository) FindByID(id uint) (*models.User, error) {
|
||||||
var user models.User
|
var user models.User
|
||||||
@@ -130,18 +140,28 @@ func (r *UserRepository) ExistsByEmail(email string) (bool, error) {
|
|||||||
|
|
||||||
// --- Auth Token Methods ---
|
// --- Auth Token Methods ---
|
||||||
|
|
||||||
// GetOrCreateToken gets or creates an auth token for a user
|
// GetOrCreateToken gets or creates an auth token for a user.
|
||||||
|
// Wrapped in a transaction to prevent race conditions where two
|
||||||
|
// concurrent requests could create duplicate tokens for the same user.
|
||||||
func (r *UserRepository) GetOrCreateToken(userID uint) (*models.AuthToken, error) {
|
func (r *UserRepository) GetOrCreateToken(userID uint) (*models.AuthToken, error) {
|
||||||
var token models.AuthToken
|
var token models.AuthToken
|
||||||
result := r.db.Where("user_id = ?", userID).First(&token)
|
|
||||||
|
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
err := r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
token = models.AuthToken{UserID: userID}
|
result := tx.Where("user_id = ?", userID).First(&token)
|
||||||
if err := r.db.Create(&token).Error; err != nil {
|
|
||||||
return nil, err
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
token = models.AuthToken{UserID: userID}
|
||||||
|
if err := tx.Create(&token).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
}
|
}
|
||||||
} else if result.Error != nil {
|
|
||||||
return nil, result.Error
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &token, nil
|
return &token, nil
|
||||||
@@ -341,7 +361,7 @@ func (r *UserRepository) SearchUsers(query string, limit, offset int) ([]models.
|
|||||||
var users []models.User
|
var users []models.User
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
searchQuery := "%" + strings.ToLower(query) + "%"
|
searchQuery := "%" + escapeLikeWildcards(strings.ToLower(query)) + "%"
|
||||||
|
|
||||||
baseQuery := r.db.Model(&models.User{}).
|
baseQuery := r.db.Model(&models.User{}).
|
||||||
Where("LOWER(username) LIKE ? OR LOWER(email) LIKE ? OR LOWER(first_name) LIKE ? OR LOWER(last_name) LIKE ?",
|
Where("LOWER(username) LIKE ? OR LOWER(email) LIKE ? OR LOWER(first_name) LIKE ? OR LOWER(last_name) LIKE ?",
|
||||||
@@ -384,7 +404,7 @@ func (r *UserRepository) FindUsersInSharedResidences(userID uint) ([]models.User
|
|||||||
// 2. Members of residences owned by current user
|
// 2. Members of residences owned by current user
|
||||||
// 3. Members of residences where current user is also a member
|
// 3. Members of residences where current user is also a member
|
||||||
err := r.db.Raw(`
|
err := r.db.Raw(`
|
||||||
SELECT DISTINCT u.* FROM user_customuser u
|
SELECT DISTINCT u.* FROM auth_user u
|
||||||
WHERE u.id != ? AND u.is_active = true AND (
|
WHERE u.id != ? AND u.is_active = true AND (
|
||||||
-- Users who own residences where current user is a shared user
|
-- Users who own residences where current user is a shared user
|
||||||
u.id IN (
|
u.id IN (
|
||||||
@@ -417,7 +437,7 @@ func (r *UserRepository) FindUserIfSharedResidence(targetUserID, requestingUserI
|
|||||||
var user models.User
|
var user models.User
|
||||||
|
|
||||||
err := r.db.Raw(`
|
err := r.db.Raw(`
|
||||||
SELECT u.* FROM user_customuser u
|
SELECT u.* FROM auth_user u
|
||||||
WHERE u.id = ? AND u.is_active = true AND (
|
WHERE u.id = ? AND u.is_active = true AND (
|
||||||
u.id = ? OR
|
u.id = ? OR
|
||||||
-- Target owns a residence where requester is a member
|
-- Target owns a residence where requester is a member
|
||||||
@@ -460,7 +480,7 @@ func (r *UserRepository) FindProfilesInSharedResidences(userID uint) ([]models.U
|
|||||||
|
|
||||||
err := r.db.Raw(`
|
err := r.db.Raw(`
|
||||||
SELECT p.* FROM user_userprofile p
|
SELECT p.* FROM user_userprofile p
|
||||||
INNER JOIN user_customuser u ON p.user_id = u.id
|
INNER JOIN auth_user u ON p.user_id = u.id
|
||||||
WHERE u.is_active = true AND (
|
WHERE u.is_active = true AND (
|
||||||
u.id = ? OR
|
u.id = ? OR
|
||||||
-- Users who own residences where current user is a shared user
|
-- Users who own residences where current user is a shared user
|
||||||
|
|||||||
@@ -59,6 +59,15 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
|||||||
e.Use(custommiddleware.RequestIDMiddleware())
|
e.Use(custommiddleware.RequestIDMiddleware())
|
||||||
e.Use(utils.EchoRecovery())
|
e.Use(utils.EchoRecovery())
|
||||||
e.Use(custommiddleware.StructuredLogger())
|
e.Use(custommiddleware.StructuredLogger())
|
||||||
|
|
||||||
|
// Security headers (X-Frame-Options, X-Content-Type-Options, X-XSS-Protection, etc.)
|
||||||
|
e.Use(middleware.SecureWithConfig(middleware.SecureConfig{
|
||||||
|
XSSProtection: "1; mode=block",
|
||||||
|
ContentTypeNosniff: "nosniff",
|
||||||
|
XFrameOptions: "SAMEORIGIN",
|
||||||
|
HSTSMaxAge: 31536000, // 1 year in seconds
|
||||||
|
ReferrerPolicy: "strict-origin-when-cross-origin",
|
||||||
|
}))
|
||||||
e.Use(middleware.BodyLimitWithConfig(middleware.BodyLimitConfig{
|
e.Use(middleware.BodyLimitWithConfig(middleware.BodyLimitConfig{
|
||||||
Limit: "1M", // 1MB default for JSON payloads
|
Limit: "1M", // 1MB default for JSON payloads
|
||||||
Skipper: func(c echo.Context) bool {
|
Skipper: func(c echo.Context) bool {
|
||||||
@@ -187,7 +196,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
|||||||
var uploadHandler *handlers.UploadHandler
|
var uploadHandler *handlers.UploadHandler
|
||||||
var mediaHandler *handlers.MediaHandler
|
var mediaHandler *handlers.MediaHandler
|
||||||
if deps.StorageService != nil {
|
if deps.StorageService != nil {
|
||||||
uploadHandler = handlers.NewUploadHandler(deps.StorageService)
|
uploadHandler = handlers.NewUploadHandler(deps.StorageService, services.NewFileOwnershipService(deps.DB))
|
||||||
mediaHandler = handlers.NewMediaHandler(documentRepo, taskRepo, residenceRepo, deps.StorageService)
|
mediaHandler = handlers.NewMediaHandler(documentRepo, taskRepo, residenceRepo, deps.StorageService)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -247,13 +256,22 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// corsMiddleware configures CORS with restricted origins in production.
|
// corsMiddleware configures CORS with restricted origins in production.
|
||||||
// In debug mode, all origins are allowed for development convenience.
|
// In debug mode, explicit localhost origins are allowed for development.
|
||||||
// In production, origins are read from the CORS_ALLOWED_ORIGINS environment variable
|
// In production, origins are read from the CORS_ALLOWED_ORIGINS environment variable
|
||||||
// (comma-separated), falling back to a restrictive default set.
|
// (comma-separated), falling back to a restrictive default set.
|
||||||
func corsMiddleware(cfg *config.Config) echo.MiddlewareFunc {
|
func corsMiddleware(cfg *config.Config) echo.MiddlewareFunc {
|
||||||
var origins []string
|
var origins []string
|
||||||
if cfg.Server.Debug {
|
if cfg.Server.Debug {
|
||||||
origins = []string{"*"}
|
origins = []string{
|
||||||
|
"http://localhost:3000",
|
||||||
|
"http://localhost:3001",
|
||||||
|
"http://localhost:8080",
|
||||||
|
"http://localhost:8000",
|
||||||
|
"http://127.0.0.1:3000",
|
||||||
|
"http://127.0.0.1:3001",
|
||||||
|
"http://127.0.0.1:8080",
|
||||||
|
"http://127.0.0.1:8000",
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
origins = cfg.Server.CorsAllowedOrigins
|
origins = cfg.Server.CorsAllowedOrigins
|
||||||
if len(origins) == 0 {
|
if len(origins) == 0 {
|
||||||
@@ -286,17 +304,24 @@ func healthCheck(c echo.Context) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupPublicAuthRoutes configures public authentication routes
|
// setupPublicAuthRoutes configures public authentication routes with
|
||||||
|
// per-endpoint rate limiters to mitigate brute-force and credential-stuffing.
|
||||||
func setupPublicAuthRoutes(api *echo.Group, authHandler *handlers.AuthHandler) {
|
func setupPublicAuthRoutes(api *echo.Group, authHandler *handlers.AuthHandler) {
|
||||||
auth := api.Group("/auth")
|
auth := api.Group("/auth")
|
||||||
|
|
||||||
|
// Rate limiters — created once, shared across requests.
|
||||||
|
loginRL := custommiddleware.LoginRateLimiter() // 10 req/min
|
||||||
|
registerRL := custommiddleware.RegistrationRateLimiter() // 5 req/min
|
||||||
|
passwordRL := custommiddleware.PasswordResetRateLimiter() // 3 req/min
|
||||||
|
|
||||||
{
|
{
|
||||||
auth.POST("/login/", authHandler.Login)
|
auth.POST("/login/", authHandler.Login, loginRL)
|
||||||
auth.POST("/register/", authHandler.Register)
|
auth.POST("/register/", authHandler.Register, registerRL)
|
||||||
auth.POST("/forgot-password/", authHandler.ForgotPassword)
|
auth.POST("/forgot-password/", authHandler.ForgotPassword, passwordRL)
|
||||||
auth.POST("/verify-reset-code/", authHandler.VerifyResetCode)
|
auth.POST("/verify-reset-code/", authHandler.VerifyResetCode, passwordRL)
|
||||||
auth.POST("/reset-password/", authHandler.ResetPassword)
|
auth.POST("/reset-password/", authHandler.ResetPassword, passwordRL)
|
||||||
auth.POST("/apple-sign-in/", authHandler.AppleSignIn)
|
auth.POST("/apple-sign-in/", authHandler.AppleSignIn, loginRL)
|
||||||
auth.POST("/google-sign-in/", authHandler.GoogleSignIn)
|
auth.POST("/google-sign-in/", authHandler.GoogleSignIn, loginRL)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
"github.com/treytartt/honeydue-api/internal/apperrors"
|
"github.com/treytartt/honeydue-api/internal/apperrors"
|
||||||
@@ -90,7 +91,7 @@ func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginRespons
|
|||||||
// Update last login
|
// Update last login
|
||||||
if err := s.userRepo.UpdateLastLogin(user.ID); err != nil {
|
if err := s.userRepo.UpdateLastLogin(user.ID); err != nil {
|
||||||
// Log error but don't fail the login
|
// Log error but don't fail the login
|
||||||
fmt.Printf("Failed to update last login: %v\n", err)
|
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to update last login")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &responses.LoginResponse{
|
return &responses.LoginResponse{
|
||||||
@@ -99,7 +100,9 @@ func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginRespons
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register creates a new user account
|
// Register creates a new user account.
|
||||||
|
// F-10: User creation, profile creation, notification preferences, and confirmation code
|
||||||
|
// are wrapped in a transaction for atomicity.
|
||||||
func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.RegisterResponse, string, error) {
|
func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.RegisterResponse, string, error) {
|
||||||
// Check if username exists
|
// Check if username exists
|
||||||
exists, err := s.userRepo.ExistsByUsername(req.Username)
|
exists, err := s.userRepo.ExistsByUsername(req.Username)
|
||||||
@@ -133,43 +136,49 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
|
|||||||
return nil, "", apperrors.Internal(err)
|
return nil, "", apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save user
|
// Generate confirmation code - use fixed code when DEBUG_FIXED_CODES is enabled for easier local testing
|
||||||
if err := s.userRepo.Create(user); err != nil {
|
|
||||||
return nil, "", apperrors.Internal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create user profile
|
|
||||||
if _, err := s.userRepo.GetOrCreateProfile(user.ID); err != nil {
|
|
||||||
// Log error but don't fail registration
|
|
||||||
fmt.Printf("Failed to create user profile: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create notification preferences with all options enabled
|
|
||||||
if s.notificationRepo != nil {
|
|
||||||
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
|
|
||||||
// Log error but don't fail registration
|
|
||||||
fmt.Printf("Failed to create notification preferences: %v\n", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create auth token
|
|
||||||
token, err := s.userRepo.GetOrCreateToken(user.ID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", apperrors.Internal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate confirmation code - use fixed code in debug mode for easier local testing
|
|
||||||
var code string
|
var code string
|
||||||
if s.cfg.Server.Debug {
|
if s.cfg.Server.DebugFixedCodes {
|
||||||
code = "123456"
|
code = "123456"
|
||||||
} else {
|
} else {
|
||||||
code = generateSixDigitCode()
|
code = generateSixDigitCode()
|
||||||
}
|
}
|
||||||
expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry)
|
expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry)
|
||||||
|
|
||||||
if _, err := s.userRepo.CreateConfirmationCode(user.ID, code, expiresAt); err != nil {
|
// Wrap user creation + profile + notification preferences + confirmation code in a transaction
|
||||||
// Log error but don't fail registration
|
txErr := s.userRepo.Transaction(func(txRepo *repositories.UserRepository) error {
|
||||||
fmt.Printf("Failed to create confirmation code: %v\n", err)
|
// Save user
|
||||||
|
if err := txRepo.Create(user); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create user profile
|
||||||
|
if _, err := txRepo.GetOrCreateProfile(user.ID); err != nil {
|
||||||
|
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create user profile during registration")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create notification preferences with all options enabled
|
||||||
|
if s.notificationRepo != nil {
|
||||||
|
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
|
||||||
|
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences during registration")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create confirmation code
|
||||||
|
if _, err := txRepo.CreateConfirmationCode(user.ID, code, expiresAt); err != nil {
|
||||||
|
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create confirmation code during registration")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if txErr != nil {
|
||||||
|
return nil, "", apperrors.Internal(txErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create auth token (outside transaction since token generation is idempotent)
|
||||||
|
token, err := s.userRepo.GetOrCreateToken(user.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &responses.RegisterResponse{
|
return &responses.RegisterResponse{
|
||||||
@@ -248,8 +257,8 @@ func (s *AuthService) VerifyEmail(userID uint, code string) error {
|
|||||||
return apperrors.BadRequest("error.email_already_verified")
|
return apperrors.BadRequest("error.email_already_verified")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for test code in debug mode
|
// Check for test code when DEBUG_FIXED_CODES is enabled
|
||||||
if s.cfg.Server.Debug && code == "123456" {
|
if s.cfg.Server.DebugFixedCodes && code == "123456" {
|
||||||
if err := s.userRepo.SetProfileVerified(userID, true); err != nil {
|
if err := s.userRepo.SetProfileVerified(userID, true); err != nil {
|
||||||
return apperrors.Internal(err)
|
return apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
@@ -294,9 +303,9 @@ func (s *AuthService) ResendVerificationCode(userID uint) (string, error) {
|
|||||||
return "", apperrors.BadRequest("error.email_already_verified")
|
return "", apperrors.BadRequest("error.email_already_verified")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate new code - use fixed code in debug mode for easier local testing
|
// Generate new code - use fixed code when DEBUG_FIXED_CODES is enabled for easier local testing
|
||||||
var code string
|
var code string
|
||||||
if s.cfg.Server.Debug {
|
if s.cfg.Server.DebugFixedCodes {
|
||||||
code = "123456"
|
code = "123456"
|
||||||
} else {
|
} else {
|
||||||
code = generateSixDigitCode()
|
code = generateSixDigitCode()
|
||||||
@@ -331,9 +340,9 @@ func (s *AuthService) ForgotPassword(email string) (string, *models.User, error)
|
|||||||
return "", nil, apperrors.TooManyRequests("error.rate_limit_exceeded")
|
return "", nil, apperrors.TooManyRequests("error.rate_limit_exceeded")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate code and reset token - use fixed code in debug mode for easier local testing
|
// Generate code and reset token - use fixed code when DEBUG_FIXED_CODES is enabled for easier local testing
|
||||||
var code string
|
var code string
|
||||||
if s.cfg.Server.Debug {
|
if s.cfg.Server.DebugFixedCodes {
|
||||||
code = "123456"
|
code = "123456"
|
||||||
} else {
|
} else {
|
||||||
code = generateSixDigitCode()
|
code = generateSixDigitCode()
|
||||||
@@ -365,8 +374,8 @@ func (s *AuthService) VerifyResetCode(email, code string) (string, error) {
|
|||||||
return "", apperrors.Internal(err)
|
return "", apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for test code in debug mode
|
// Check for test code when DEBUG_FIXED_CODES is enabled
|
||||||
if s.cfg.Server.Debug && code == "123456" {
|
if s.cfg.Server.DebugFixedCodes && code == "123456" {
|
||||||
return resetCode.ResetToken, nil
|
return resetCode.ResetToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -422,13 +431,13 @@ func (s *AuthService) ResetPassword(resetToken, newPassword string) error {
|
|||||||
// Mark reset code as used
|
// Mark reset code as used
|
||||||
if err := s.userRepo.MarkPasswordResetCodeUsed(resetCode.ID); err != nil {
|
if err := s.userRepo.MarkPasswordResetCodeUsed(resetCode.ID); err != nil {
|
||||||
// Log error but don't fail
|
// Log error but don't fail
|
||||||
fmt.Printf("Failed to mark reset code as used: %v\n", err)
|
log.Warn().Err(err).Uint("reset_code_id", resetCode.ID).Msg("Failed to mark reset code as used")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Invalidate all existing tokens for this user (security measure)
|
// Invalidate all existing tokens for this user (security measure)
|
||||||
if err := s.userRepo.DeleteTokenByUserID(user.ID); err != nil {
|
if err := s.userRepo.DeleteTokenByUserID(user.ID); err != nil {
|
||||||
// Log error but don't fail
|
// Log error but don't fail
|
||||||
fmt.Printf("Failed to delete user tokens: %v\n", err)
|
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to delete user tokens after password reset")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -482,6 +491,13 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
|||||||
if email != "" {
|
if email != "" {
|
||||||
existingUser, err := s.userRepo.FindByEmail(email)
|
existingUser, err := s.userRepo.FindByEmail(email)
|
||||||
if err == nil && existingUser != nil {
|
if err == nil && existingUser != nil {
|
||||||
|
// S-06: Log auto-linking of social account to existing user
|
||||||
|
log.Warn().
|
||||||
|
Str("email", email).
|
||||||
|
Str("provider", "apple").
|
||||||
|
Uint("user_id", existingUser.ID).
|
||||||
|
Msg("Auto-linking social account to existing user by email match")
|
||||||
|
|
||||||
// Link Apple ID to existing account
|
// Link Apple ID to existing account
|
||||||
appleAuthRecord := &models.AppleSocialAuth{
|
appleAuthRecord := &models.AppleSocialAuth{
|
||||||
UserID: existingUser.ID,
|
UserID: existingUser.ID,
|
||||||
@@ -505,8 +521,11 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
|||||||
// Update last login
|
// Update last login
|
||||||
_ = s.userRepo.UpdateLastLogin(existingUser.ID)
|
_ = s.userRepo.UpdateLastLogin(existingUser.ID)
|
||||||
|
|
||||||
// Reload user with profile
|
// B-08: Check error from FindByIDWithProfile
|
||||||
existingUser, _ = s.userRepo.FindByIDWithProfile(existingUser.ID)
|
existingUser, err = s.userRepo.FindByIDWithProfile(existingUser.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, apperrors.Internal(err)
|
||||||
|
}
|
||||||
|
|
||||||
return &responses.AppleSignInResponse{
|
return &responses.AppleSignInResponse{
|
||||||
Token: token.Key,
|
Token: token.Key,
|
||||||
@@ -544,8 +563,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
|||||||
// Create notification preferences with all options enabled
|
// Create notification preferences with all options enabled
|
||||||
if s.notificationRepo != nil {
|
if s.notificationRepo != nil {
|
||||||
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
|
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
|
||||||
// Log error but don't fail registration
|
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Apple Sign In user")
|
||||||
fmt.Printf("Failed to create notification preferences: %v\n", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -566,8 +584,11 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
|||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reload user with profile
|
// B-08: Check error from FindByIDWithProfile
|
||||||
user, _ = s.userRepo.FindByIDWithProfile(user.ID)
|
user, err = s.userRepo.FindByIDWithProfile(user.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, apperrors.Internal(err)
|
||||||
|
}
|
||||||
|
|
||||||
return &responses.AppleSignInResponse{
|
return &responses.AppleSignInResponse{
|
||||||
Token: token.Key,
|
Token: token.Key,
|
||||||
@@ -623,6 +644,13 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
|||||||
if email != "" {
|
if email != "" {
|
||||||
existingUser, err := s.userRepo.FindByEmail(email)
|
existingUser, err := s.userRepo.FindByEmail(email)
|
||||||
if err == nil && existingUser != nil {
|
if err == nil && existingUser != nil {
|
||||||
|
// S-06: Log auto-linking of social account to existing user
|
||||||
|
log.Warn().
|
||||||
|
Str("email", email).
|
||||||
|
Str("provider", "google").
|
||||||
|
Uint("user_id", existingUser.ID).
|
||||||
|
Msg("Auto-linking social account to existing user by email match")
|
||||||
|
|
||||||
// Link Google ID to existing account
|
// Link Google ID to existing account
|
||||||
googleAuthRecord := &models.GoogleSocialAuth{
|
googleAuthRecord := &models.GoogleSocialAuth{
|
||||||
UserID: existingUser.ID,
|
UserID: existingUser.ID,
|
||||||
@@ -649,8 +677,11 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
|||||||
// Update last login
|
// Update last login
|
||||||
_ = s.userRepo.UpdateLastLogin(existingUser.ID)
|
_ = s.userRepo.UpdateLastLogin(existingUser.ID)
|
||||||
|
|
||||||
// Reload user with profile
|
// B-08: Check error from FindByIDWithProfile
|
||||||
existingUser, _ = s.userRepo.FindByIDWithProfile(existingUser.ID)
|
existingUser, err = s.userRepo.FindByIDWithProfile(existingUser.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, apperrors.Internal(err)
|
||||||
|
}
|
||||||
|
|
||||||
return &responses.GoogleSignInResponse{
|
return &responses.GoogleSignInResponse{
|
||||||
Token: token.Key,
|
Token: token.Key,
|
||||||
@@ -688,8 +719,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
|||||||
// Create notification preferences with all options enabled
|
// Create notification preferences with all options enabled
|
||||||
if s.notificationRepo != nil {
|
if s.notificationRepo != nil {
|
||||||
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
|
if _, err := s.notificationRepo.GetOrCreatePreferences(user.ID); err != nil {
|
||||||
// Log error but don't fail registration
|
log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Google Sign In user")
|
||||||
fmt.Printf("Failed to create notification preferences: %v\n", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -711,8 +741,11 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
|||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reload user with profile
|
// B-08: Check error from FindByIDWithProfile
|
||||||
user, _ = s.userRepo.FindByIDWithProfile(user.ID)
|
user, err = s.userRepo.FindByIDWithProfile(user.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, apperrors.Internal(err)
|
||||||
|
}
|
||||||
|
|
||||||
return &responses.GoogleSignInResponse{
|
return &responses.GoogleSignInResponse{
|
||||||
Token: token.Key,
|
Token: token.Key,
|
||||||
|
|||||||
@@ -2,9 +2,10 @@ package services
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/md5"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"hash/fnv"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
@@ -18,38 +19,55 @@ type CacheService struct {
|
|||||||
client *redis.Client
|
client *redis.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
var cacheInstance *CacheService
|
var (
|
||||||
|
cacheInstance *CacheService
|
||||||
|
cacheOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
// NewCacheService creates a new cache service
|
// NewCacheService creates a new cache service (thread-safe via sync.Once)
|
||||||
func NewCacheService(cfg *config.RedisConfig) (*CacheService, error) {
|
func NewCacheService(cfg *config.RedisConfig) (*CacheService, error) {
|
||||||
opt, err := redis.ParseURL(cfg.URL)
|
var initErr error
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse Redis URL: %w", err)
|
cacheOnce.Do(func() {
|
||||||
|
opt, err := redis.ParseURL(cfg.URL)
|
||||||
|
if err != nil {
|
||||||
|
initErr = fmt.Errorf("failed to parse Redis URL: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Password != "" {
|
||||||
|
opt.Password = cfg.Password
|
||||||
|
}
|
||||||
|
if cfg.DB != 0 {
|
||||||
|
opt.DB = cfg.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
client := redis.NewClient(opt)
|
||||||
|
|
||||||
|
// Test connection
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := client.Ping(ctx).Err(); err != nil {
|
||||||
|
initErr = fmt.Errorf("failed to connect to Redis: %w", err)
|
||||||
|
// Reset Once so a retry is possible after transient failures
|
||||||
|
cacheOnce = sync.Once{}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// S-14: Mask credentials in Redis URL before logging
|
||||||
|
log.Info().
|
||||||
|
Str("url", config.MaskURLCredentials(cfg.URL)).
|
||||||
|
Int("db", opt.DB).
|
||||||
|
Msg("Connected to Redis")
|
||||||
|
|
||||||
|
cacheInstance = &CacheService{client: client}
|
||||||
|
})
|
||||||
|
|
||||||
|
if initErr != nil {
|
||||||
|
return nil, initErr
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.Password != "" {
|
|
||||||
opt.Password = cfg.Password
|
|
||||||
}
|
|
||||||
if cfg.DB != 0 {
|
|
||||||
opt.DB = cfg.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
client := redis.NewClient(opt)
|
|
||||||
|
|
||||||
// Test connection
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if err := client.Ping(ctx).Err(); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info().
|
|
||||||
Str("url", cfg.URL).
|
|
||||||
Int("db", opt.DB).
|
|
||||||
Msg("Connected to Redis")
|
|
||||||
|
|
||||||
cacheInstance = &CacheService{client: client}
|
|
||||||
return cacheInstance, nil
|
return cacheInstance, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -311,9 +329,10 @@ func (c *CacheService) CacheSeededData(ctx context.Context, data interface{}) (s
|
|||||||
return "", fmt.Errorf("failed to marshal seeded data: %w", err)
|
return "", fmt.Errorf("failed to marshal seeded data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate MD5 ETag from the JSON data
|
// Generate FNV-64a ETag from the JSON data (faster than MD5, non-cryptographic)
|
||||||
hash := md5.Sum(jsonData)
|
h := fnv.New64a()
|
||||||
etag := fmt.Sprintf("\"%x\"", hash)
|
h.Write(jsonData)
|
||||||
|
etag := fmt.Sprintf("\"%x\"", h.Sum64())
|
||||||
|
|
||||||
// Store both the data and the ETag
|
// Store both the data and the ETag
|
||||||
if err := c.client.Set(ctx, SeededDataKey, jsonData, SeededDataTTL).Err(); err != nil {
|
if err := c.client.Set(ctx, SeededDataKey, jsonData, SeededDataTTL).Err(); err != nil {
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
"io"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
mail "github.com/wneessen/go-mail"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"gopkg.in/gomail.v2"
|
|
||||||
|
|
||||||
"github.com/treytartt/honeydue-api/internal/config"
|
"github.com/treytartt/honeydue-api/internal/config"
|
||||||
)
|
)
|
||||||
@@ -16,17 +16,31 @@ import (
|
|||||||
// EmailService handles sending emails
|
// EmailService handles sending emails
|
||||||
type EmailService struct {
|
type EmailService struct {
|
||||||
cfg *config.EmailConfig
|
cfg *config.EmailConfig
|
||||||
dialer *gomail.Dialer
|
client *mail.Client
|
||||||
enabled bool
|
enabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewEmailService creates a new email service
|
// NewEmailService creates a new email service
|
||||||
func NewEmailService(cfg *config.EmailConfig, enabled bool) *EmailService {
|
func NewEmailService(cfg *config.EmailConfig, enabled bool) *EmailService {
|
||||||
dialer := gomail.NewDialer(cfg.Host, cfg.Port, cfg.User, cfg.Password)
|
client, err := mail.NewClient(cfg.Host,
|
||||||
|
mail.WithPort(cfg.Port),
|
||||||
|
mail.WithSMTPAuth(mail.SMTPAuthPlain),
|
||||||
|
mail.WithUsername(cfg.User),
|
||||||
|
mail.WithPassword(cfg.Password),
|
||||||
|
mail.WithTLSPortPolicy(mail.TLSOpportunistic),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("Failed to create mail client - emails will not be sent")
|
||||||
|
return &EmailService{
|
||||||
|
cfg: cfg,
|
||||||
|
client: nil,
|
||||||
|
enabled: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &EmailService{
|
return &EmailService{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
dialer: dialer,
|
client: client,
|
||||||
enabled: enabled,
|
enabled: enabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -37,14 +51,18 @@ func (s *EmailService) SendEmail(to, subject, htmlBody, textBody string) error {
|
|||||||
log.Debug().Msg("Email sending disabled by feature flag")
|
log.Debug().Msg("Email sending disabled by feature flag")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
m := gomail.NewMessage()
|
m := mail.NewMsg()
|
||||||
m.SetHeader("From", s.cfg.From)
|
if err := m.FromFormat("honeyDue", s.cfg.From); err != nil {
|
||||||
m.SetHeader("To", to)
|
return fmt.Errorf("failed to set from address: %w", err)
|
||||||
m.SetHeader("Subject", subject)
|
}
|
||||||
m.SetBody("text/plain", textBody)
|
if err := m.AddTo(to); err != nil {
|
||||||
m.AddAlternative("text/html", htmlBody)
|
return fmt.Errorf("failed to set to address: %w", err)
|
||||||
|
}
|
||||||
|
m.Subject(subject)
|
||||||
|
m.SetBodyString(mail.TypeTextPlain, textBody)
|
||||||
|
m.AddAlternativeString(mail.TypeTextHTML, htmlBody)
|
||||||
|
|
||||||
if err := s.dialer.DialAndSend(m); err != nil {
|
if err := s.client.DialAndSend(m); err != nil {
|
||||||
log.Error().Err(err).Str("to", to).Str("subject", subject).Msg("Failed to send email")
|
log.Error().Err(err).Str("to", to).Str("subject", subject).Msg("Failed to send email")
|
||||||
return fmt.Errorf("failed to send email: %w", err)
|
return fmt.Errorf("failed to send email: %w", err)
|
||||||
}
|
}
|
||||||
@@ -74,26 +92,25 @@ func (s *EmailService) SendEmailWithAttachment(to, subject, htmlBody, textBody s
|
|||||||
log.Debug().Msg("Email sending disabled by feature flag")
|
log.Debug().Msg("Email sending disabled by feature flag")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
m := gomail.NewMessage()
|
m := mail.NewMsg()
|
||||||
m.SetHeader("From", s.cfg.From)
|
if err := m.FromFormat("honeyDue", s.cfg.From); err != nil {
|
||||||
m.SetHeader("To", to)
|
return fmt.Errorf("failed to set from address: %w", err)
|
||||||
m.SetHeader("Subject", subject)
|
}
|
||||||
m.SetBody("text/plain", textBody)
|
if err := m.AddTo(to); err != nil {
|
||||||
m.AddAlternative("text/html", htmlBody)
|
return fmt.Errorf("failed to set to address: %w", err)
|
||||||
|
}
|
||||||
|
m.Subject(subject)
|
||||||
|
m.SetBodyString(mail.TypeTextPlain, textBody)
|
||||||
|
m.AddAlternativeString(mail.TypeTextHTML, htmlBody)
|
||||||
|
|
||||||
if attachment != nil {
|
if attachment != nil {
|
||||||
m.Attach(attachment.Filename,
|
m.AttachReader(attachment.Filename,
|
||||||
gomail.SetCopyFunc(func(w io.Writer) error {
|
bytes.NewReader(attachment.Data),
|
||||||
_, err := w.Write(attachment.Data)
|
mail.WithFileContentType(mail.ContentType(attachment.ContentType)),
|
||||||
return err
|
|
||||||
}),
|
|
||||||
gomail.SetHeader(map[string][]string{
|
|
||||||
"Content-Type": {attachment.ContentType},
|
|
||||||
}),
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.dialer.DialAndSend(m); err != nil {
|
if err := s.client.DialAndSend(m); err != nil {
|
||||||
log.Error().Err(err).Str("to", to).Str("subject", subject).Msg("Failed to send email with attachment")
|
log.Error().Err(err).Str("to", to).Str("subject", subject).Msg("Failed to send email with attachment")
|
||||||
return fmt.Errorf("failed to send email: %w", err)
|
return fmt.Errorf("failed to send email: %w", err)
|
||||||
}
|
}
|
||||||
@@ -108,29 +125,28 @@ func (s *EmailService) SendEmailWithEmbeddedImages(to, subject, htmlBody, textBo
|
|||||||
log.Debug().Msg("Email sending disabled by feature flag")
|
log.Debug().Msg("Email sending disabled by feature flag")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
m := gomail.NewMessage()
|
m := mail.NewMsg()
|
||||||
m.SetHeader("From", s.cfg.From)
|
if err := m.FromFormat("honeyDue", s.cfg.From); err != nil {
|
||||||
m.SetHeader("To", to)
|
return fmt.Errorf("failed to set from address: %w", err)
|
||||||
m.SetHeader("Subject", subject)
|
}
|
||||||
m.SetBody("text/plain", textBody)
|
if err := m.AddTo(to); err != nil {
|
||||||
m.AddAlternative("text/html", htmlBody)
|
return fmt.Errorf("failed to set to address: %w", err)
|
||||||
|
}
|
||||||
|
m.Subject(subject)
|
||||||
|
m.SetBodyString(mail.TypeTextPlain, textBody)
|
||||||
|
m.AddAlternativeString(mail.TypeTextHTML, htmlBody)
|
||||||
|
|
||||||
// Embed each image with Content-ID for inline display
|
// Embed each image with Content-ID for inline display
|
||||||
for _, img := range images {
|
for _, img := range images {
|
||||||
m.Embed(img.Filename,
|
img := img // capture range variable for closure
|
||||||
gomail.SetCopyFunc(func(w io.Writer) error {
|
m.EmbedReader(img.Filename,
|
||||||
_, err := w.Write(img.Data)
|
bytes.NewReader(img.Data),
|
||||||
return err
|
mail.WithFileContentType(mail.ContentType(img.ContentType)),
|
||||||
}),
|
mail.WithFileContentID(img.ContentID),
|
||||||
gomail.SetHeader(map[string][]string{
|
|
||||||
"Content-Type": {img.ContentType},
|
|
||||||
"Content-ID": {"<" + img.ContentID + ">"},
|
|
||||||
"Content-Disposition": {"inline; filename=\"" + img.Filename + "\""},
|
|
||||||
}),
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.dialer.DialAndSend(m); err != nil {
|
if err := s.client.DialAndSend(m); err != nil {
|
||||||
log.Error().Err(err).Str("to", to).Str("subject", subject).Int("images", len(images)).Msg("Failed to send email with embedded images")
|
log.Error().Err(err).Str("to", to).Str("subject", subject).Int("images", len(images)).Msg("Failed to send email with embedded images")
|
||||||
return fmt.Errorf("failed to send email: %w", err)
|
return fmt.Errorf("failed to send email: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
66
internal/services/file_ownership_service.go
Normal file
66
internal/services/file_ownership_service.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/treytartt/honeydue-api/internal/models"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FileOwnershipService checks whether a user owns a file referenced by URL.
|
||||||
|
// It queries task completion images, document files, and document images
|
||||||
|
// to determine ownership through residence access.
|
||||||
|
type FileOwnershipService struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFileOwnershipService creates a new FileOwnershipService
|
||||||
|
func NewFileOwnershipService(db *gorm.DB) *FileOwnershipService {
|
||||||
|
return &FileOwnershipService{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsFileOwnedByUser checks if the given file URL belongs to a record
|
||||||
|
// that the user has access to (via residence membership).
|
||||||
|
func (s *FileOwnershipService) IsFileOwnedByUser(fileURL string, userID uint) (bool, error) {
|
||||||
|
// Check task completion images: image_url -> completion -> task -> residence -> user access
|
||||||
|
var completionImageCount int64
|
||||||
|
err := s.db.Model(&models.TaskCompletionImage{}).
|
||||||
|
Joins("JOIN task_taskcompletion ON task_taskcompletion.id = task_taskcompletionimage.completion_id").
|
||||||
|
Joins("JOIN task_task ON task_task.id = task_taskcompletion.task_id").
|
||||||
|
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_task.residence_id").
|
||||||
|
Where("task_taskcompletionimage.image_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
|
||||||
|
Count(&completionImageCount).Error
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if completionImageCount > 0 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check document files: file_url -> document -> residence -> user access
|
||||||
|
var documentCount int64
|
||||||
|
err = s.db.Model(&models.Document{}).
|
||||||
|
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_document.residence_id").
|
||||||
|
Where("task_document.file_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
|
||||||
|
Count(&documentCount).Error
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if documentCount > 0 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check document images: image_url -> document_image -> document -> residence -> user access
|
||||||
|
var documentImageCount int64
|
||||||
|
err = s.db.Model(&models.DocumentImage{}).
|
||||||
|
Joins("JOIN task_document ON task_document.id = task_documentimage.document_id").
|
||||||
|
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_document.residence_id").
|
||||||
|
Where("task_documentimage.image_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
|
||||||
|
Count(&documentImageCount).Error
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if documentImageCount > 0 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
@@ -36,11 +36,12 @@ var (
|
|||||||
|
|
||||||
// AppleIAPClient handles Apple App Store Server API validation
|
// AppleIAPClient handles Apple App Store Server API validation
|
||||||
type AppleIAPClient struct {
|
type AppleIAPClient struct {
|
||||||
keyID string
|
keyID string
|
||||||
issuerID string
|
issuerID string
|
||||||
bundleID string
|
bundleID string
|
||||||
privateKey *ecdsa.PrivateKey
|
privateKey *ecdsa.PrivateKey
|
||||||
sandbox bool
|
sandbox bool
|
||||||
|
httpClient *http.Client // P-07: Reused across requests
|
||||||
}
|
}
|
||||||
|
|
||||||
// GoogleIAPClient handles Google Play Developer API validation
|
// GoogleIAPClient handles Google Play Developer API validation
|
||||||
@@ -122,6 +123,7 @@ func NewAppleIAPClient(cfg config.AppleIAPConfig) (*AppleIAPClient, error) {
|
|||||||
bundleID: cfg.BundleID,
|
bundleID: cfg.BundleID,
|
||||||
privateKey: ecdsaKey,
|
privateKey: ecdsaKey,
|
||||||
sandbox: cfg.Sandbox,
|
sandbox: cfg.Sandbox,
|
||||||
|
httpClient: &http.Client{Timeout: 30 * time.Second}, // P-07: Single client reused across requests
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,8 +170,8 @@ func (c *AppleIAPClient) ValidateTransaction(ctx context.Context, transactionID
|
|||||||
req.Header.Set("Authorization", "Bearer "+token)
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
client := &http.Client{Timeout: 30 * time.Second}
|
// P-07: Reuse the single http.Client instead of creating one per request
|
||||||
resp, err := client.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to call Apple API: %w", err)
|
return nil, fmt.Errorf("failed to call Apple API: %w", err)
|
||||||
}
|
}
|
||||||
@@ -276,8 +278,8 @@ func (c *AppleIAPClient) ValidateReceipt(ctx context.Context, receiptData string
|
|||||||
req.Header.Set("Authorization", "Bearer "+token)
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
client := &http.Client{Timeout: 30 * time.Second}
|
// P-07: Reuse the single http.Client
|
||||||
resp, err := client.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to call Apple API: %w", err)
|
return nil, fmt.Errorf("failed to call Apple API: %w", err)
|
||||||
}
|
}
|
||||||
@@ -357,8 +359,8 @@ func (c *AppleIAPClient) validateLegacyReceiptWithSandbox(ctx context.Context, r
|
|||||||
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
client := &http.Client{Timeout: 30 * time.Second}
|
// P-07: Reuse the single http.Client
|
||||||
resp, err := client.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to call Apple verifyReceipt: %w", err)
|
return nil, fmt.Errorf("failed to call Apple verifyReceipt: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,9 +4,11 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/treytartt/honeydue-api/internal/apperrors"
|
"github.com/treytartt/honeydue-api/internal/apperrors"
|
||||||
@@ -184,8 +186,33 @@ func (s *NotificationService) GetPreferences(userID uint) (*NotificationPreferen
|
|||||||
return NewNotificationPreferencesResponse(prefs), nil
|
return NewNotificationPreferencesResponse(prefs), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateHourField checks that an optional hour value is in the valid range 0-23.
|
||||||
|
func validateHourField(val *int, fieldName string) error {
|
||||||
|
if val != nil && (*val < 0 || *val > 23) {
|
||||||
|
return apperrors.BadRequest("error.invalid_hour").
|
||||||
|
WithMessage(fmt.Sprintf("%s must be between 0 and 23", fieldName))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// UpdatePreferences updates notification preferences
|
// UpdatePreferences updates notification preferences
|
||||||
func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferencesRequest) (*NotificationPreferencesResponse, error) {
|
func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferencesRequest) (*NotificationPreferencesResponse, error) {
|
||||||
|
// B-12: Validate hour fields are in range 0-23
|
||||||
|
hourFields := []struct {
|
||||||
|
value *int
|
||||||
|
name string
|
||||||
|
}{
|
||||||
|
{req.TaskDueSoonHour, "task_due_soon_hour"},
|
||||||
|
{req.TaskOverdueHour, "task_overdue_hour"},
|
||||||
|
{req.WarrantyExpiringHour, "warranty_expiring_hour"},
|
||||||
|
{req.DailyDigestHour, "daily_digest_hour"},
|
||||||
|
}
|
||||||
|
for _, hf := range hourFields {
|
||||||
|
if err := validateHourField(hf.value, hf.name); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
|
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
@@ -256,7 +283,10 @@ func (s *NotificationService) UpdateUserTimezone(userID uint, timezone string) {
|
|||||||
// Only update if timezone changed (avoid unnecessary DB writes)
|
// Only update if timezone changed (avoid unnecessary DB writes)
|
||||||
if prefs.Timezone == nil || *prefs.Timezone != timezone {
|
if prefs.Timezone == nil || *prefs.Timezone != timezone {
|
||||||
prefs.Timezone = &timezone
|
prefs.Timezone = &timezone
|
||||||
_ = s.notificationRepo.UpdatePreferences(prefs)
|
if err := s.notificationRepo.UpdatePreferences(prefs); err != nil {
|
||||||
|
log.Error().Err(err).Uint("user_id", userID).Str("timezone", timezone).
|
||||||
|
Msg("Failed to update user timezone in notification preferences")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -430,6 +460,7 @@ func (s *NotificationService) UnregisterDevice(registrationID, platform string,
|
|||||||
|
|
||||||
// === Response/Request Types ===
|
// === Response/Request Types ===
|
||||||
|
|
||||||
|
// TODO(hardening): Move to internal/dto/responses/notification.go
|
||||||
// NotificationResponse represents a notification in API response
|
// NotificationResponse represents a notification in API response
|
||||||
type NotificationResponse struct {
|
type NotificationResponse struct {
|
||||||
ID uint `json:"id"`
|
ID uint `json:"id"`
|
||||||
@@ -473,6 +504,7 @@ func NewNotificationResponse(n *models.Notification) NotificationResponse {
|
|||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(hardening): Move to internal/dto/responses/notification.go
|
||||||
// NotificationPreferencesResponse represents notification preferences
|
// NotificationPreferencesResponse represents notification preferences
|
||||||
type NotificationPreferencesResponse struct {
|
type NotificationPreferencesResponse struct {
|
||||||
TaskDueSoon bool `json:"task_due_soon"`
|
TaskDueSoon bool `json:"task_due_soon"`
|
||||||
@@ -511,6 +543,7 @@ func NewNotificationPreferencesResponse(p *models.NotificationPreference) *Notif
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(hardening): Move to internal/dto/requests/notification.go
|
||||||
// UpdatePreferencesRequest represents preferences update request
|
// UpdatePreferencesRequest represents preferences update request
|
||||||
type UpdatePreferencesRequest struct {
|
type UpdatePreferencesRequest struct {
|
||||||
TaskDueSoon *bool `json:"task_due_soon"`
|
TaskDueSoon *bool `json:"task_due_soon"`
|
||||||
@@ -532,6 +565,7 @@ type UpdatePreferencesRequest struct {
|
|||||||
DailyDigestHour *int `json:"daily_digest_hour"`
|
DailyDigestHour *int `json:"daily_digest_hour"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(hardening): Move to internal/dto/responses/notification.go
|
||||||
// DeviceResponse represents a device in API response
|
// DeviceResponse represents a device in API response
|
||||||
type DeviceResponse struct {
|
type DeviceResponse struct {
|
||||||
ID uint `json:"id"`
|
ID uint `json:"id"`
|
||||||
@@ -569,6 +603,7 @@ func NewGCMDeviceResponse(d *models.GCMDevice) *DeviceResponse {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(hardening): Move to internal/dto/requests/notification.go
|
||||||
// RegisterDeviceRequest represents device registration request
|
// RegisterDeviceRequest represents device registration request
|
||||||
type RegisterDeviceRequest struct {
|
type RegisterDeviceRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ func generateTrackingID() string {
|
|||||||
|
|
||||||
// HasSentEmail checks if a specific email type has already been sent to a user
|
// HasSentEmail checks if a specific email type has already been sent to a user
|
||||||
func (s *OnboardingEmailService) HasSentEmail(userID uint, emailType models.OnboardingEmailType) bool {
|
func (s *OnboardingEmailService) HasSentEmail(userID uint, emailType models.OnboardingEmailType) bool {
|
||||||
|
// TODO(hardening): Replace with OnboardingEmailRepository.HasSentEmail()
|
||||||
var count int64
|
var count int64
|
||||||
if err := s.db.Model(&models.OnboardingEmail{}).
|
if err := s.db.Model(&models.OnboardingEmail{}).
|
||||||
Where("user_id = ? AND email_type = ?", userID, emailType).
|
Where("user_id = ? AND email_type = ?", userID, emailType).
|
||||||
@@ -51,6 +52,7 @@ func (s *OnboardingEmailService) HasSentEmail(userID uint, emailType models.Onbo
|
|||||||
|
|
||||||
// RecordEmailSent records that an email was sent to a user
|
// RecordEmailSent records that an email was sent to a user
|
||||||
func (s *OnboardingEmailService) RecordEmailSent(userID uint, emailType models.OnboardingEmailType, trackingID string) error {
|
func (s *OnboardingEmailService) RecordEmailSent(userID uint, emailType models.OnboardingEmailType, trackingID string) error {
|
||||||
|
// TODO(hardening): Replace with OnboardingEmailRepository.Create()
|
||||||
email := &models.OnboardingEmail{
|
email := &models.OnboardingEmail{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
EmailType: emailType,
|
EmailType: emailType,
|
||||||
@@ -66,6 +68,7 @@ func (s *OnboardingEmailService) RecordEmailSent(userID uint, emailType models.O
|
|||||||
|
|
||||||
// RecordEmailOpened records that an email was opened based on tracking ID
|
// RecordEmailOpened records that an email was opened based on tracking ID
|
||||||
func (s *OnboardingEmailService) RecordEmailOpened(trackingID string) error {
|
func (s *OnboardingEmailService) RecordEmailOpened(trackingID string) error {
|
||||||
|
// TODO(hardening): Replace with OnboardingEmailRepository.MarkOpened()
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
result := s.db.Model(&models.OnboardingEmail{}).
|
result := s.db.Model(&models.OnboardingEmail{}).
|
||||||
Where("tracking_id = ? AND opened_at IS NULL", trackingID).
|
Where("tracking_id = ? AND opened_at IS NULL", trackingID).
|
||||||
@@ -84,6 +87,7 @@ func (s *OnboardingEmailService) RecordEmailOpened(trackingID string) error {
|
|||||||
|
|
||||||
// GetEmailHistory gets all onboarding emails for a specific user
|
// GetEmailHistory gets all onboarding emails for a specific user
|
||||||
func (s *OnboardingEmailService) GetEmailHistory(userID uint) ([]models.OnboardingEmail, error) {
|
func (s *OnboardingEmailService) GetEmailHistory(userID uint) ([]models.OnboardingEmail, error) {
|
||||||
|
// TODO(hardening): Replace with OnboardingEmailRepository.FindByUserID()
|
||||||
var emails []models.OnboardingEmail
|
var emails []models.OnboardingEmail
|
||||||
if err := s.db.Where("user_id = ?", userID).Order("sent_at DESC").Find(&emails).Error; err != nil {
|
if err := s.db.Where("user_id = ?", userID).Order("sent_at DESC").Find(&emails).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -105,11 +109,13 @@ func (s *OnboardingEmailService) GetAllEmailHistory(page, pageSize int) ([]model
|
|||||||
|
|
||||||
offset := (page - 1) * pageSize
|
offset := (page - 1) * pageSize
|
||||||
|
|
||||||
|
// TODO(hardening): Replace with OnboardingEmailRepository.CountAll()
|
||||||
// Count total
|
// Count total
|
||||||
if err := s.db.Model(&models.OnboardingEmail{}).Count(&total).Error; err != nil {
|
if err := s.db.Model(&models.OnboardingEmail{}).Count(&total).Error; err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(hardening): Replace with OnboardingEmailRepository.FindAllPaginated()
|
||||||
// Get paginated results with user info
|
// Get paginated results with user info
|
||||||
if err := s.db.Preload("User").
|
if err := s.db.Preload("User").
|
||||||
Order("sent_at DESC").
|
Order("sent_at DESC").
|
||||||
@@ -126,6 +132,7 @@ func (s *OnboardingEmailService) GetAllEmailHistory(page, pageSize int) ([]model
|
|||||||
func (s *OnboardingEmailService) GetEmailStats() (*OnboardingEmailStats, error) {
|
func (s *OnboardingEmailService) GetEmailStats() (*OnboardingEmailStats, error) {
|
||||||
stats := &OnboardingEmailStats{}
|
stats := &OnboardingEmailStats{}
|
||||||
|
|
||||||
|
// TODO(hardening): Replace with OnboardingEmailRepository.GetStats()
|
||||||
// No residence email stats
|
// No residence email stats
|
||||||
var noResTotal, noResOpened int64
|
var noResTotal, noResOpened int64
|
||||||
if err := s.db.Model(&models.OnboardingEmail{}).
|
if err := s.db.Model(&models.OnboardingEmail{}).
|
||||||
@@ -159,6 +166,7 @@ func (s *OnboardingEmailService) GetEmailStats() (*OnboardingEmailStats, error)
|
|||||||
return stats, nil
|
return stats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(hardening): Move to internal/dto/responses/onboarding_email.go
|
||||||
// OnboardingEmailStats represents statistics about onboarding emails
|
// OnboardingEmailStats represents statistics about onboarding emails
|
||||||
type OnboardingEmailStats struct {
|
type OnboardingEmailStats struct {
|
||||||
NoResidenceTotal int64 `json:"no_residence_total"`
|
NoResidenceTotal int64 `json:"no_residence_total"`
|
||||||
@@ -173,6 +181,7 @@ func (s *OnboardingEmailService) UsersNeedingNoResidenceEmail() ([]models.User,
|
|||||||
|
|
||||||
twoDaysAgo := time.Now().UTC().AddDate(0, 0, -2)
|
twoDaysAgo := time.Now().UTC().AddDate(0, 0, -2)
|
||||||
|
|
||||||
|
// TODO(hardening): Replace with OnboardingEmailRepository.FindUsersWithoutResidence()
|
||||||
// Find users who:
|
// Find users who:
|
||||||
// 1. Are verified
|
// 1. Are verified
|
||||||
// 2. Registered 2+ days ago
|
// 2. Registered 2+ days ago
|
||||||
@@ -201,6 +210,7 @@ func (s *OnboardingEmailService) UsersNeedingNoTasksEmail() ([]models.User, erro
|
|||||||
|
|
||||||
fiveDaysAgo := time.Now().UTC().AddDate(0, 0, -5)
|
fiveDaysAgo := time.Now().UTC().AddDate(0, 0, -5)
|
||||||
|
|
||||||
|
// TODO(hardening): Replace with OnboardingEmailRepository.FindUsersWithoutTasks()
|
||||||
// Find users who:
|
// Find users who:
|
||||||
// 1. Are verified
|
// 1. Are verified
|
||||||
// 2. Have at least one residence
|
// 2. Have at least one residence
|
||||||
@@ -325,6 +335,7 @@ func (s *OnboardingEmailService) sendNoTasksEmail(user models.User) error {
|
|||||||
// SendOnboardingEmailToUser manually sends an onboarding email to a specific user
|
// SendOnboardingEmailToUser manually sends an onboarding email to a specific user
|
||||||
// This is used by admin to force-send emails regardless of eligibility criteria
|
// This is used by admin to force-send emails regardless of eligibility criteria
|
||||||
func (s *OnboardingEmailService) SendOnboardingEmailToUser(userID uint, emailType models.OnboardingEmailType) error {
|
func (s *OnboardingEmailService) SendOnboardingEmailToUser(userID uint, emailType models.OnboardingEmailType) error {
|
||||||
|
// TODO(hardening): Replace with UserRepository.FindByID() (inject UserRepository)
|
||||||
// Load the user
|
// Load the user
|
||||||
var user models.User
|
var user models.User
|
||||||
if err := s.db.First(&user, userID).Error; err != nil {
|
if err := s.db.First(&user, userID).Error; err != nil {
|
||||||
@@ -362,6 +373,7 @@ func (s *OnboardingEmailService) SendOnboardingEmailToUser(userID uint, emailTyp
|
|||||||
// If already sent before, delete the old record first to allow re-recording
|
// If already sent before, delete the old record first to allow re-recording
|
||||||
// This allows admins to "resend" emails while still tracking them
|
// This allows admins to "resend" emails while still tracking them
|
||||||
if alreadySent {
|
if alreadySent {
|
||||||
|
// TODO(hardening): Replace with OnboardingEmailRepository.DeleteByUserAndType()
|
||||||
if err := s.db.Where("user_id = ? AND email_type = ?", userID, emailType).Delete(&models.OnboardingEmail{}).Error; err != nil {
|
if err := s.db.Where("user_id = ? AND email_type = ?", userID, emailType).Delete(&models.OnboardingEmail{}).Error; err != nil {
|
||||||
log.Error().Err(err).Uint("user_id", userID).Str("email_type", string(emailType)).Msg("Failed to delete old onboarding email record before resend")
|
log.Error().Err(err).Uint("user_id", userID).Str("email_type", string(emailType)).Msg("Failed to delete old onboarding email record before resend")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jung-kurt/gofpdf"
|
"github.com/go-pdf/fpdf"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PDFService handles PDF generation
|
// PDFService handles PDF generation
|
||||||
@@ -18,7 +18,7 @@ func NewPDFService() *PDFService {
|
|||||||
|
|
||||||
// GenerateTasksReportPDF generates a PDF report from task report data
|
// GenerateTasksReportPDF generates a PDF report from task report data
|
||||||
func (s *PDFService) GenerateTasksReportPDF(report *TasksReportResponse) ([]byte, error) {
|
func (s *PDFService) GenerateTasksReportPDF(report *TasksReportResponse) ([]byte, error) {
|
||||||
pdf := gofpdf.New("P", "mm", "A4", "")
|
pdf := fpdf.New("P", "mm", "A4", "")
|
||||||
pdf.SetMargins(15, 15, 15)
|
pdf.SetMargins(15, 15, 15)
|
||||||
pdf.AddPage()
|
pdf.AddPage()
|
||||||
|
|
||||||
|
|||||||
@@ -133,14 +133,16 @@ func (s *ResidenceService) GetMyResidences(userID uint, now time.Time) (*respons
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attach completion summaries (honeycomb grid data)
|
// P-01: Batch fetch completion summaries in 2 queries total instead of 2*N
|
||||||
for i := range residenceResponses {
|
summaries, err := s.taskRepo.GetBatchCompletionSummaries(residenceIDs, now, 10)
|
||||||
summary, err := s.taskRepo.GetCompletionSummary(residenceResponses[i].ID, now, 10)
|
if err != nil {
|
||||||
if err != nil {
|
log.Warn().Err(err).Msg("Failed to fetch batch completion summaries")
|
||||||
log.Warn().Err(err).Uint("residence_id", residenceResponses[i].ID).Msg("Failed to fetch completion summary")
|
} else {
|
||||||
continue
|
for i := range residenceResponses {
|
||||||
|
if summary, ok := summaries[residenceResponses[i].ID]; ok {
|
||||||
|
residenceResponses[i].CompletionSummary = summary
|
||||||
|
}
|
||||||
}
|
}
|
||||||
residenceResponses[i].CompletionSummary = summary
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -17,7 +18,8 @@ import (
|
|||||||
|
|
||||||
// StorageService handles file uploads to local filesystem
|
// StorageService handles file uploads to local filesystem
|
||||||
type StorageService struct {
|
type StorageService struct {
|
||||||
cfg *config.StorageConfig
|
cfg *config.StorageConfig
|
||||||
|
allowedTypes map[string]struct{} // P-12: Parsed once at init for O(1) lookups
|
||||||
}
|
}
|
||||||
|
|
||||||
// UploadResult contains information about an uploaded file
|
// UploadResult contains information about an uploaded file
|
||||||
@@ -44,9 +46,18 @@ func NewStorageService(cfg *config.StorageConfig) (*StorageService, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().Str("upload_dir", cfg.UploadDir).Msg("Storage service initialized")
|
// P-12: Parse AllowedTypes once at initialization for O(1) lookups
|
||||||
|
allowedTypes := make(map[string]struct{})
|
||||||
|
for _, t := range strings.Split(cfg.AllowedTypes, ",") {
|
||||||
|
trimmed := strings.TrimSpace(t)
|
||||||
|
if trimmed != "" {
|
||||||
|
allowedTypes[trimmed] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &StorageService{cfg: cfg}, nil
|
log.Info().Str("upload_dir", cfg.UploadDir).Int("allowed_types", len(allowedTypes)).Msg("Storage service initialized")
|
||||||
|
|
||||||
|
return &StorageService{cfg: cfg, allowedTypes: allowedTypes}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Upload saves a file to the local filesystem
|
// Upload saves a file to the local filesystem
|
||||||
@@ -56,17 +67,47 @@ func (s *StorageService) Upload(file *multipart.FileHeader, category string) (*U
|
|||||||
return nil, fmt.Errorf("file size %d exceeds maximum allowed %d bytes", file.Size, s.cfg.MaxFileSize)
|
return nil, fmt.Errorf("file size %d exceeds maximum allowed %d bytes", file.Size, s.cfg.MaxFileSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get MIME type
|
// Get claimed MIME type from header
|
||||||
mimeType := file.Header.Get("Content-Type")
|
claimedMimeType := file.Header.Get("Content-Type")
|
||||||
if mimeType == "" {
|
if claimedMimeType == "" {
|
||||||
mimeType = "application/octet-stream"
|
claimedMimeType = "application/octet-stream"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate MIME type
|
// S-09: Detect actual content type from file bytes to prevent disguised uploads
|
||||||
|
src, err := file.Open()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open uploaded file: %w", err)
|
||||||
|
}
|
||||||
|
defer src.Close()
|
||||||
|
|
||||||
|
// Read the first 512 bytes for content type detection
|
||||||
|
sniffBuf := make([]byte, 512)
|
||||||
|
n, err := src.Read(sniffBuf)
|
||||||
|
if err != nil && n == 0 {
|
||||||
|
return nil, fmt.Errorf("failed to read file for content type detection: %w", err)
|
||||||
|
}
|
||||||
|
detectedMimeType := http.DetectContentType(sniffBuf[:n])
|
||||||
|
|
||||||
|
// Validate that the detected type matches the claimed type (at the category level)
|
||||||
|
// Allow application/octet-stream from detection since DetectContentType may not
|
||||||
|
// recognize all valid types, but the claimed type must still be in our allowed list
|
||||||
|
if detectedMimeType != "application/octet-stream" && !s.mimeTypesCompatible(claimedMimeType, detectedMimeType) {
|
||||||
|
return nil, fmt.Errorf("file content type mismatch: claimed %s but detected %s", claimedMimeType, detectedMimeType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the claimed MIME type (which is more specific) if it's allowed
|
||||||
|
mimeType := claimedMimeType
|
||||||
|
|
||||||
|
// Validate MIME type against allowed list
|
||||||
if !s.isAllowedType(mimeType) {
|
if !s.isAllowedType(mimeType) {
|
||||||
return nil, fmt.Errorf("file type %s is not allowed", mimeType)
|
return nil, fmt.Errorf("file type %s is not allowed", mimeType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Seek back to beginning after sniffing
|
||||||
|
if _, err := src.Seek(0, io.SeekStart); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to seek file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Generate unique filename
|
// Generate unique filename
|
||||||
ext := filepath.Ext(file.Filename)
|
ext := filepath.Ext(file.Filename)
|
||||||
if ext == "" {
|
if ext == "" {
|
||||||
@@ -83,15 +124,11 @@ func (s *StorageService) Upload(file *multipart.FileHeader, category string) (*U
|
|||||||
subdir = "completions"
|
subdir = "completions"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Full path
|
// S-18: Sanitize path to prevent traversal attacks
|
||||||
destPath := filepath.Join(s.cfg.UploadDir, subdir, newFilename)
|
destPath, err := SafeResolvePath(s.cfg.UploadDir, filepath.Join(subdir, newFilename))
|
||||||
|
|
||||||
// Open source file
|
|
||||||
src, err := file.Open()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to open uploaded file: %w", err)
|
return nil, fmt.Errorf("invalid upload path: %w", err)
|
||||||
}
|
}
|
||||||
defer src.Close()
|
|
||||||
|
|
||||||
// Create destination file
|
// Create destination file
|
||||||
dst, err := os.Create(destPath)
|
dst, err := os.Create(destPath)
|
||||||
@@ -131,19 +168,11 @@ func (s *StorageService) Delete(fileURL string) error {
|
|||||||
// Convert URL to file path
|
// Convert URL to file path
|
||||||
relativePath := strings.TrimPrefix(fileURL, s.cfg.BaseURL)
|
relativePath := strings.TrimPrefix(fileURL, s.cfg.BaseURL)
|
||||||
relativePath = strings.TrimPrefix(relativePath, "/")
|
relativePath = strings.TrimPrefix(relativePath, "/")
|
||||||
fullPath := filepath.Join(s.cfg.UploadDir, relativePath)
|
|
||||||
|
|
||||||
// Security check: ensure path is within upload directory
|
// S-18: Use SafeResolvePath to prevent path traversal
|
||||||
absUploadDir, err := filepath.Abs(s.cfg.UploadDir)
|
fullPath, err := SafeResolvePath(s.cfg.UploadDir, relativePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to resolve upload directory: %w", err)
|
return fmt.Errorf("invalid file path: %w", err)
|
||||||
}
|
|
||||||
absFilePath, err := filepath.Abs(fullPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to resolve file path: %w", err)
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(absFilePath, absUploadDir+string(filepath.Separator)) && absFilePath != absUploadDir {
|
|
||||||
return fmt.Errorf("invalid file path")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.Remove(fullPath); err != nil {
|
if err := os.Remove(fullPath); err != nil {
|
||||||
@@ -157,15 +186,23 @@ func (s *StorageService) Delete(fileURL string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// isAllowedType checks if the MIME type is in the allowed list
|
// isAllowedType checks if the MIME type is in the allowed list.
|
||||||
|
// P-12: Uses the pre-parsed allowedTypes map for O(1) lookups instead of
|
||||||
|
// splitting the config string on every call.
|
||||||
func (s *StorageService) isAllowedType(mimeType string) bool {
|
func (s *StorageService) isAllowedType(mimeType string) bool {
|
||||||
allowed := strings.Split(s.cfg.AllowedTypes, ",")
|
_, ok := s.allowedTypes[mimeType]
|
||||||
for _, t := range allowed {
|
return ok
|
||||||
if strings.TrimSpace(t) == mimeType {
|
}
|
||||||
return true
|
|
||||||
}
|
// mimeTypesCompatible checks if the claimed and detected MIME types are compatible.
|
||||||
|
// Two MIME types are compatible if they share the same primary type (e.g., both "image/*").
|
||||||
|
func (s *StorageService) mimeTypesCompatible(claimed, detected string) bool {
|
||||||
|
claimedParts := strings.SplitN(claimed, "/", 2)
|
||||||
|
detectedParts := strings.SplitN(detected, "/", 2)
|
||||||
|
if len(claimedParts) < 1 || len(detectedParts) < 1 {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
return false
|
return claimedParts[0] == detectedParts[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
// getExtensionFromMimeType returns a file extension for common MIME types
|
// getExtensionFromMimeType returns a file extension for common MIME types
|
||||||
@@ -191,5 +228,12 @@ func (s *StorageService) GetUploadDir() string {
|
|||||||
// NewStorageServiceForTest creates a StorageService without creating directories.
|
// NewStorageServiceForTest creates a StorageService without creating directories.
|
||||||
// This is intended only for unit tests that need a StorageService with a known config.
|
// This is intended only for unit tests that need a StorageService with a known config.
|
||||||
func NewStorageServiceForTest(cfg *config.StorageConfig) *StorageService {
|
func NewStorageServiceForTest(cfg *config.StorageConfig) *StorageService {
|
||||||
return &StorageService{cfg: cfg}
|
allowedTypes := make(map[string]struct{})
|
||||||
|
for _, t := range strings.Split(cfg.AllowedTypes, ",") {
|
||||||
|
trimmed := strings.TrimSpace(t)
|
||||||
|
if trimmed != "" {
|
||||||
|
allowedTypes[trimmed] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &StorageService{cfg: cfg, allowedTypes: allowedTypes}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ package services
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
"github.com/spf13/viper"
|
||||||
"github.com/stripe/stripe-go/v81"
|
"github.com/stripe/stripe-go/v81"
|
||||||
portalsession "github.com/stripe/stripe-go/v81/billingportal/session"
|
portalsession "github.com/stripe/stripe-go/v81/billingportal/session"
|
||||||
checkoutsession "github.com/stripe/stripe-go/v81/checkout/session"
|
checkoutsession "github.com/stripe/stripe-go/v81/checkout/session"
|
||||||
@@ -34,7 +34,8 @@ func NewStripeService(
|
|||||||
subscriptionRepo *repositories.SubscriptionRepository,
|
subscriptionRepo *repositories.SubscriptionRepository,
|
||||||
userRepo *repositories.UserRepository,
|
userRepo *repositories.UserRepository,
|
||||||
) *StripeService {
|
) *StripeService {
|
||||||
key := os.Getenv("STRIPE_SECRET_KEY")
|
// S-21: Use Viper config instead of os.Getenv for consistent configuration management
|
||||||
|
key := viper.GetString("STRIPE_SECRET_KEY")
|
||||||
if key == "" {
|
if key == "" {
|
||||||
log.Warn().Msg("STRIPE_SECRET_KEY not set, Stripe integration will not work")
|
log.Warn().Msg("STRIPE_SECRET_KEY not set, Stripe integration will not work")
|
||||||
} else {
|
} else {
|
||||||
@@ -42,7 +43,7 @@ func NewStripeService(
|
|||||||
log.Info().Msg("Stripe API key configured")
|
log.Info().Msg("Stripe API key configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
webhookSecret := os.Getenv("STRIPE_WEBHOOK_SECRET")
|
webhookSecret := viper.GetString("STRIPE_WEBHOOK_SECRET")
|
||||||
if webhookSecret == "" {
|
if webhookSecret == "" {
|
||||||
log.Warn().Msg("STRIPE_WEBHOOK_SECRET not set, webhook verification will fail")
|
log.Warn().Msg("STRIPE_WEBHOOK_SECRET not set, webhook verification will fail")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -202,18 +202,19 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getUserUsage calculates current usage for a user.
|
// getUserUsage calculates current usage for a user.
|
||||||
|
// P-10: Uses CountByOwner for properties count instead of loading all owned residences.
|
||||||
// Uses batch COUNT queries (O(1) queries) instead of per-residence queries (O(N)).
|
// Uses batch COUNT queries (O(1) queries) instead of per-residence queries (O(N)).
|
||||||
func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error) {
|
func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error) {
|
||||||
residences, err := s.residenceRepo.FindOwnedByUser(userID)
|
// P-10: Use CountByOwner for an efficient COUNT query instead of loading all records
|
||||||
|
propertiesCount, err := s.residenceRepo.CountByOwner(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
propertiesCount := int64(len(residences))
|
|
||||||
|
|
||||||
// Collect residence IDs for batch queries
|
// Still need residence IDs for batch counting tasks/contractors/documents
|
||||||
residenceIDs := make([]uint, len(residences))
|
residenceIDs, err := s.residenceRepo.FindResidenceIDsByOwner(userID)
|
||||||
for i, r := range residences {
|
if err != nil {
|
||||||
residenceIDs[i] = r.ID
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count tasks, contractors, and documents across all residences with single queries each
|
// Count tasks, contractors, and documents across all residences with single queries each
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ func (s *TaskService) ListTasks(userID uint, daysThreshold int, now time.Time) (
|
|||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := responses.NewKanbanBoardResponseForAll(board)
|
resp := responses.NewKanbanBoardResponseForAll(board, now)
|
||||||
// NOTE: Summary statistics are calculated client-side from kanban data
|
// NOTE: Summary statistics are calculated client-side from kanban data
|
||||||
return &resp, nil
|
return &resp, nil
|
||||||
}
|
}
|
||||||
@@ -157,7 +157,7 @@ func (s *TaskService) GetTasksByResidence(residenceID, userID uint, daysThreshol
|
|||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := responses.NewKanbanBoardResponse(board, residenceID)
|
resp := responses.NewKanbanBoardResponse(board, residenceID, now)
|
||||||
// NOTE: Summary statistics are calculated client-side from kanban data
|
// NOTE: Summary statistics are calculated client-side from kanban data
|
||||||
return &resp, nil
|
return &resp, nil
|
||||||
}
|
}
|
||||||
@@ -601,8 +601,8 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
|
|||||||
task.InProgress = false
|
task.InProgress = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// P1-5: Wrap completion creation and task update in a transaction.
|
// P1-5 + B-07: Wrap completion creation, task update, and image creation
|
||||||
// If either operation fails, both are rolled back to prevent orphaned completions.
|
// in a single transaction for atomicity. If any operation fails, all are rolled back.
|
||||||
txErr := s.taskRepo.DB().Transaction(func(tx *gorm.DB) error {
|
txErr := s.taskRepo.DB().Transaction(func(tx *gorm.DB) error {
|
||||||
if err := s.taskRepo.CreateCompletionTx(tx, completion); err != nil {
|
if err := s.taskRepo.CreateCompletionTx(tx, completion); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -610,6 +610,18 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
|
|||||||
if err := s.taskRepo.UpdateTx(tx, task); err != nil {
|
if err := s.taskRepo.UpdateTx(tx, task); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
// B-07: Create images inside the same transaction as completion
|
||||||
|
for _, imageURL := range req.ImageURLs {
|
||||||
|
if imageURL != "" {
|
||||||
|
img := &models.TaskCompletionImage{
|
||||||
|
CompletionID: completion.ID,
|
||||||
|
ImageURL: imageURL,
|
||||||
|
}
|
||||||
|
if err := tx.Create(img).Error; err != nil {
|
||||||
|
return fmt.Errorf("failed to create completion image: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if txErr != nil {
|
if txErr != nil {
|
||||||
@@ -621,19 +633,6 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
|
|||||||
return nil, apperrors.Internal(txErr)
|
return nil, apperrors.Internal(txErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create images if provided
|
|
||||||
for _, imageURL := range req.ImageURLs {
|
|
||||||
if imageURL != "" {
|
|
||||||
img := &models.TaskCompletionImage{
|
|
||||||
CompletionID: completion.ID,
|
|
||||||
ImageURL: imageURL,
|
|
||||||
}
|
|
||||||
if err := s.taskRepo.CreateCompletionImage(img); err != nil {
|
|
||||||
log.Error().Err(err).Uint("completion_id", completion.ID).Msg("Failed to create completion image")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reload completion with user info and images
|
// Reload completion with user info and images
|
||||||
completion, err = s.taskRepo.FindCompletionByID(completion.ID)
|
completion, err = s.taskRepo.FindCompletionByID(completion.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -663,8 +662,10 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// QuickComplete creates a minimal task completion (for widget use)
|
// QuickComplete creates a minimal task completion (for widget use).
|
||||||
// Returns only success/error, no response body
|
// LE-01: The entire operation (completion creation + task update) is wrapped in a
|
||||||
|
// transaction for atomicity.
|
||||||
|
// Returns only success/error, no response body.
|
||||||
func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
|
func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
|
||||||
// Get the task
|
// Get the task
|
||||||
task, err := s.taskRepo.FindByID(taskID)
|
task, err := s.taskRepo.FindByID(taskID)
|
||||||
@@ -697,10 +698,6 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
|
|||||||
CompletedFromColumn: completedFromColumn,
|
CompletedFromColumn: completedFromColumn,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.taskRepo.CreateCompletion(completion); err != nil {
|
|
||||||
return apperrors.Internal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update next_due_date and in_progress based on frequency
|
// Update next_due_date and in_progress based on frequency
|
||||||
// Determine interval days: Custom frequency uses task.CustomIntervalDays, otherwise use frequency.Days
|
// Determine interval days: Custom frequency uses task.CustomIntervalDays, otherwise use frequency.Days
|
||||||
// Note: Frequency is no longer preloaded for performance, so we load it separately if needed
|
// Note: Frequency is no longer preloaded for performance, so we load it separately if needed
|
||||||
@@ -729,7 +726,6 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
|
|||||||
} else {
|
} else {
|
||||||
// Recurring task - calculate next due date from completion date + interval
|
// Recurring task - calculate next due date from completion date + interval
|
||||||
nextDue := completedAt.AddDate(0, 0, *quickIntervalDays)
|
nextDue := completedAt.AddDate(0, 0, *quickIntervalDays)
|
||||||
// frequencyName was already set when loading frequency above
|
|
||||||
log.Info().
|
log.Info().
|
||||||
Uint("task_id", task.ID).
|
Uint("task_id", task.ID).
|
||||||
Str("frequency_name", frequencyName).
|
Str("frequency_name", frequencyName).
|
||||||
@@ -742,12 +738,23 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
|
|||||||
// Reset in_progress to false
|
// Reset in_progress to false
|
||||||
task.InProgress = false
|
task.InProgress = false
|
||||||
}
|
}
|
||||||
if err := s.taskRepo.Update(task); err != nil {
|
|
||||||
if errors.Is(err, repositories.ErrVersionConflict) {
|
// LE-01: Wrap completion creation and task update in a transaction for atomicity
|
||||||
|
txErr := s.taskRepo.DB().Transaction(func(tx *gorm.DB) error {
|
||||||
|
if err := s.taskRepo.CreateCompletionTx(tx, completion); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := s.taskRepo.UpdateTx(tx, task); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if txErr != nil {
|
||||||
|
if errors.Is(txErr, repositories.ErrVersionConflict) {
|
||||||
return apperrors.Conflict("error.version_conflict")
|
return apperrors.Conflict("error.version_conflict")
|
||||||
}
|
}
|
||||||
log.Error().Err(err).Uint("task_id", task.ID).Msg("Failed to update task after quick completion")
|
log.Error().Err(txErr).Uint("task_id", task.ID).Msg("Failed to create completion and update task in QuickComplete")
|
||||||
return apperrors.Internal(err) // Return error so caller knows the update failed
|
return apperrors.Internal(txErr)
|
||||||
}
|
}
|
||||||
log.Info().Uint("task_id", task.ID).Msg("QuickComplete: Task updated successfully")
|
log.Info().Uint("task_id", task.ID).Msg("QuickComplete: Task updated successfully")
|
||||||
|
|
||||||
@@ -813,8 +820,16 @@ func (s *TaskService) sendTaskCompletedNotification(task *models.Task, completio
|
|||||||
// Send email notification (to everyone INCLUDING the person who completed it)
|
// Send email notification (to everyone INCLUDING the person who completed it)
|
||||||
// Check user's email notification preferences first
|
// Check user's email notification preferences first
|
||||||
if s.emailService != nil && user.Email != "" && s.notificationService != nil {
|
if s.emailService != nil && user.Email != "" && s.notificationService != nil {
|
||||||
prefs, err := s.notificationService.GetPreferences(user.ID)
|
prefs, prefsErr := s.notificationService.GetPreferences(user.ID)
|
||||||
if err != nil || (prefs != nil && prefs.EmailTaskCompleted) {
|
// LE-06: Log fail-open behavior when preferences cannot be loaded
|
||||||
|
if prefsErr != nil {
|
||||||
|
log.Warn().
|
||||||
|
Err(prefsErr).
|
||||||
|
Uint("user_id", user.ID).
|
||||||
|
Uint("task_id", task.ID).
|
||||||
|
Msg("Failed to load notification preferences, falling back to sending email (fail-open)")
|
||||||
|
}
|
||||||
|
if prefsErr != nil || (prefs != nil && prefs.EmailTaskCompleted) {
|
||||||
// Send email if we couldn't get prefs (fail-open) or if email notifications are enabled
|
// Send email if we couldn't get prefs (fail-open) or if email notifications are enabled
|
||||||
if err := s.emailService.SendTaskCompletedEmail(
|
if err := s.emailService.SendTaskCompletedEmail(
|
||||||
user.Email,
|
user.Email,
|
||||||
|
|||||||
@@ -32,7 +32,8 @@ func (s *UserService) ListUsersInSharedResidences(userID uint) ([]responses.User
|
|||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var result []responses.UserSummary
|
// F-23: Initialize as empty slice so JSON serialization produces [] instead of null
|
||||||
|
result := make([]responses.UserSummary, 0, len(users))
|
||||||
for _, u := range users {
|
for _, u := range users {
|
||||||
result = append(result, responses.UserSummary{
|
result = append(result, responses.UserSummary{
|
||||||
ID: u.ID,
|
ID: u.ID,
|
||||||
@@ -72,7 +73,8 @@ func (s *UserService) ListProfilesInSharedResidences(userID uint) ([]responses.U
|
|||||||
return nil, apperrors.Internal(err)
|
return nil, apperrors.Internal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var result []responses.UserProfileSummary
|
// F-23: Initialize as empty slice so JSON serialization produces [] instead of null
|
||||||
|
result := make([]responses.UserProfileSummary, 0, len(profiles))
|
||||||
for _, p := range profiles {
|
for _, p := range profiles {
|
||||||
result = append(result, responses.UserProfileSummary{
|
result = append(result, responses.UserProfileSummary{
|
||||||
ID: p.ID,
|
ID: p.ID,
|
||||||
|
|||||||
@@ -27,7 +27,10 @@ var testDB *gorm.DB
|
|||||||
// testUserID is a user ID that exists in the database for foreign key constraints
|
// testUserID is a user ID that exists in the database for foreign key constraints
|
||||||
var testUserID uint = 1
|
var testUserID uint = 1
|
||||||
|
|
||||||
// TestMain sets up the database connection for all tests in this package
|
// TestMain sets up the database connection for all tests in this package.
|
||||||
|
// If the database is not available, testDB remains nil and individual tests
|
||||||
|
// will call t.Skip() instead of using os.Exit(0), which preserves proper
|
||||||
|
// test reporting and coverage output.
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
dsn := os.Getenv("TEST_DATABASE_URL")
|
dsn := os.Getenv("TEST_DATABASE_URL")
|
||||||
if dsn == "" {
|
if dsn == "" {
|
||||||
@@ -39,15 +42,23 @@ func TestMain(m *testing.M) {
|
|||||||
Logger: logger.Default.LogMode(logger.Silent),
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
println("Skipping consistency integration tests: database not available")
|
// Explicitly nil out testDB; individual tests will t.Skip("Database not available")
|
||||||
|
testDB = nil
|
||||||
|
println("Consistency integration tests will be skipped: database not available")
|
||||||
println("Set TEST_DATABASE_URL to run these tests")
|
println("Set TEST_DATABASE_URL to run these tests")
|
||||||
os.Exit(0)
|
os.Exit(m.Run())
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlDB, err := testDB.DB()
|
sqlDB, err := testDB.DB()
|
||||||
if err != nil || sqlDB.Ping() != nil {
|
if err != nil {
|
||||||
println("Failed to connect to database")
|
println("Failed to get underlying DB:", err.Error())
|
||||||
os.Exit(0)
|
testDB = nil
|
||||||
|
os.Exit(m.Run())
|
||||||
|
}
|
||||||
|
if pingErr := sqlDB.Ping(); pingErr != nil {
|
||||||
|
println("Failed to ping database:", pingErr.Error())
|
||||||
|
testDB = nil
|
||||||
|
os.Exit(m.Run())
|
||||||
}
|
}
|
||||||
|
|
||||||
println("Database connected, running consistency tests...")
|
println("Database connected, running consistency tests...")
|
||||||
|
|||||||
@@ -17,7 +17,10 @@ import (
|
|||||||
// testDB holds the database connection for integration tests
|
// testDB holds the database connection for integration tests
|
||||||
var testDB *gorm.DB
|
var testDB *gorm.DB
|
||||||
|
|
||||||
// TestMain sets up the database connection for all tests in this package
|
// TestMain sets up the database connection for all tests in this package.
|
||||||
|
// If the database is not available, testDB remains nil and individual tests
|
||||||
|
// will call t.Skip() instead of using os.Exit(0), which preserves proper
|
||||||
|
// test reporting and coverage output.
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
// Get database URL from environment or use default
|
// Get database URL from environment or use default
|
||||||
dsn := os.Getenv("TEST_DATABASE_URL")
|
dsn := os.Getenv("TEST_DATABASE_URL")
|
||||||
@@ -30,22 +33,25 @@ func TestMain(m *testing.M) {
|
|||||||
Logger: logger.Default.LogMode(logger.Silent),
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Print message and skip tests if database is not available
|
// Explicitly nil out testDB; individual tests will t.Skip("Database not available")
|
||||||
println("Skipping scope integration tests: database not available")
|
testDB = nil
|
||||||
|
println("Scope integration tests will be skipped: database not available")
|
||||||
println("Set TEST_DATABASE_URL to run these tests")
|
println("Set TEST_DATABASE_URL to run these tests")
|
||||||
println("Error:", err.Error())
|
println("Error:", err.Error())
|
||||||
os.Exit(0)
|
os.Exit(m.Run())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify connection works
|
// Verify connection works
|
||||||
sqlDB, err := testDB.DB()
|
sqlDB, err := testDB.DB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
println("Failed to get underlying DB:", err.Error())
|
println("Failed to get underlying DB:", err.Error())
|
||||||
os.Exit(0)
|
testDB = nil
|
||||||
|
os.Exit(m.Run())
|
||||||
}
|
}
|
||||||
if err := sqlDB.Ping(); err != nil {
|
if err := sqlDB.Ping(); err != nil {
|
||||||
println("Failed to ping database:", err.Error())
|
println("Failed to ping database:", err.Error())
|
||||||
os.Exit(0)
|
testDB = nil
|
||||||
|
os.Exit(m.Run())
|
||||||
}
|
}
|
||||||
|
|
||||||
println("Database connected successfully, running integration tests...")
|
println("Database connected successfully, running integration tests...")
|
||||||
@@ -57,7 +63,9 @@ func TestMain(m *testing.M) {
|
|||||||
&models.Residence{},
|
&models.Residence{},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
os.Exit(1)
|
println("Failed to run migrations:", err.Error())
|
||||||
|
testDB = nil
|
||||||
|
os.Exit(m.Run())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run tests
|
// Run tests
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package testutil
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -71,7 +72,12 @@ func SetupTestDB(t *testing.T) *gorm.DB {
|
|||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupTestRouter creates a test Echo router with the custom error handler
|
// SetupTestRouter creates a test Echo router with the custom error handler.
|
||||||
|
// Uses apperrors.HTTPErrorHandler which is the base error handler shared with
|
||||||
|
// production (router.customHTTPErrorHandler). Both handle AppError, ValidationErrors,
|
||||||
|
// and echo.HTTPError identically. Production additionally maps legacy service sentinel
|
||||||
|
// errors (e.g., services.ErrTaskNotFound) which are being migrated to AppError types.
|
||||||
|
// Tests exercise handlers that return AppError, so this handler covers all test scenarios.
|
||||||
func SetupTestRouter() *echo.Echo {
|
func SetupTestRouter() *echo.Echo {
|
||||||
e := echo.New()
|
e := echo.New()
|
||||||
e.Validator = validator.NewCustomValidator()
|
e.Validator = validator.NewCustomValidator()
|
||||||
@@ -79,17 +85,52 @@ func SetupTestRouter() *echo.Echo {
|
|||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
|
|
||||||
// MakeRequest makes a test HTTP request and returns the response
|
// MakeRequest makes a test HTTP request and returns the response.
|
||||||
|
// Errors from JSON marshaling and HTTP request construction are checked and
|
||||||
|
// will panic if they occur, since these indicate programming errors in tests.
|
||||||
|
// Prefer MakeRequestT for new tests, which uses t.Fatal for better reporting.
|
||||||
func MakeRequest(router *echo.Echo, method, path string, body interface{}, token string) *httptest.ResponseRecorder {
|
func MakeRequest(router *echo.Echo, method, path string, body interface{}, token string) *httptest.ResponseRecorder {
|
||||||
var reqBody *bytes.Buffer
|
var reqBody *bytes.Buffer
|
||||||
if body != nil {
|
if body != nil {
|
||||||
jsonBody, _ := json.Marshal(body)
|
jsonBody, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("testutil.MakeRequest: failed to marshal request body: %v", err))
|
||||||
|
}
|
||||||
reqBody = bytes.NewBuffer(jsonBody)
|
reqBody = bytes.NewBuffer(jsonBody)
|
||||||
} else {
|
} else {
|
||||||
reqBody = bytes.NewBuffer(nil)
|
reqBody = bytes.NewBuffer(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, _ := http.NewRequest(method, path, reqBody)
|
req, err := http.NewRequest(method, path, reqBody)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("testutil.MakeRequest: failed to create HTTP request: %v", err))
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
if token != "" {
|
||||||
|
req.Header.Set("Authorization", "Token "+token)
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
return rec
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeRequestT is like MakeRequest but accepts a *testing.T for proper test
|
||||||
|
// failure reporting. Prefer this over MakeRequest in new tests.
|
||||||
|
func MakeRequestT(t *testing.T, router *echo.Echo, method, path string, body interface{}, token string) *httptest.ResponseRecorder {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var reqBody *bytes.Buffer
|
||||||
|
if body != nil {
|
||||||
|
jsonBody, err := json.Marshal(body)
|
||||||
|
require.NoError(t, err, "failed to marshal request body")
|
||||||
|
reqBody = bytes.NewBuffer(jsonBody)
|
||||||
|
} else {
|
||||||
|
reqBody = bytes.NewBuffer(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest(method, path, reqBody)
|
||||||
|
require.NoError(t, err, "failed to create HTTP request")
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
if token != "" {
|
if token != "" {
|
||||||
req.Header.Set("Authorization", "Token "+token)
|
req.Header.Set("Authorization", "Token "+token)
|
||||||
@@ -215,8 +256,12 @@ func CreateTestTask(t *testing.T, db *gorm.DB, residenceID, createdByID uint, ti
|
|||||||
return task
|
return task
|
||||||
}
|
}
|
||||||
|
|
||||||
// SeedLookupData seeds all lookup tables with test data
|
// SeedLookupData seeds all lookup tables with test data.
|
||||||
|
// All GORM create operations are checked for errors to prevent silent failures
|
||||||
|
// that could cause misleading test results.
|
||||||
func SeedLookupData(t *testing.T, db *gorm.DB) {
|
func SeedLookupData(t *testing.T, db *gorm.DB) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
// Residence types
|
// Residence types
|
||||||
residenceTypes := []models.ResidenceType{
|
residenceTypes := []models.ResidenceType{
|
||||||
{Name: "House"},
|
{Name: "House"},
|
||||||
@@ -224,8 +269,9 @@ func SeedLookupData(t *testing.T, db *gorm.DB) {
|
|||||||
{Name: "Condo"},
|
{Name: "Condo"},
|
||||||
{Name: "Townhouse"},
|
{Name: "Townhouse"},
|
||||||
}
|
}
|
||||||
for _, rt := range residenceTypes {
|
for i := range residenceTypes {
|
||||||
db.Create(&rt)
|
err := db.Create(&residenceTypes[i]).Error
|
||||||
|
require.NoError(t, err, "failed to seed residence type: %s", residenceTypes[i].Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Task categories
|
// Task categories
|
||||||
@@ -235,8 +281,9 @@ func SeedLookupData(t *testing.T, db *gorm.DB) {
|
|||||||
{Name: "HVAC", DisplayOrder: 3},
|
{Name: "HVAC", DisplayOrder: 3},
|
||||||
{Name: "General", DisplayOrder: 99},
|
{Name: "General", DisplayOrder: 99},
|
||||||
}
|
}
|
||||||
for _, c := range categories {
|
for i := range categories {
|
||||||
db.Create(&c)
|
err := db.Create(&categories[i]).Error
|
||||||
|
require.NoError(t, err, "failed to seed task category: %s", categories[i].Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Task priorities
|
// Task priorities
|
||||||
@@ -246,8 +293,9 @@ func SeedLookupData(t *testing.T, db *gorm.DB) {
|
|||||||
{Name: "High", Level: 3, DisplayOrder: 3},
|
{Name: "High", Level: 3, DisplayOrder: 3},
|
||||||
{Name: "Urgent", Level: 4, DisplayOrder: 4},
|
{Name: "Urgent", Level: 4, DisplayOrder: 4},
|
||||||
}
|
}
|
||||||
for _, p := range priorities {
|
for i := range priorities {
|
||||||
db.Create(&p)
|
err := db.Create(&priorities[i]).Error
|
||||||
|
require.NoError(t, err, "failed to seed task priority: %s", priorities[i].Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Task frequencies
|
// Task frequencies
|
||||||
@@ -258,8 +306,9 @@ func SeedLookupData(t *testing.T, db *gorm.DB) {
|
|||||||
{Name: "Weekly", Days: &days7, DisplayOrder: 2},
|
{Name: "Weekly", Days: &days7, DisplayOrder: 2},
|
||||||
{Name: "Monthly", Days: &days30, DisplayOrder: 3},
|
{Name: "Monthly", Days: &days30, DisplayOrder: 3},
|
||||||
}
|
}
|
||||||
for _, f := range frequencies {
|
for i := range frequencies {
|
||||||
db.Create(&f)
|
err := db.Create(&frequencies[i]).Error
|
||||||
|
require.NoError(t, err, "failed to seed task frequency: %s", frequencies[i].Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Contractor specialties
|
// Contractor specialties
|
||||||
@@ -269,8 +318,9 @@ func SeedLookupData(t *testing.T, db *gorm.DB) {
|
|||||||
{Name: "HVAC Technician"},
|
{Name: "HVAC Technician"},
|
||||||
{Name: "Handyman"},
|
{Name: "Handyman"},
|
||||||
}
|
}
|
||||||
for _, s := range specialties {
|
for i := range specialties {
|
||||||
db.Create(&s)
|
err := db.Create(&specialties[i]).Error
|
||||||
|
require.NoError(t, err, "failed to seed contractor specialty: %s", specialties[i].Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,8 +20,6 @@ import (
|
|||||||
|
|
||||||
// Task types
|
// Task types
|
||||||
const (
|
const (
|
||||||
TypeTaskReminder = "notification:task_reminder"
|
|
||||||
TypeOverdueReminder = "notification:overdue_reminder"
|
|
||||||
TypeSmartReminder = "notification:smart_reminder" // Frequency-aware reminders
|
TypeSmartReminder = "notification:smart_reminder" // Frequency-aware reminders
|
||||||
TypeDailyDigest = "notification:daily_digest"
|
TypeDailyDigest = "notification:daily_digest"
|
||||||
TypeSendEmail = "email:send"
|
TypeSendEmail = "email:send"
|
||||||
@@ -36,6 +34,7 @@ type Handler struct {
|
|||||||
taskRepo *repositories.TaskRepository
|
taskRepo *repositories.TaskRepository
|
||||||
residenceRepo *repositories.ResidenceRepository
|
residenceRepo *repositories.ResidenceRepository
|
||||||
reminderRepo *repositories.ReminderRepository
|
reminderRepo *repositories.ReminderRepository
|
||||||
|
notificationRepo *repositories.NotificationRepository
|
||||||
pushClient *push.Client
|
pushClient *push.Client
|
||||||
emailService *services.EmailService
|
emailService *services.EmailService
|
||||||
notificationService *services.NotificationService
|
notificationService *services.NotificationService
|
||||||
@@ -56,6 +55,7 @@ func NewHandler(db *gorm.DB, pushClient *push.Client, emailService *services.Ema
|
|||||||
taskRepo: repositories.NewTaskRepository(db),
|
taskRepo: repositories.NewTaskRepository(db),
|
||||||
residenceRepo: repositories.NewResidenceRepository(db),
|
residenceRepo: repositories.NewResidenceRepository(db),
|
||||||
reminderRepo: repositories.NewReminderRepository(db),
|
reminderRepo: repositories.NewReminderRepository(db),
|
||||||
|
notificationRepo: repositories.NewNotificationRepository(db),
|
||||||
pushClient: pushClient,
|
pushClient: pushClient,
|
||||||
emailService: emailService,
|
emailService: emailService,
|
||||||
notificationService: notificationService,
|
notificationService: notificationService,
|
||||||
@@ -64,218 +64,6 @@ func NewHandler(db *gorm.DB, pushClient *push.Client, emailService *services.Ema
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TaskReminderData represents a task due soon for reminder notifications
|
|
||||||
type TaskReminderData struct {
|
|
||||||
TaskID uint
|
|
||||||
TaskTitle string
|
|
||||||
DueDate time.Time
|
|
||||||
UserID uint
|
|
||||||
UserEmail string
|
|
||||||
UserName string
|
|
||||||
ResidenceName string
|
|
||||||
}
|
|
||||||
|
|
||||||
// HandleTaskReminder processes task reminder notifications for tasks due today or tomorrow with actionable buttons
|
|
||||||
func (h *Handler) HandleTaskReminder(ctx context.Context, task *asynq.Task) error {
|
|
||||||
log.Info().Msg("Processing task reminder notifications...")
|
|
||||||
|
|
||||||
now := time.Now().UTC()
|
|
||||||
currentHour := now.Hour()
|
|
||||||
systemDefaultHour := h.config.Worker.TaskReminderHour
|
|
||||||
|
|
||||||
log.Info().Int("current_hour", currentHour).Int("system_default_hour", systemDefaultHour).Msg("Task reminder check")
|
|
||||||
|
|
||||||
// Step 1: Find users who should receive notifications THIS hour
|
|
||||||
// Logic: Each user gets notified ONCE per day at exactly ONE hour:
|
|
||||||
// - If user has custom hour set: notify ONLY at that custom hour
|
|
||||||
// - If user has NO custom hour (NULL): notify ONLY at system default hour
|
|
||||||
// This prevents duplicates: a user with custom hour is NEVER notified at default hour
|
|
||||||
var eligibleUserIDs []uint
|
|
||||||
|
|
||||||
query := h.db.Model(&models.NotificationPreference{}).
|
|
||||||
Select("user_id").
|
|
||||||
Where("task_due_soon = true")
|
|
||||||
|
|
||||||
if currentHour == systemDefaultHour {
|
|
||||||
// At system default hour: notify users who have NO custom hour (NULL) OR whose custom hour equals default
|
|
||||||
query = query.Where("task_due_soon_hour IS NULL OR task_due_soon_hour = ?", currentHour)
|
|
||||||
} else {
|
|
||||||
// At non-default hour: only notify users who have this specific custom hour set
|
|
||||||
// Exclude users with NULL (they get notified at default hour only)
|
|
||||||
query = query.Where("task_due_soon_hour = ?", currentHour)
|
|
||||||
}
|
|
||||||
|
|
||||||
err := query.Pluck("user_id", &eligibleUserIDs).Error
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("Failed to query eligible users for task reminders")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Early exit if no users need notifications this hour
|
|
||||||
if len(eligibleUserIDs) == 0 {
|
|
||||||
log.Debug().Int("hour", currentHour).Msg("No users scheduled for task reminder notifications this hour")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info().Int("eligible_users", len(eligibleUserIDs)).Msg("Found users eligible for task reminders this hour")
|
|
||||||
|
|
||||||
// Step 2: Query tasks due today or tomorrow using the single-purpose repository function
|
|
||||||
// Uses the same scopes as kanban for consistency, with IncludeInProgress=true
|
|
||||||
// so users still get notified about in-progress tasks that are due soon.
|
|
||||||
opts := repositories.TaskFilterOptions{
|
|
||||||
UserIDs: eligibleUserIDs,
|
|
||||||
IncludeInProgress: true, // Notifications should include in-progress tasks
|
|
||||||
PreloadResidence: true,
|
|
||||||
PreloadCompletions: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Due soon = due within 2 days (today and tomorrow)
|
|
||||||
dueSoonTasks, err := h.taskRepo.GetDueSoonTasks(now, 2, opts)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("Failed to query tasks due soon")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info().Int("count", len(dueSoonTasks)).Msg("Found tasks due today/tomorrow for eligible users")
|
|
||||||
|
|
||||||
// Build set for O(1) eligibility lookups instead of O(N) linear scan
|
|
||||||
eligibleSet := make(map[uint]bool, len(eligibleUserIDs))
|
|
||||||
for _, id := range eligibleUserIDs {
|
|
||||||
eligibleSet[id] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Group tasks by user (assigned_to or residence owner)
|
|
||||||
userTasks := make(map[uint][]models.Task)
|
|
||||||
for _, t := range dueSoonTasks {
|
|
||||||
var userID uint
|
|
||||||
if t.AssignedToID != nil {
|
|
||||||
userID = *t.AssignedToID
|
|
||||||
} else if t.Residence.ID != 0 {
|
|
||||||
userID = t.Residence.OwnerID
|
|
||||||
} else {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Only include if user is in eligible set (O(1) lookup)
|
|
||||||
if eligibleSet[userID] {
|
|
||||||
userTasks[userID] = append(userTasks[userID], t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Step 3: Send notifications (no need to check preferences again - already filtered)
|
|
||||||
// Send individual task-specific notification for each task (all tasks, no limit)
|
|
||||||
for userID, taskList := range userTasks {
|
|
||||||
for _, t := range taskList {
|
|
||||||
if err := h.notificationService.CreateAndSendTaskNotification(ctx, userID, models.NotificationTaskDueSoon, &t); err != nil {
|
|
||||||
log.Error().Err(err).Uint("user_id", userID).Uint("task_id", t.ID).Msg("Failed to send task reminder notification")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info().Int("users_notified", len(userTasks)).Msg("Task reminder notifications completed")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// HandleOverdueReminder processes overdue task notifications with actionable buttons
|
|
||||||
func (h *Handler) HandleOverdueReminder(ctx context.Context, task *asynq.Task) error {
|
|
||||||
log.Info().Msg("Processing overdue task notifications...")
|
|
||||||
|
|
||||||
now := time.Now().UTC()
|
|
||||||
currentHour := now.Hour()
|
|
||||||
systemDefaultHour := h.config.Worker.OverdueReminderHour
|
|
||||||
|
|
||||||
log.Info().Int("current_hour", currentHour).Int("system_default_hour", systemDefaultHour).Msg("Overdue reminder check")
|
|
||||||
|
|
||||||
// Step 1: Find users who should receive notifications THIS hour
|
|
||||||
// Logic: Each user gets notified ONCE per day at exactly ONE hour:
|
|
||||||
// - If user has custom hour set: notify ONLY at that custom hour
|
|
||||||
// - If user has NO custom hour (NULL): notify ONLY at system default hour
|
|
||||||
// This prevents duplicates: a user with custom hour is NEVER notified at default hour
|
|
||||||
var eligibleUserIDs []uint
|
|
||||||
|
|
||||||
query := h.db.Model(&models.NotificationPreference{}).
|
|
||||||
Select("user_id").
|
|
||||||
Where("task_overdue = true")
|
|
||||||
|
|
||||||
if currentHour == systemDefaultHour {
|
|
||||||
// At system default hour: notify users who have NO custom hour (NULL) OR whose custom hour equals default
|
|
||||||
query = query.Where("task_overdue_hour IS NULL OR task_overdue_hour = ?", currentHour)
|
|
||||||
} else {
|
|
||||||
// At non-default hour: only notify users who have this specific custom hour set
|
|
||||||
// Exclude users with NULL (they get notified at default hour only)
|
|
||||||
query = query.Where("task_overdue_hour = ?", currentHour)
|
|
||||||
}
|
|
||||||
|
|
||||||
err := query.Pluck("user_id", &eligibleUserIDs).Error
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("Failed to query eligible users for overdue reminders")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Early exit if no users need notifications this hour
|
|
||||||
if len(eligibleUserIDs) == 0 {
|
|
||||||
log.Debug().Int("hour", currentHour).Msg("No users scheduled for overdue notifications this hour")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info().Int("eligible_users", len(eligibleUserIDs)).Msg("Found users eligible for overdue reminders this hour")
|
|
||||||
|
|
||||||
// Step 2: Query overdue tasks using the single-purpose repository function
|
|
||||||
// Uses the same scopes as kanban for consistency, with IncludeInProgress=true
|
|
||||||
// so users still get notified about in-progress tasks that are overdue.
|
|
||||||
opts := repositories.TaskFilterOptions{
|
|
||||||
UserIDs: eligibleUserIDs,
|
|
||||||
IncludeInProgress: true, // Notifications should include in-progress tasks
|
|
||||||
PreloadResidence: true,
|
|
||||||
PreloadCompletions: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
overdueTasks, err := h.taskRepo.GetOverdueTasks(now, opts)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("Failed to query overdue tasks")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info().Int("count", len(overdueTasks)).Msg("Found overdue tasks for eligible users")
|
|
||||||
|
|
||||||
// Build set for O(1) eligibility lookups instead of O(N) linear scan
|
|
||||||
eligibleSet := make(map[uint]bool, len(eligibleUserIDs))
|
|
||||||
for _, id := range eligibleUserIDs {
|
|
||||||
eligibleSet[id] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Group tasks by user (assigned_to or residence owner)
|
|
||||||
userTasks := make(map[uint][]models.Task)
|
|
||||||
for _, t := range overdueTasks {
|
|
||||||
var userID uint
|
|
||||||
if t.AssignedToID != nil {
|
|
||||||
userID = *t.AssignedToID
|
|
||||||
} else if t.Residence.ID != 0 {
|
|
||||||
userID = t.Residence.OwnerID
|
|
||||||
} else {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Only include if user is in eligible set (O(1) lookup)
|
|
||||||
if eligibleSet[userID] {
|
|
||||||
userTasks[userID] = append(userTasks[userID], t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Step 3: Send notifications (no need to check preferences again - already filtered)
|
|
||||||
// Send individual task-specific notification for each task (all tasks, no limit)
|
|
||||||
for userID, taskList := range userTasks {
|
|
||||||
for _, t := range taskList {
|
|
||||||
if err := h.notificationService.CreateAndSendTaskNotification(ctx, userID, models.NotificationTaskOverdue, &t); err != nil {
|
|
||||||
log.Error().Err(err).Uint("user_id", userID).Uint("task_id", t.ID).Msg("Failed to send overdue notification")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info().Int("users_notified", len(userTasks)).Msg("Overdue task notifications completed")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// HandleDailyDigest processes daily digest notifications with task statistics
|
// HandleDailyDigest processes daily digest notifications with task statistics
|
||||||
func (h *Handler) HandleDailyDigest(ctx context.Context, task *asynq.Task) error {
|
func (h *Handler) HandleDailyDigest(ctx context.Context, task *asynq.Task) error {
|
||||||
log.Info().Msg("Processing daily digest notifications...")
|
log.Info().Msg("Processing daily digest notifications...")
|
||||||
@@ -328,8 +116,7 @@ func (h *Handler) HandleDailyDigest(ctx context.Context, task *asynq.Task) error
|
|||||||
// Get user's timezone from notification preferences for accurate overdue calculation
|
// Get user's timezone from notification preferences for accurate overdue calculation
|
||||||
// This ensures the daily digest matches what the user sees in the kanban UI
|
// This ensures the daily digest matches what the user sees in the kanban UI
|
||||||
var userNow time.Time
|
var userNow time.Time
|
||||||
var prefs models.NotificationPreference
|
if prefs, err := h.notificationRepo.FindPreferencesByUser(userID); err == nil && prefs.Timezone != nil {
|
||||||
if err := h.db.Where("user_id = ?", userID).First(&prefs).Error; err == nil && prefs.Timezone != nil {
|
|
||||||
if loc, err := time.LoadLocation(*prefs.Timezone); err == nil {
|
if loc, err := time.LoadLocation(*prefs.Timezone); err == nil {
|
||||||
// Use start of day in user's timezone (matches kanban behavior)
|
// Use start of day in user's timezone (matches kanban behavior)
|
||||||
userNowInTz := time.Now().In(loc)
|
userNowInTz := time.Now().In(loc)
|
||||||
@@ -481,22 +268,11 @@ func (h *Handler) sendPushToUser(ctx context.Context, userID uint, title, messag
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get iOS device tokens
|
// Get active device tokens via repository
|
||||||
var iosTokens []string
|
iosTokens, androidTokens, err := h.notificationRepo.GetActiveTokensForUser(userID)
|
||||||
err := h.db.Model(&models.APNSDevice{}).
|
|
||||||
Where("user_id = ? AND active = ?", userID, true).
|
|
||||||
Pluck("registration_id", &iosTokens).Error
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Uint("user_id", userID).Msg("Failed to get iOS tokens")
|
log.Error().Err(err).Uint("user_id", userID).Msg("Failed to get device tokens")
|
||||||
}
|
return err
|
||||||
|
|
||||||
// Get Android device tokens
|
|
||||||
var androidTokens []string
|
|
||||||
err = h.db.Model(&models.GCMDevice{}).
|
|
||||||
Where("user_id = ? AND active = ?", userID, true).
|
|
||||||
Pluck("registration_id", &androidTokens).Error
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Uint("user_id", userID).Msg("Failed to get Android tokens")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(iosTokens) == 0 && len(androidTokens) == 0 {
|
if len(iosTokens) == 0 && len(androidTokens) == 0 {
|
||||||
@@ -561,18 +337,17 @@ func (h *Handler) HandleOnboardingEmails(ctx context.Context, task *asynq.Task)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send no-residence emails (users without any residences after 2 days)
|
// Send no-residence emails (users without any residences after 2 days)
|
||||||
noResCount, err := h.onboardingService.CheckAndSendNoResidenceEmails()
|
noResCount, noResErr := h.onboardingService.CheckAndSendNoResidenceEmails()
|
||||||
if err != nil {
|
if noResErr != nil {
|
||||||
log.Error().Err(err).Msg("Failed to process no-residence onboarding emails")
|
log.Error().Err(noResErr).Msg("Failed to process no-residence onboarding emails")
|
||||||
// Continue to next type, don't return error
|
|
||||||
} else {
|
} else {
|
||||||
log.Info().Int("count", noResCount).Msg("Sent no-residence onboarding emails")
|
log.Info().Int("count", noResCount).Msg("Sent no-residence onboarding emails")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send no-tasks emails (users with residence but no tasks after 5 days)
|
// Send no-tasks emails (users with residence but no tasks after 5 days)
|
||||||
noTasksCount, err := h.onboardingService.CheckAndSendNoTasksEmails()
|
noTasksCount, noTasksErr := h.onboardingService.CheckAndSendNoTasksEmails()
|
||||||
if err != nil {
|
if noTasksErr != nil {
|
||||||
log.Error().Err(err).Msg("Failed to process no-tasks onboarding emails")
|
log.Error().Err(noTasksErr).Msg("Failed to process no-tasks onboarding emails")
|
||||||
} else {
|
} else {
|
||||||
log.Info().Int("count", noTasksCount).Msg("Sent no-tasks onboarding emails")
|
log.Info().Int("count", noTasksCount).Msg("Sent no-tasks onboarding emails")
|
||||||
}
|
}
|
||||||
@@ -582,6 +357,11 @@ func (h *Handler) HandleOnboardingEmails(ctx context.Context, task *asynq.Task)
|
|||||||
Int("no_tasks_sent", noTasksCount).
|
Int("no_tasks_sent", noTasksCount).
|
||||||
Msg("Onboarding email processing completed")
|
Msg("Onboarding email processing completed")
|
||||||
|
|
||||||
|
// If all sub-tasks failed, return an error so Asynq retries
|
||||||
|
if noResErr != nil && noTasksErr != nil {
|
||||||
|
return fmt.Errorf("all onboarding email sub-tasks failed: no-residence: %w, no-tasks: %v", noResErr, noTasksErr)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -603,7 +383,6 @@ func (h *Handler) HandleSmartReminder(ctx context.Context, task *asynq.Task) err
|
|||||||
log.Info().Msg("Processing smart task reminders...")
|
log.Info().Msg("Processing smart task reminders...")
|
||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)
|
|
||||||
currentHour := now.Hour()
|
currentHour := now.Hour()
|
||||||
|
|
||||||
dueSoonDefault := h.config.Worker.TaskReminderHour
|
dueSoonDefault := h.config.Worker.TaskReminderHour
|
||||||
@@ -673,6 +452,22 @@ func (h *Handler) HandleSmartReminder(ctx context.Context, task *asynq.Task) err
|
|||||||
Int("want_overdue", len(userWantsOverdue)).
|
Int("want_overdue", len(userWantsOverdue)).
|
||||||
Msg("Found users eligible for reminders")
|
Msg("Found users eligible for reminders")
|
||||||
|
|
||||||
|
// Build per-user "today" using their timezone preference
|
||||||
|
// This ensures reminder stage calculations (overdue, due-soon) match the user's local date
|
||||||
|
userToday := make(map[uint]time.Time, len(allUserIDs))
|
||||||
|
utcToday := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)
|
||||||
|
for _, uid := range allUserIDs {
|
||||||
|
prefs, err := h.notificationRepo.FindPreferencesByUser(uid)
|
||||||
|
if err == nil && prefs.Timezone != nil {
|
||||||
|
if loc, locErr := time.LoadLocation(*prefs.Timezone); locErr == nil {
|
||||||
|
userNowInTz := time.Now().In(loc)
|
||||||
|
userToday[uid] = time.Date(userNowInTz.Year(), userNowInTz.Month(), userNowInTz.Day(), 0, 0, 0, 0, loc)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
userToday[uid] = utcToday // Fallback to UTC
|
||||||
|
}
|
||||||
|
|
||||||
// Step 2: Single query to get ALL active tasks (both due-soon and overdue) for these users
|
// Step 2: Single query to get ALL active tasks (both due-soon and overdue) for these users
|
||||||
opts := repositories.TaskFilterOptions{
|
opts := repositories.TaskFilterOptions{
|
||||||
UserIDs: allUserIDs,
|
UserIDs: allUserIDs,
|
||||||
@@ -734,7 +529,8 @@ func (h *Handler) HandleSmartReminder(ctx context.Context, task *asynq.Task) err
|
|||||||
frequencyDays = &days
|
frequencyDays = &days
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine which reminder stage applies today
|
// Determine which reminder stage applies today using the user's local date
|
||||||
|
today := userToday[userID]
|
||||||
stage := notifications.GetReminderStageForToday(effectiveDate, frequencyDays, today)
|
stage := notifications.GetReminderStageForToday(effectiveDate, frequencyDays, today)
|
||||||
if stage == "" {
|
if stage == "" {
|
||||||
continue
|
continue
|
||||||
@@ -800,6 +596,11 @@ func (h *Handler) HandleSmartReminder(ctx context.Context, task *asynq.Task) err
|
|||||||
notificationType = models.NotificationTaskDueSoon
|
notificationType = models.NotificationTaskDueSoon
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Log the reminder BEFORE sending so we have a record even if the send crashes
|
||||||
|
if _, err := h.reminderRepo.LogReminder(t.ID, c.userID, c.effectiveDate, c.reminderStage, nil); err != nil {
|
||||||
|
log.Error().Err(err).Uint("task_id", t.ID).Str("stage", c.stage).Msg("Failed to log reminder")
|
||||||
|
}
|
||||||
|
|
||||||
// Send notification
|
// Send notification
|
||||||
if err := h.notificationService.CreateAndSendTaskNotification(ctx, c.userID, notificationType, &t); err != nil {
|
if err := h.notificationService.CreateAndSendTaskNotification(ctx, c.userID, notificationType, &t); err != nil {
|
||||||
log.Error().Err(err).
|
log.Error().Err(err).
|
||||||
@@ -810,11 +611,6 @@ func (h *Handler) HandleSmartReminder(ctx context.Context, task *asynq.Task) err
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Log the reminder
|
|
||||||
if _, err := h.reminderRepo.LogReminder(t.ID, c.userID, c.effectiveDate, c.reminderStage, nil); err != nil {
|
|
||||||
log.Error().Err(err).Uint("task_id", t.ID).Str("stage", c.stage).Msg("Failed to log reminder")
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.isOverdue {
|
if c.isOverdue {
|
||||||
overdueSent++
|
overdueSent++
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -1,27 +1,18 @@
|
|||||||
package worker
|
package worker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/hibiken/asynq"
|
"github.com/hibiken/asynq"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Task types
|
// Task types for email jobs
|
||||||
const (
|
const (
|
||||||
TypeWelcomeEmail = "email:welcome"
|
TypeWelcomeEmail = "email:welcome"
|
||||||
TypeVerificationEmail = "email:verification"
|
TypeVerificationEmail = "email:verification"
|
||||||
TypePasswordResetEmail = "email:password_reset"
|
TypePasswordResetEmail = "email:password_reset"
|
||||||
TypePasswordChangedEmail = "email:password_changed"
|
TypePasswordChangedEmail = "email:password_changed"
|
||||||
TypeTaskCompletionEmail = "email:task_completion"
|
|
||||||
TypeGeneratePDFReport = "pdf:generate_report"
|
|
||||||
TypeUpdateContractorRating = "contractor:update_rating"
|
|
||||||
TypeDailyNotifications = "notifications:daily"
|
|
||||||
TypeTaskReminders = "notifications:task_reminders"
|
|
||||||
TypeOverdueReminders = "notifications:overdue_reminders"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// EmailPayload is the base payload for email tasks
|
// EmailPayload is the base payload for email tasks
|
||||||
@@ -146,94 +137,3 @@ func (c *TaskClient) EnqueuePasswordChangedEmail(to, firstName string) error {
|
|||||||
log.Debug().Str("to", to).Msg("Password changed email task enqueued")
|
log.Debug().Str("to", to).Msg("Password changed email task enqueued")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// WorkerServer manages the asynq worker server
|
|
||||||
type WorkerServer struct {
|
|
||||||
server *asynq.Server
|
|
||||||
scheduler *asynq.Scheduler
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewWorkerServer creates a new worker server
|
|
||||||
func NewWorkerServer(redisAddr string, concurrency int) *WorkerServer {
|
|
||||||
srv := asynq.NewServer(
|
|
||||||
asynq.RedisClientOpt{Addr: redisAddr},
|
|
||||||
asynq.Config{
|
|
||||||
Concurrency: concurrency,
|
|
||||||
Queues: map[string]int{
|
|
||||||
"critical": 6,
|
|
||||||
"default": 3,
|
|
||||||
"low": 1,
|
|
||||||
},
|
|
||||||
ErrorHandler: asynq.ErrorHandlerFunc(func(ctx context.Context, task *asynq.Task, err error) {
|
|
||||||
log.Error().
|
|
||||||
Err(err).
|
|
||||||
Str("type", task.Type()).
|
|
||||||
Bytes("payload", task.Payload()).
|
|
||||||
Msg("Task failed")
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
// Create scheduler for periodic tasks
|
|
||||||
loc, _ := time.LoadLocation("UTC")
|
|
||||||
scheduler := asynq.NewScheduler(
|
|
||||||
asynq.RedisClientOpt{Addr: redisAddr},
|
|
||||||
&asynq.SchedulerOpts{
|
|
||||||
Location: loc,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return &WorkerServer{
|
|
||||||
server: srv,
|
|
||||||
scheduler: scheduler,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterHandlers registers task handlers
|
|
||||||
func (w *WorkerServer) RegisterHandlers(mux *asynq.ServeMux) {
|
|
||||||
// Handlers will be registered by the main worker process
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterScheduledTasks registers periodic tasks
|
|
||||||
func (w *WorkerServer) RegisterScheduledTasks() error {
|
|
||||||
// Task reminders - 8:00 PM UTC daily
|
|
||||||
_, err := w.scheduler.Register("0 20 * * *", asynq.NewTask(TypeTaskReminders, nil))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to register task reminders: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Overdue reminders - 9:00 AM UTC daily
|
|
||||||
_, err = w.scheduler.Register("0 9 * * *", asynq.NewTask(TypeOverdueReminders, nil))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to register overdue reminders: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Daily notifications - 11:00 AM UTC daily
|
|
||||||
_, err = w.scheduler.Register("0 11 * * *", asynq.NewTask(TypeDailyNotifications, nil))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to register daily notifications: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start starts the worker server and scheduler
|
|
||||||
func (w *WorkerServer) Start(mux *asynq.ServeMux) error {
|
|
||||||
// Start scheduler
|
|
||||||
if err := w.scheduler.Start(); err != nil {
|
|
||||||
return fmt.Errorf("failed to start scheduler: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start server
|
|
||||||
if err := w.server.Start(mux); err != nil {
|
|
||||||
return fmt.Errorf("failed to start worker server: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shutdown gracefully shuts down the worker server
|
|
||||||
func (w *WorkerServer) Shutdown() {
|
|
||||||
w.scheduler.Shutdown()
|
|
||||||
w.server.Shutdown()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,12 +2,16 @@ package utils
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime/debug"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
|
"github.com/treytartt/honeydue-api/internal/dto/responses"
|
||||||
)
|
)
|
||||||
|
|
||||||
// InitLogger initializes the zerolog logger
|
// InitLogger initializes the zerolog logger
|
||||||
@@ -113,14 +117,17 @@ func EchoRecovery() echo.MiddlewareFunc {
|
|||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
|
// F-14: Include full stack trace for debugging
|
||||||
log.Error().
|
log.Error().
|
||||||
Interface("error", err).
|
Interface("error", err).
|
||||||
Str("path", c.Request().URL.Path).
|
Str("path", c.Request().URL.Path).
|
||||||
Str("method", c.Request().Method).
|
Str("method", c.Request().Method).
|
||||||
|
Str("stack", string(debug.Stack())).
|
||||||
Msg("Panic recovered")
|
Msg("Panic recovered")
|
||||||
|
|
||||||
c.JSON(500, map[string]interface{}{
|
// F-15: Use the project's standard ErrorResponse struct
|
||||||
"error": "Internal server error",
|
c.JSON(http.StatusInternalServerError, responses.ErrorResponse{
|
||||||
|
Error: "Internal server error",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
Reference in New Issue
Block a user