From 4abc57535eb6a5cc69be4cf324718e4624b8789a Mon Sep 17 00:00:00 2001 From: Trey T Date: Thu, 26 Mar 2026 10:41:01 -0500 Subject: [PATCH] Add delete account endpoint and file encryption at rest Delete Account (Plan #2): - DELETE /api/auth/account/ with password or "DELETE" confirmation - Cascade delete across 15+ tables in correct FK order - Auth provider detection (email/apple/google) for /auth/me/ - File cleanup after account deletion - Handler + repository tests (12 tests) Encryption at Rest (Plan #3): - AES-256-GCM envelope encryption (per-file DEK wrapped by KEK) - Encrypt on upload, auto-decrypt on serve via StorageService.ReadFile() - MediaHandler serves decrypted files via c.Blob() - TaskService email image loading uses ReadFile() - cmd/migrate-encrypt CLI tool with --dry-run for existing files - Encryption service + storage service tests (18 tests) --- .env.example | 2 + Makefile | 17 +- cmd/api/main.go | 11 + cmd/migrate-encrypt/main.go | 190 +++++++++++++++ docker-compose.dev.yml | 3 + docker-compose.yml | 3 + internal/config/config.go | 29 ++- internal/dto/requests/auth.go | 6 + internal/dto/responses/auth.go | 40 ++-- internal/handlers/auth_handler.go | 51 ++++ internal/handlers/auth_handler_delete_test.go | 217 +++++++++++++++++ internal/handlers/media_handler.go | 33 +-- internal/repositories/user_repo.go | 195 ++++++++++++++++ internal/repositories/user_repo_test.go | 166 +++++++++++++ internal/router/router.go | 2 + internal/services/auth_service.go | 69 +++++- internal/services/encryption_service.go | 179 ++++++++++++++ internal/services/encryption_service_test.go | 218 ++++++++++++++++++ internal/services/storage_service.go | 128 ++++++++-- internal/services/storage_service_test.go | 164 +++++++++++++ internal/services/task_service.go | 31 +-- internal/testutil/testutil.go | 3 + 22 files changed, 1675 insertions(+), 82 deletions(-) create mode 100644 cmd/migrate-encrypt/main.go create mode 100644 internal/handlers/auth_handler_delete_test.go create mode 100644 internal/services/encryption_service.go create mode 100644 internal/services/encryption_service_test.go create mode 100644 internal/services/storage_service_test.go diff --git a/.env.example b/.env.example index ed05fdc..28ab78e 100644 --- a/.env.example +++ b/.env.example @@ -50,6 +50,8 @@ STORAGE_UPLOAD_DIR=./uploads STORAGE_BASE_URL=/uploads STORAGE_MAX_FILE_SIZE=10485760 STORAGE_ALLOWED_TYPES=image/jpeg,image/png,image/gif,image/webp,application/pdf +# 64-char hex key for file encryption at rest. Generate with: openssl rand -hex 32 +STORAGE_ENCRYPTION_KEY= # Feature Flags (Kill Switches) # Set to false to disable. All default to true (enabled). diff --git a/Makefile b/Makefile index 3dfbd08..066add4 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: build run test contract-test clean deps lint docker-build docker-up docker-down migrate +.PHONY: build run test contract-test clean deps lint docker-build docker-up docker-down migrate migrate-encrypt migrate-encrypt-dry # Binary names API_BINARY=honeydue-api @@ -99,6 +99,13 @@ migrate-down: migrate-create: migrate create -ext sql -dir migrations -seq $(name) +# Encrypt existing uploads at rest (run after setting STORAGE_ENCRYPTION_KEY) +migrate-encrypt: + go run ./cmd/migrate-encrypt + +migrate-encrypt-dry: + go run ./cmd/migrate-encrypt --dry-run + # Development helpers dev: deps run @@ -139,5 +146,9 @@ help: @echo " docker-build-prod - Build production images (api, worker, admin)" @echo "" @echo "Database:" - @echo " migrate-up - Run database migrations" - @echo " migrate-down - Rollback database migrations" + @echo " migrate-up - Run database migrations" + @echo " migrate-down - Rollback database migrations" + @echo "" + @echo "Encryption:" + @echo " migrate-encrypt - Encrypt existing uploads at rest" + @echo " migrate-encrypt-dry - Preview encryption migration (dry run)" diff --git a/cmd/api/main.go b/cmd/api/main.go index 3cb0dd6..84895ee 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -134,6 +134,17 @@ func main() { Str("base_url", cfg.Storage.BaseURL). Int64("max_file_size", cfg.Storage.MaxFileSize). Msg("Storage service initialized") + + // Initialize file encryption at rest if configured + if cfg.Storage.EncryptionKey != "" { + encSvc, encErr := services.NewEncryptionService(cfg.Storage.EncryptionKey) + if encErr != nil { + log.Error().Err(encErr).Msg("Failed to initialize encryption service - files will NOT be encrypted") + } else { + storageService.SetEncryptionService(encSvc) + log.Info().Msg("File encryption at rest enabled") + } + } } } diff --git a/cmd/migrate-encrypt/main.go b/cmd/migrate-encrypt/main.go new file mode 100644 index 0000000..fb399c3 --- /dev/null +++ b/cmd/migrate-encrypt/main.go @@ -0,0 +1,190 @@ +// migrate-encrypt is a standalone CLI tool that encrypts existing uploaded files at rest. +// +// It walks the uploads directory, encrypts each unencrypted file, updates the +// corresponding database record, and removes the original plaintext file. +// +// Usage: +// +// go run ./cmd/migrate-encrypt --dry-run # Preview changes +// go run ./cmd/migrate-encrypt # Apply changes +package main + +import ( + "flag" + "os" + "path/filepath" + "strings" + "time" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/logger" + + "github.com/treytartt/honeydue-api/internal/config" + "github.com/treytartt/honeydue-api/internal/services" +) + +// dbTable represents a table with a URL column that may reference uploaded files. +type dbTable struct { + table string + column string +} + +// tables lists all database tables and columns that store file URLs. +var tables = []dbTable{ + {table: "task_document", column: "file_url"}, + {table: "task_documentimage", column: "image_url"}, + {table: "task_taskcompletionimage", column: "image_url"}, +} + +func main() { + dryRun := flag.Bool("dry-run", false, "Preview changes without modifying files or database") + flag.Parse() + + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}) + + cfg, err := config.Load() + if err != nil { + log.Fatal().Err(err).Msg("Failed to load config") + } + + if cfg.Storage.EncryptionKey == "" { + log.Fatal().Msg("STORAGE_ENCRYPTION_KEY is not set — cannot encrypt files") + } + + encSvc, err := services.NewEncryptionService(cfg.Storage.EncryptionKey) + if err != nil { + log.Fatal().Err(err).Msg("Failed to create encryption service") + } + + dsn := cfg.Database.DSN() + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + if err != nil { + log.Fatal().Err(err).Msg("Failed to connect to database") + } + + uploadDir := cfg.Storage.UploadDir + absUploadDir, err := filepath.Abs(uploadDir) + if err != nil { + log.Fatal().Err(err).Str("upload_dir", uploadDir).Msg("Failed to resolve upload directory") + } + + log.Info(). + Bool("dry_run", *dryRun). + Str("upload_dir", absUploadDir). + Msg("Starting file encryption migration") + + var totalFiles, encrypted, skipped, errCount int + + // Walk the uploads directory + err = filepath.Walk(absUploadDir, func(path string, info os.FileInfo, walkErr error) error { + if walkErr != nil { + log.Warn().Err(walkErr).Str("path", path).Msg("Error accessing path") + return nil + } + + // Skip directories + if info.IsDir() { + return nil + } + + // Skip files already encrypted + if strings.HasSuffix(path, ".enc") { + skipped++ + return nil + } + + totalFiles++ + + // Compute the relative path from upload dir + relPath, err := filepath.Rel(absUploadDir, path) + if err != nil { + log.Warn().Err(err).Str("path", path).Msg("Failed to compute relative path") + errCount++ + return nil + } + + if *dryRun { + log.Info().Str("file", relPath).Msg("[DRY RUN] Would encrypt") + return nil + } + + // Process within a transaction + txErr := db.Transaction(func(tx *gorm.DB) error { + // Read plaintext file + plaintext, readErr := os.ReadFile(path) + if readErr != nil { + return readErr + } + + // Encrypt + ciphertext, encErr := encSvc.Encrypt(plaintext) + if encErr != nil { + return encErr + } + + // Write encrypted file + encPath := path + ".enc" + if writeErr := os.WriteFile(encPath, ciphertext, 0644); writeErr != nil { + return writeErr + } + + // Update database records that reference this file + // The stored URL uses the BaseURL prefix + relative path + // We need to match against the relative path portion + for _, t := range tables { + // Match URLs ending with the relative path (handles both /uploads/... and bare paths) + result := tx.Table(t.table). + Where(t.column+" LIKE ?", "%"+relPath). + Where(t.column+" NOT LIKE ?", "%.enc"). + Update(t.column, gorm.Expr(t.column+" || '.enc'")) + if result.Error != nil { + // Roll back: remove the encrypted file + os.Remove(encPath) + return result.Error + } + if result.RowsAffected > 0 { + log.Info(). + Str("table", t.table). + Str("column", t.column). + Int64("rows", result.RowsAffected). + Str("file", relPath). + Msg("Updated database records") + } + } + + // Remove the original plaintext file + if removeErr := os.Remove(path); removeErr != nil { + log.Warn().Err(removeErr).Str("path", path).Msg("Failed to remove original file (encrypted copy exists)") + } + + return nil + }) + + if txErr != nil { + log.Error().Err(txErr).Str("file", relPath).Msg("Failed to encrypt file") + errCount++ + } else { + encrypted++ + log.Info().Str("file", relPath).Msg("Encrypted successfully") + } + + return nil + }) + + if err != nil { + log.Fatal().Err(err).Msg("Failed to walk upload directory") + } + + log.Info(). + Bool("dry_run", *dryRun). + Int("total_files", totalFiles). + Int("encrypted", encrypted). + Int("skipped_already_enc", skipped). + Int("errors", errCount). + Msg("Encryption migration complete") +} diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index 320a373..35e8c78 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -88,6 +88,9 @@ services: APNS_TOPIC: ${APNS_TOPIC:-com.tt.honeyDue} APNS_USE_SANDBOX: "true" FCM_SERVER_KEY: ${FCM_SERVER_KEY} + + # Storage encryption + STORAGE_ENCRYPTION_KEY: ${STORAGE_ENCRYPTION_KEY} volumes: - ./push_certs:/certs:ro - ./uploads:/app/uploads diff --git a/docker-compose.yml b/docker-compose.yml index 275e59c..5ffed63 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -93,6 +93,9 @@ services: APNS_TOPIC: ${APNS_TOPIC} APNS_USE_SANDBOX: "${APNS_USE_SANDBOX:-false}" FCM_SERVER_KEY: ${FCM_SERVER_KEY} + + # Storage encryption + STORAGE_ENCRYPTION_KEY: ${STORAGE_ENCRYPTION_KEY} volumes: - push_certs:/certs:ro - uploads:/app/uploads diff --git a/internal/config/config.go b/internal/config/config.go index 66a6342..45b72f8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,7 @@ package config import ( + "encoding/hex" "fmt" "net/url" "os" @@ -137,10 +138,11 @@ type SecurityConfig struct { // StorageConfig holds file storage settings type StorageConfig struct { - UploadDir string // Directory to store uploaded files - BaseURL string // Public URL prefix for serving files (e.g., "/uploads") - MaxFileSize int64 // Max file size in bytes (default 10MB) - AllowedTypes string // Comma-separated MIME types + UploadDir string // Directory to store uploaded files + BaseURL string // Public URL prefix for serving files (e.g., "/uploads") + MaxFileSize int64 // Max file size in bytes (default 10MB) + AllowedTypes string // Comma-separated MIME types + EncryptionKey string // 64-char hex key for file encryption at rest (optional) } // FeatureFlags holds kill switches for major subsystems. @@ -262,10 +264,11 @@ func Load() (*Config, error) { MaxPasswordResetRate: 3, }, Storage: StorageConfig{ - UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"), - BaseURL: viper.GetString("STORAGE_BASE_URL"), - MaxFileSize: viper.GetInt64("STORAGE_MAX_FILE_SIZE"), - AllowedTypes: viper.GetString("STORAGE_ALLOWED_TYPES"), + UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"), + BaseURL: viper.GetString("STORAGE_BASE_URL"), + MaxFileSize: viper.GetInt64("STORAGE_MAX_FILE_SIZE"), + AllowedTypes: viper.GetString("STORAGE_ALLOWED_TYPES"), + EncryptionKey: viper.GetString("STORAGE_ENCRYPTION_KEY"), }, AppleAuth: AppleAuthConfig{ ClientID: viper.GetString("APPLE_CLIENT_ID"), @@ -414,6 +417,16 @@ func validate(cfg *Config) error { // Database password might come from DATABASE_URL, don't require it separately // The actual connection will fail if credentials are wrong + // Validate STORAGE_ENCRYPTION_KEY if set: must be exactly 64 hex characters + if cfg.Storage.EncryptionKey != "" { + if len(cfg.Storage.EncryptionKey) != 64 { + return fmt.Errorf("STORAGE_ENCRYPTION_KEY must be exactly 64 hex characters (got %d)", len(cfg.Storage.EncryptionKey)) + } + if _, err := hex.DecodeString(cfg.Storage.EncryptionKey); err != nil { + return fmt.Errorf("STORAGE_ENCRYPTION_KEY contains invalid hex: %w", err) + } + } + return nil } diff --git a/internal/dto/requests/auth.go b/internal/dto/requests/auth.go index fd83ad7..92767e2 100644 --- a/internal/dto/requests/auth.go +++ b/internal/dto/requests/auth.go @@ -63,3 +63,9 @@ type AppleSignInRequest struct { type GoogleSignInRequest struct { IDToken string `json:"id_token" validate:"required"` // Google ID token from Credential Manager } + +// DeleteAccountRequest represents the delete account request body +type DeleteAccountRequest struct { + Password *string `json:"password"` + Confirmation *string `json:"confirmation"` +} diff --git a/internal/dto/responses/auth.go b/internal/dto/responses/auth.go index 4c2fa46..2a4dae2 100644 --- a/internal/dto/responses/auth.go +++ b/internal/dto/responses/auth.go @@ -45,15 +45,16 @@ type RegisterResponse struct { // CurrentUserResponse represents the /auth/me/ response type CurrentUserResponse struct { - ID uint `json:"id"` - Username string `json:"username"` - Email string `json:"email"` - FirstName string `json:"first_name"` - LastName string `json:"last_name"` - IsActive bool `json:"is_active"` - DateJoined time.Time `json:"date_joined"` - LastLogin *time.Time `json:"last_login,omitempty"` - Profile *UserProfileResponse `json:"profile,omitempty"` + ID uint `json:"id"` + Username string `json:"username"` + Email string `json:"email"` + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + IsActive bool `json:"is_active"` + DateJoined time.Time `json:"date_joined"` + LastLogin *time.Time `json:"last_login,omitempty"` + Profile *UserProfileResponse `json:"profile,omitempty"` + AuthProvider string `json:"auth_provider"` } // VerifyEmailResponse represents the email verification response @@ -125,17 +126,18 @@ func NewUserProfileResponse(profile *models.UserProfile) *UserProfileResponse { } // NewCurrentUserResponse creates a CurrentUserResponse from a User model -func NewCurrentUserResponse(user *models.User) CurrentUserResponse { +func NewCurrentUserResponse(user *models.User, authProvider string) CurrentUserResponse { return CurrentUserResponse{ - ID: user.ID, - Username: user.Username, - Email: user.Email, - FirstName: user.FirstName, - LastName: user.LastName, - IsActive: user.IsActive, - DateJoined: user.DateJoined, - LastLogin: user.LastLogin, - Profile: NewUserProfileResponse(user.Profile), + ID: user.ID, + Username: user.Username, + Email: user.Email, + FirstName: user.FirstName, + LastName: user.LastName, + IsActive: user.IsActive, + DateJoined: user.DateJoined, + LastLogin: user.LastLogin, + Profile: NewUserProfileResponse(user.Profile), + AuthProvider: authProvider, } } diff --git a/internal/handlers/auth_handler.go b/internal/handlers/auth_handler.go index 2d6b367..03df15b 100644 --- a/internal/handlers/auth_handler.go +++ b/internal/handlers/auth_handler.go @@ -22,6 +22,7 @@ type AuthHandler struct { cache *services.CacheService appleAuthService *services.AppleAuthService googleAuthService *services.GoogleAuthService + storageService *services.StorageService } // NewAuthHandler creates a new auth handler @@ -43,6 +44,11 @@ func (h *AuthHandler) SetGoogleAuthService(googleAuth *services.GoogleAuthServic h.googleAuthService = googleAuth } +// SetStorageService sets the storage service for file deletion during account deletion +func (h *AuthHandler) SetStorageService(storageService *services.StorageService) { + h.storageService = storageService +} + // Login handles POST /api/auth/login/ func (h *AuthHandler) Login(c echo.Context) error { var req requests.LoginRequest @@ -406,3 +412,48 @@ func (h *AuthHandler) GoogleSignIn(c echo.Context) error { return c.JSON(http.StatusOK, response) } + +// DeleteAccount handles DELETE /api/auth/account/ +func (h *AuthHandler) DeleteAccount(c echo.Context) error { + user, err := middleware.MustGetAuthUser(c) + if err != nil { + return err + } + + var req requests.DeleteAccountRequest + if err := c.Bind(&req); err != nil { + return apperrors.BadRequest("error.invalid_request") + } + + fileURLs, err := h.authService.DeleteAccount(user.ID, req.Password, req.Confirmation) + if err != nil { + log.Debug().Err(err).Uint("user_id", user.ID).Msg("Account deletion failed") + return err + } + + // Delete files from disk (best effort, don't fail the request) + if h.storageService != nil && len(fileURLs) > 0 { + go func() { + defer func() { + if r := recover(); r != nil { + log.Error().Interface("panic", r).Uint("user_id", user.ID).Msg("Panic in file cleanup goroutine") + } + }() + for _, fileURL := range fileURLs { + if err := h.storageService.Delete(fileURL); err != nil { + log.Warn().Err(err).Str("file_url", fileURL).Msg("Failed to delete file during account cleanup") + } + } + }() + } + + // Invalidate auth token from cache + token := middleware.GetAuthToken(c) + if h.cache != nil && token != "" { + if err := h.cache.InvalidateAuthToken(c.Request().Context(), token); err != nil { + log.Warn().Err(err).Msg("Failed to invalidate token in cache after account deletion") + } + } + + return c.JSON(http.StatusOK, responses.MessageResponse{Message: "Account deleted successfully"}) +} diff --git a/internal/handlers/auth_handler_delete_test.go b/internal/handlers/auth_handler_delete_test.go new file mode 100644 index 0000000..d96bdfd --- /dev/null +++ b/internal/handlers/auth_handler_delete_test.go @@ -0,0 +1,217 @@ +package handlers + +import ( + "encoding/json" + "net/http" + "testing" + "time" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + + "github.com/treytartt/honeydue-api/internal/config" + "github.com/treytartt/honeydue-api/internal/models" + "github.com/treytartt/honeydue-api/internal/repositories" + "github.com/treytartt/honeydue-api/internal/services" + "github.com/treytartt/honeydue-api/internal/testutil" +) + +func setupDeleteAccountHandler(t *testing.T) (*AuthHandler, *echo.Echo, *gorm.DB) { + db := testutil.SetupTestDB(t) + userRepo := repositories.NewUserRepository(db) + cfg := &config.Config{ + Security: config.SecurityConfig{ + SecretKey: "test-secret-key", + PasswordResetExpiry: 15 * time.Minute, + ConfirmationExpiry: 24 * time.Hour, + MaxPasswordResetRate: 3, + }, + } + authService := services.NewAuthService(userRepo, cfg) + handler := NewAuthHandler(authService, nil, nil) + e := testutil.SetupTestRouter() + return handler, e, db +} + +func TestAuthHandler_DeleteAccount_EmailUser(t *testing.T) { + handler, e, db := setupDeleteAccountHandler(t) + + user := testutil.CreateTestUser(t, db, "deletetest", "delete@test.com", "password123") + + // Create profile for the user + profile := &models.UserProfile{UserID: user.ID, Verified: true} + require.NoError(t, db.Create(profile).Error) + + // Create auth token + testutil.CreateTestToken(t, db, user.ID) + + authGroup := e.Group("/api/auth") + authGroup.Use(testutil.MockAuthMiddleware(user)) + authGroup.DELETE("/account/", handler.DeleteAccount) + + t.Run("successful deletion with correct password", func(t *testing.T) { + password := "password123" + req := map[string]interface{}{ + "password": password, + } + + w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token") + + testutil.AssertStatusCode(t, w, http.StatusOK) + + var response map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(t, err) + assert.Contains(t, response["message"], "Account deleted successfully") + + // Verify user is actually deleted + var count int64 + db.Model(&models.User{}).Where("id = ?", user.ID).Count(&count) + assert.Equal(t, int64(0), count) + + // Verify profile is deleted + db.Model(&models.UserProfile{}).Where("user_id = ?", user.ID).Count(&count) + assert.Equal(t, int64(0), count) + + // Verify auth token is deleted + db.Model(&models.AuthToken{}).Where("user_id = ?", user.ID).Count(&count) + assert.Equal(t, int64(0), count) + }) +} + +func TestAuthHandler_DeleteAccount_WrongPassword(t *testing.T) { + handler, e, db := setupDeleteAccountHandler(t) + + user := testutil.CreateTestUser(t, db, "wrongpw", "wrongpw@test.com", "password123") + + authGroup := e.Group("/api/auth") + authGroup.Use(testutil.MockAuthMiddleware(user)) + authGroup.DELETE("/account/", handler.DeleteAccount) + + t.Run("wrong password returns 401", func(t *testing.T) { + wrongPw := "wrongpassword" + req := map[string]interface{}{ + "password": wrongPw, + } + + w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token") + + testutil.AssertStatusCode(t, w, http.StatusUnauthorized) + }) +} + +func TestAuthHandler_DeleteAccount_MissingPassword(t *testing.T) { + handler, e, db := setupDeleteAccountHandler(t) + + user := testutil.CreateTestUser(t, db, "nopw", "nopw@test.com", "password123") + + authGroup := e.Group("/api/auth") + authGroup.Use(testutil.MockAuthMiddleware(user)) + authGroup.DELETE("/account/", handler.DeleteAccount) + + t.Run("missing password returns 400", func(t *testing.T) { + req := map[string]interface{}{} + + w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token") + + testutil.AssertStatusCode(t, w, http.StatusBadRequest) + }) +} + +func TestAuthHandler_DeleteAccount_SocialAuthUser(t *testing.T) { + handler, e, db := setupDeleteAccountHandler(t) + + user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "randompassword") + + // Create Apple social auth record + appleAuth := &models.AppleSocialAuth{ + UserID: user.ID, + AppleID: "apple_sub_123", + Email: "apple@test.com", + } + require.NoError(t, db.Create(appleAuth).Error) + + // Create profile + profile := &models.UserProfile{UserID: user.ID, Verified: true} + require.NoError(t, db.Create(profile).Error) + + authGroup := e.Group("/api/auth") + authGroup.Use(testutil.MockAuthMiddleware(user)) + authGroup.DELETE("/account/", handler.DeleteAccount) + + t.Run("successful deletion with DELETE confirmation", func(t *testing.T) { + confirmation := "DELETE" + req := map[string]interface{}{ + "confirmation": confirmation, + } + + w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token") + + testutil.AssertStatusCode(t, w, http.StatusOK) + + // Verify user is deleted + var count int64 + db.Model(&models.User{}).Where("id = ?", user.ID).Count(&count) + assert.Equal(t, int64(0), count) + + // Verify apple auth is deleted + db.Model(&models.AppleSocialAuth{}).Where("user_id = ?", user.ID).Count(&count) + assert.Equal(t, int64(0), count) + }) +} + +func TestAuthHandler_DeleteAccount_SocialAuthMissingConfirmation(t *testing.T) { + handler, e, db := setupDeleteAccountHandler(t) + + user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "randompassword") + + // Create Google social auth record + googleAuth := &models.GoogleSocialAuth{ + UserID: user.ID, + GoogleID: "google_sub_456", + Email: "google@test.com", + } + require.NoError(t, db.Create(googleAuth).Error) + + authGroup := e.Group("/api/auth") + authGroup.Use(testutil.MockAuthMiddleware(user)) + authGroup.DELETE("/account/", handler.DeleteAccount) + + t.Run("missing confirmation returns 400", func(t *testing.T) { + req := map[string]interface{}{} + + w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token") + + testutil.AssertStatusCode(t, w, http.StatusBadRequest) + }) + + t.Run("wrong confirmation returns 400", func(t *testing.T) { + wrongConfirmation := "delete" + req := map[string]interface{}{ + "confirmation": wrongConfirmation, + } + + w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token") + + testutil.AssertStatusCode(t, w, http.StatusBadRequest) + }) +} + +func TestAuthHandler_DeleteAccount_Unauthenticated(t *testing.T) { + handler, e, _ := setupDeleteAccountHandler(t) + + // No auth middleware - unauthenticated request + e.DELETE("/api/auth/account/", handler.DeleteAccount) + + t.Run("unauthenticated request returns 401", func(t *testing.T) { + req := map[string]interface{}{ + "password": "password123", + } + + w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "") + + testutil.AssertStatusCode(t, w, http.StatusUnauthorized) + }) +} diff --git a/internal/handlers/media_handler.go b/internal/handlers/media_handler.go index 9c44b5e..464f576 100644 --- a/internal/handlers/media_handler.go +++ b/internal/handlers/media_handler.go @@ -1,6 +1,8 @@ package handlers import ( + "net/http" + "path/filepath" "strconv" "strings" @@ -60,15 +62,18 @@ func (h *MediaHandler) ServeDocument(c echo.Context) error { return apperrors.Forbidden("error.access_denied") } - // Serve the file - filePath := h.resolveFilePath(doc.FileURL) - if filePath == "" { + // Serve the file (supports encrypted files transparently) + data, mimeType, err := h.storageSvc.ReadFile(doc.FileURL) + if err != nil { return apperrors.NotFound("error.file_not_found") } - // Set caching headers (private, 1 hour) + // Set caching and disposition headers c.Response().Header().Set("Cache-Control", "private, max-age=3600") - return c.File(filePath) + if doc.FileName != "" { + c.Response().Header().Set("Content-Disposition", "inline; filename=\""+doc.FileName+"\"") + } + return c.Blob(http.StatusOK, mimeType, data) } // ServeDocumentImage serves a document image with access control @@ -102,14 +107,15 @@ func (h *MediaHandler) ServeDocumentImage(c echo.Context) error { return apperrors.Forbidden("error.access_denied") } - // Serve the file - filePath := h.resolveFilePath(img.ImageURL) - if filePath == "" { + // Serve the file (supports encrypted files transparently) + data, mimeType, err := h.storageSvc.ReadFile(img.ImageURL) + if err != nil { return apperrors.NotFound("error.file_not_found") } c.Response().Header().Set("Cache-Control", "private, max-age=3600") - return c.File(filePath) + c.Response().Header().Set("Content-Disposition", "inline; filename=\""+filepath.Base(img.ImageURL)+"\"") + return c.Blob(http.StatusOK, mimeType, data) } // ServeCompletionImage serves a task completion image with access control @@ -149,14 +155,15 @@ func (h *MediaHandler) ServeCompletionImage(c echo.Context) error { return apperrors.Forbidden("error.access_denied") } - // Serve the file - filePath := h.resolveFilePath(img.ImageURL) - if filePath == "" { + // Serve the file (supports encrypted files transparently) + data, mimeType, err := h.storageSvc.ReadFile(img.ImageURL) + if err != nil { return apperrors.NotFound("error.file_not_found") } c.Response().Header().Set("Cache-Control", "private, max-age=3600") - return c.File(filePath) + c.Response().Header().Set("Content-Disposition", "inline; filename=\""+filepath.Base(img.ImageURL)+"\"") + return c.Blob(http.StatusOK, mimeType, data) } // resolveFilePath converts a stored URL to an actual file path. diff --git a/internal/repositories/user_repo.go b/internal/repositories/user_repo.go index 251cc17..1670005 100644 --- a/internal/repositories/user_repo.go +++ b/internal/repositories/user_repo.go @@ -34,6 +34,12 @@ func NewUserRepository(db *gorm.DB) *UserRepository { return &UserRepository{db: db} } +// DB returns the underlying *gorm.DB connection. This is useful when callers +// need to pass the connection (e.g., a transaction) to methods that accept *gorm.DB. +func (r *UserRepository) DB() *gorm.DB { + return r.db +} + // Transaction runs fn inside a database transaction. The callback receives a // new UserRepository backed by the transaction so all operations within fn // share the same transactional connection. @@ -509,6 +515,195 @@ func (r *UserRepository) FindProfilesInSharedResidences(userID uint) ([]models.U return profiles, err } +// --- Auth Provider Detection --- + +// FindAuthProvider determines the auth provider for a user. +// Returns "apple", "google", or "email". +func (r *UserRepository) FindAuthProvider(userID uint) (string, error) { + var count int64 + if err := r.db.Model(&models.AppleSocialAuth{}).Where("user_id = ?", userID).Count(&count).Error; err != nil { + return "", err + } + if count > 0 { + return "apple", nil + } + + if err := r.db.Model(&models.GoogleSocialAuth{}).Where("user_id = ?", userID).Count(&count).Error; err != nil { + return "", err + } + if count > 0 { + return "google", nil + } + + return "email", nil +} + +// --- Account Deletion --- + +// DeleteUserCascade deletes a user and all related records in dependency order. +// Should be called on a repository backed by a transaction (via Transaction callback). +// Returns a list of file URLs that need to be deleted from disk after the transaction commits. +func (r *UserRepository) DeleteUserCascade(userID uint) ([]string, error) { + var fileURLs []string + db := r.db + + // 1. Push notification devices + if err := db.Where("user_id = ?", userID).Delete(&models.APNSDevice{}).Error; err != nil { + return nil, err + } + if err := db.Where("user_id = ?", userID).Delete(&models.GCMDevice{}).Error; err != nil { + return nil, err + } + + // 2. Notifications + if err := db.Where("user_id = ?", userID).Delete(&models.Notification{}).Error; err != nil { + return nil, err + } + + // 3. Notification preferences + if err := db.Where("user_id = ?", userID).Delete(&models.NotificationPreference{}).Error; err != nil { + return nil, err + } + + // 4. Task reminder logs + if err := db.Where("user_id = ?", userID).Delete(&models.TaskReminderLog{}).Error; err != nil { + return nil, err + } + + // 5. Find residences owned by user + var ownedResidences []models.Residence + if err := db.Where("owner_id = ?", userID).Find(&ownedResidences).Error; err != nil { + return nil, err + } + + for _, residence := range ownedResidences { + // Collect file URLs before deleting + + // Task completion images (via completion_id -> task_id -> residence_id) + var completionImageURLs []string + db.Model(&models.TaskCompletionImage{}). + Joins("JOIN task_taskcompletion ON task_taskcompletion.id = task_taskcompletionimage.completion_id"). + Joins("JOIN task_task ON task_task.id = task_taskcompletion.task_id"). + Where("task_task.residence_id = ?", residence.ID). + Pluck("task_taskcompletionimage.image_url", &completionImageURLs) + fileURLs = append(fileURLs, completionImageURLs...) + + // Delete task completion images + db.Exec(`DELETE FROM task_taskcompletionimage WHERE completion_id IN ( + SELECT tc.id FROM task_taskcompletion tc + JOIN task_task t ON t.id = tc.task_id + WHERE t.residence_id = ? + )`, residence.ID) + + // Delete task completions + db.Exec(`DELETE FROM task_taskcompletion WHERE task_id IN ( + SELECT id FROM task_task WHERE residence_id = ? + )`, residence.ID) + + // Document images (via document_id -> residence_id) + var docImageURLs []string + db.Model(&models.DocumentImage{}). + Joins("JOIN task_document ON task_document.id = task_documentimage.document_id"). + Where("task_document.residence_id = ?", residence.ID). + Pluck("task_documentimage.image_url", &docImageURLs) + fileURLs = append(fileURLs, docImageURLs...) + + // Delete document images + db.Exec(`DELETE FROM task_documentimage WHERE document_id IN ( + SELECT id FROM task_document WHERE residence_id = ? + )`, residence.ID) + + // Document file URLs + var docFileURLs []string + db.Model(&models.Document{}).Where("residence_id = ?", residence.ID).Pluck("file_url", &docFileURLs) + fileURLs = append(fileURLs, docFileURLs...) + + // Delete documents + if err := db.Where("residence_id = ?", residence.ID).Delete(&models.Document{}).Error; err != nil { + return nil, err + } + + // Delete tasks + if err := db.Where("residence_id = ?", residence.ID).Delete(&models.Task{}).Error; err != nil { + return nil, err + } + + // Delete contractor specialties (many-to-many join table) + db.Exec(`DELETE FROM task_contractor_specialties WHERE contractor_id IN ( + SELECT id FROM task_contractor WHERE residence_id = ? + )`, residence.ID) + + // Delete contractors + if err := db.Where("residence_id = ?", residence.ID).Delete(&models.Contractor{}).Error; err != nil { + return nil, err + } + + // Delete share codes + if err := db.Where("residence_id = ?", residence.ID).Delete(&models.ResidenceShareCode{}).Error; err != nil { + return nil, err + } + + // Remove residence membership records (many-to-many join table) + db.Exec("DELETE FROM residence_residence_users WHERE residence_id = ?", residence.ID) + + // Delete the residence itself + if err := db.Delete(&residence).Error; err != nil { + return nil, err + } + } + + // 6. Remove user from shared residences they don't own (membership only) + db.Exec("DELETE FROM residence_residence_users WHERE user_id = ?", userID) + + // 7. Subscription + if err := db.Where("user_id = ?", userID).Delete(&models.UserSubscription{}).Error; err != nil { + return nil, err + } + + // 8. Social auth records + if err := db.Where("user_id = ?", userID).Delete(&models.AppleSocialAuth{}).Error; err != nil { + return nil, err + } + if err := db.Where("user_id = ?", userID).Delete(&models.GoogleSocialAuth{}).Error; err != nil { + return nil, err + } + + // 9. Confirmation codes + if err := db.Where("user_id = ?", userID).Delete(&models.ConfirmationCode{}).Error; err != nil { + return nil, err + } + + // 10. Password reset codes + if err := db.Where("user_id = ?", userID).Delete(&models.PasswordResetCode{}).Error; err != nil { + return nil, err + } + + // 11. Auth tokens + if err := db.Where("user_id = ?", userID).Delete(&models.AuthToken{}).Error; err != nil { + return nil, err + } + + // 12. User profile + if err := db.Where("user_id = ?", userID).Delete(&models.UserProfile{}).Error; err != nil { + return nil, err + } + + // 13. User + if err := db.Where("id = ?", userID).Delete(&models.User{}).Error; err != nil { + return nil, err + } + + // Filter out empty URLs + var cleanURLs []string + for _, url := range fileURLs { + if url != "" { + cleanURLs = append(cleanURLs, url) + } + } + + return cleanURLs, nil +} + // --- Apple Social Auth Methods --- // FindByAppleID finds an Apple social auth by Apple ID diff --git a/internal/repositories/user_repo_test.go b/internal/repositories/user_repo_test.go index a77fe57..0f0fa6c 100644 --- a/internal/repositories/user_repo_test.go +++ b/internal/repositories/user_repo_test.go @@ -187,3 +187,169 @@ func TestUserRepository_GetOrCreateProfile(t *testing.T) { require.NoError(t, err) assert.Equal(t, profile1.ID, profile2.ID) } + +func TestUserRepository_FindAuthProvider(t *testing.T) { + db := testutil.SetupTestDB(t) + repo := NewUserRepository(db) + + t.Run("email user", func(t *testing.T) { + user := testutil.CreateTestUser(t, db, "emailuser", "email@test.com", "password123") + provider, err := repo.FindAuthProvider(user.ID) + require.NoError(t, err) + assert.Equal(t, "email", provider) + }) + + t.Run("apple user", func(t *testing.T) { + user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "password123") + appleAuth := &models.AppleSocialAuth{ + UserID: user.ID, + AppleID: "apple_sub_test", + Email: "apple@test.com", + } + require.NoError(t, db.Create(appleAuth).Error) + + provider, err := repo.FindAuthProvider(user.ID) + require.NoError(t, err) + assert.Equal(t, "apple", provider) + }) + + t.Run("google user", func(t *testing.T) { + user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "password123") + googleAuth := &models.GoogleSocialAuth{ + UserID: user.ID, + GoogleID: "google_sub_test", + Email: "google@test.com", + } + require.NoError(t, db.Create(googleAuth).Error) + + provider, err := repo.FindAuthProvider(user.ID) + require.NoError(t, err) + assert.Equal(t, "google", provider) + }) +} + +func TestUserRepository_DeleteUserCascade(t *testing.T) { + t.Run("deletes user with no residences", func(t *testing.T) { + db := testutil.SetupTestDB(t) + repo := NewUserRepository(db) + + user := testutil.CreateTestUser(t, db, "deletebare", "deletebare@test.com", "password123") + + // Create profile and token + profile := &models.UserProfile{UserID: user.ID, Verified: true} + require.NoError(t, db.Create(profile).Error) + _, err := models.GetOrCreateToken(db, user.ID) + require.NoError(t, err) + + var fileURLs []string + txErr := repo.Transaction(func(txRepo *UserRepository) error { + urls, err := txRepo.DeleteUserCascade(user.ID) + if err != nil { + return err + } + fileURLs = urls + return nil + }) + require.NoError(t, txErr) + assert.Empty(t, fileURLs) + + // Verify user is gone + var count int64 + db.Model(&models.User{}).Where("id = ?", user.ID).Count(&count) + assert.Equal(t, int64(0), count) + + // Verify profile is gone + db.Model(&models.UserProfile{}).Where("user_id = ?", user.ID).Count(&count) + assert.Equal(t, int64(0), count) + + // Verify token is gone + db.Model(&models.AuthToken{}).Where("user_id = ?", user.ID).Count(&count) + assert.Equal(t, int64(0), count) + }) + + t.Run("returns file URLs for cleanup", func(t *testing.T) { + db := testutil.SetupTestDB(t) + repo := NewUserRepository(db) + + user := testutil.CreateTestUser(t, db, "deletefiles", "deletefiles@test.com", "password123") + residence := testutil.CreateTestResidence(t, db, user.ID, "Test Home") + + // Create document with file + doc := &models.Document{ + ResidenceID: residence.ID, + CreatedByID: user.ID, + Title: "Test Doc", + FileURL: "/uploads/documents/test.pdf", + } + require.NoError(t, db.Create(doc).Error) + + // Create document image + docImage := &models.DocumentImage{ + DocumentID: doc.ID, + ImageURL: "/uploads/images/docimg.jpg", + } + require.NoError(t, db.Create(docImage).Error) + + var fileURLs []string + txErr := repo.Transaction(func(txRepo *UserRepository) error { + urls, err := txRepo.DeleteUserCascade(user.ID) + if err != nil { + return err + } + fileURLs = urls + return nil + }) + require.NoError(t, txErr) + + // Should return the file URLs + assert.Contains(t, fileURLs, "/uploads/documents/test.pdf") + assert.Contains(t, fileURLs, "/uploads/images/docimg.jpg") + + // Verify everything deleted + var count int64 + db.Model(&models.User{}).Where("id = ?", user.ID).Count(&count) + assert.Equal(t, int64(0), count) + db.Model(&models.Residence{}).Where("id = ?", residence.ID).Count(&count) + assert.Equal(t, int64(0), count) + db.Model(&models.Document{}).Where("id = ?", doc.ID).Count(&count) + assert.Equal(t, int64(0), count) + }) + + t.Run("handles user with owned and shared residences", func(t *testing.T) { + db := testutil.SetupTestDB(t) + repo := NewUserRepository(db) + + owner := testutil.CreateTestUser(t, db, "deleteowner", "deleteowner@test.com", "password123") + otherUser := testutil.CreateTestUser(t, db, "otheruser", "other@test.com", "password123") + otherResidence := testutil.CreateTestResidence(t, db, otherUser.ID, "Other Home") + + // Owner's residence + ownedResidence := testutil.CreateTestResidence(t, db, owner.ID, "Owner Home") + + // Add owner as member of other user's residence + db.Exec("INSERT INTO residence_residence_users (residence_id, user_id) VALUES (?, ?)", otherResidence.ID, owner.ID) + + txErr := repo.Transaction(func(txRepo *UserRepository) error { + _, err := txRepo.DeleteUserCascade(owner.ID) + return err + }) + require.NoError(t, txErr) + + // Owner's residence should be deleted + var count int64 + db.Model(&models.Residence{}).Where("id = ?", ownedResidence.ID).Count(&count) + assert.Equal(t, int64(0), count) + + // Other user's residence should still exist + db.Model(&models.Residence{}).Where("id = ?", otherResidence.ID).Count(&count) + assert.Equal(t, int64(1), count) + + // Owner should no longer be a member of other's residence + db.Raw("SELECT COUNT(*) FROM residence_residence_users WHERE user_id = ?", owner.ID).Count(&count) + assert.Equal(t, int64(0), count) + + // Other user should still exist + db.Model(&models.User{}).Where("id = ?", otherUser.ID).Count(&count) + assert.Equal(t, int64(1), count) + }) +} diff --git a/internal/router/router.go b/internal/router/router.go index b595e19..32399fb 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -182,6 +182,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo { authHandler := handlers.NewAuthHandler(authService, deps.EmailService, deps.Cache) authHandler.SetAppleAuthService(appleAuthService) authHandler.SetGoogleAuthService(googleAuthService) + authHandler.SetStorageService(deps.StorageService) userHandler := handlers.NewUserHandler(userService) residenceHandler := handlers.NewResidenceHandler(residenceService, deps.PDFService, deps.EmailService, cfg.Features.PDFReportsEnabled) taskHandler := handlers.NewTaskHandler(taskService, deps.StorageService) @@ -347,6 +348,7 @@ func setupProtectedAuthRoutes(api *echo.Group, authHandler *handlers.AuthHandler auth.POST("/verify/", authHandler.VerifyEmail) // Alias for mobile app compatibility auth.POST("/verify-email/", authHandler.VerifyEmail) // Original route auth.POST("/resend-verification/", authHandler.ResendVerification) + auth.DELETE("/account/", authHandler.DeleteAccount) } } diff --git a/internal/services/auth_service.go b/internal/services/auth_service.go index bda3d8f..edd2f2d 100644 --- a/internal/services/auth_service.go +++ b/internal/services/auth_service.go @@ -200,10 +200,69 @@ func (s *AuthService) GetCurrentUser(userID uint) (*responses.CurrentUserRespons return nil, err } - response := responses.NewCurrentUserResponse(user) + authProvider, err := s.userRepo.FindAuthProvider(userID) + if err != nil { + // Log but don't fail - default to "email" + log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to determine auth provider") + authProvider = "email" + } + + response := responses.NewCurrentUserResponse(user, authProvider) return &response, nil } +// DeleteAccount deletes a user's account and all associated data. +// For email auth users, password verification is required. +// For social auth users, confirmation string "DELETE" is required. +// Returns a list of file URLs that need to be deleted from disk. +func (s *AuthService) DeleteAccount(userID uint, password, confirmation *string) ([]string, error) { + // Fetch user + user, err := s.userRepo.FindByID(userID) + if err != nil { + if errors.Is(err, repositories.ErrUserNotFound) { + return nil, apperrors.NotFound("error.user_not_found") + } + return nil, apperrors.Internal(err) + } + + // Determine auth provider + authProvider, err := s.userRepo.FindAuthProvider(userID) + if err != nil { + return nil, apperrors.Internal(err) + } + + // Validate credentials based on auth provider + if authProvider == "email" { + if password == nil || *password == "" { + return nil, apperrors.BadRequest("error.password_required") + } + if !user.CheckPassword(*password) { + return nil, apperrors.Unauthorized("error.invalid_credentials") + } + } else { + // Social auth (apple or google) - require confirmation + if confirmation == nil || *confirmation != "DELETE" { + return nil, apperrors.BadRequest("error.confirmation_required") + } + } + + // Start transaction and cascade delete + var fileURLs []string + txErr := s.userRepo.Transaction(func(txRepo *repositories.UserRepository) error { + urls, err := txRepo.DeleteUserCascade(userID) + if err != nil { + return err + } + fileURLs = urls + return nil + }) + if txErr != nil { + return nil, apperrors.Internal(txErr) + } + + return fileURLs, nil +} + // UpdateProfile updates a user's profile func (s *AuthService) UpdateProfile(userID uint, req *requests.UpdateProfileRequest) (*responses.CurrentUserResponse, error) { user, err := s.userRepo.FindByID(userID) @@ -240,7 +299,13 @@ func (s *AuthService) UpdateProfile(userID uint, req *requests.UpdateProfileRequ return nil, err } - response := responses.NewCurrentUserResponse(user) + authProvider, err := s.userRepo.FindAuthProvider(userID) + if err != nil { + log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to determine auth provider") + authProvider = "email" + } + + response := responses.NewCurrentUserResponse(user, authProvider) return &response, nil } diff --git a/internal/services/encryption_service.go b/internal/services/encryption_service.go new file mode 100644 index 0000000..105fc36 --- /dev/null +++ b/internal/services/encryption_service.go @@ -0,0 +1,179 @@ +package services + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/hex" + "fmt" + "io" +) + +const ( + // encryptionVersion is the current file format version byte. + encryptionVersion byte = 0x01 + + // aes256KeyLen is the required key length in bytes for AES-256. + aes256KeyLen = 32 + + // gcmNonceSize is the standard GCM nonce size (12 bytes). + gcmNonceSize = 12 + + // gcmTagSize is the standard GCM authentication tag size (16 bytes). + gcmTagSize = 16 + + // encryptedDEKSize is the encrypted DEK length: 32-byte DEK + 16-byte GCM tag. + encryptedDEKSize = aes256KeyLen + gcmTagSize + + // headerSize is the fixed header: version(1) + KEK nonce(12) + encrypted DEK(48) + DEK nonce(12). + headerSize = 1 + gcmNonceSize + encryptedDEKSize + gcmNonceSize +) + +// EncryptionService provides AES-256-GCM envelope encryption for files at rest. +type EncryptionService struct { + kek []byte // Key Encryption Key (32 bytes) +} + +// NewEncryptionService creates an EncryptionService from a 64-character hex-encoded KEK. +func NewEncryptionService(hexKey string) (*EncryptionService, error) { + if len(hexKey) != 64 { + return nil, fmt.Errorf("encryption key must be exactly 64 hex characters (got %d)", len(hexKey)) + } + + kek, err := hex.DecodeString(hexKey) + if err != nil { + return nil, fmt.Errorf("invalid hex in encryption key: %w", err) + } + + if len(kek) != aes256KeyLen { + return nil, fmt.Errorf("decoded key must be %d bytes", aes256KeyLen) + } + + return &EncryptionService{kek: kek}, nil +} + +// IsEnabled returns true if the encryption service is configured and ready. +func (s *EncryptionService) IsEnabled() bool { + return s != nil && len(s.kek) == aes256KeyLen +} + +// Encrypt encrypts plaintext using envelope encryption (random DEK encrypted with KEK). +// +// File format: +// +// [1-byte version 0x01] +// [12-byte KEK nonce] +// [48-byte encrypted DEK (32-byte DEK + 16-byte GCM tag)] +// [12-byte DEK nonce] +// [ciphertext + 16-byte GCM tag] +func (s *EncryptionService) Encrypt(plaintext []byte) ([]byte, error) { + // Generate a random Data Encryption Key (DEK) + dek := make([]byte, aes256KeyLen) + if _, err := io.ReadFull(rand.Reader, dek); err != nil { + return nil, fmt.Errorf("failed to generate DEK: %w", err) + } + + // Encrypt the DEK with the KEK + kekBlock, err := aes.NewCipher(s.kek) + if err != nil { + return nil, fmt.Errorf("failed to create KEK cipher: %w", err) + } + kekGCM, err := cipher.NewGCM(kekBlock) + if err != nil { + return nil, fmt.Errorf("failed to create KEK GCM: %w", err) + } + + kekNonce := make([]byte, gcmNonceSize) + if _, err := io.ReadFull(rand.Reader, kekNonce); err != nil { + return nil, fmt.Errorf("failed to generate KEK nonce: %w", err) + } + encryptedDEK := kekGCM.Seal(nil, kekNonce, dek, nil) + + // Encrypt the plaintext with the DEK + dekBlock, err := aes.NewCipher(dek) + if err != nil { + return nil, fmt.Errorf("failed to create DEK cipher: %w", err) + } + dekGCM, err := cipher.NewGCM(dekBlock) + if err != nil { + return nil, fmt.Errorf("failed to create DEK GCM: %w", err) + } + + dekNonce := make([]byte, gcmNonceSize) + if _, err := io.ReadFull(rand.Reader, dekNonce); err != nil { + return nil, fmt.Errorf("failed to generate DEK nonce: %w", err) + } + ciphertext := dekGCM.Seal(nil, dekNonce, plaintext, nil) + + // Pack the output: version + kekNonce + encryptedDEK + dekNonce + ciphertext + out := make([]byte, 0, headerSize+len(ciphertext)) + out = append(out, encryptionVersion) + out = append(out, kekNonce...) + out = append(out, encryptedDEK...) + out = append(out, dekNonce...) + out = append(out, ciphertext...) + + return out, nil +} + +// Decrypt reverses the Encrypt operation, recovering the original plaintext. +func (s *EncryptionService) Decrypt(blob []byte) ([]byte, error) { + if len(blob) < headerSize { + return nil, fmt.Errorf("ciphertext too short (%d bytes, minimum %d)", len(blob), headerSize) + } + + // Parse version + version := blob[0] + if version != encryptionVersion { + return nil, fmt.Errorf("unsupported encryption version: 0x%02x", version) + } + + offset := 1 + + // Parse KEK nonce + kekNonce := blob[offset : offset+gcmNonceSize] + offset += gcmNonceSize + + // Parse encrypted DEK + encryptedDEK := blob[offset : offset+encryptedDEKSize] + offset += encryptedDEKSize + + // Parse DEK nonce + dekNonce := blob[offset : offset+gcmNonceSize] + offset += gcmNonceSize + + // Remaining bytes are the ciphertext + GCM tag + ciphertext := blob[offset:] + + // Decrypt the DEK with the KEK + kekBlock, err := aes.NewCipher(s.kek) + if err != nil { + return nil, fmt.Errorf("failed to create KEK cipher: %w", err) + } + kekGCM, err := cipher.NewGCM(kekBlock) + if err != nil { + return nil, fmt.Errorf("failed to create KEK GCM: %w", err) + } + + dek, err := kekGCM.Open(nil, kekNonce, encryptedDEK, nil) + if err != nil { + return nil, fmt.Errorf("failed to decrypt DEK (wrong key?): %w", err) + } + + // Decrypt the plaintext with the DEK + dekBlock, err := aes.NewCipher(dek) + if err != nil { + return nil, fmt.Errorf("failed to create DEK cipher: %w", err) + } + dekGCM, err := cipher.NewGCM(dekBlock) + if err != nil { + return nil, fmt.Errorf("failed to create DEK GCM: %w", err) + } + + plaintext, err := dekGCM.Open(nil, dekNonce, ciphertext, nil) + if err != nil { + return nil, fmt.Errorf("failed to decrypt data (tampered?): %w", err) + } + + return plaintext, nil +} diff --git a/internal/services/encryption_service_test.go b/internal/services/encryption_service_test.go new file mode 100644 index 0000000..3dbb4d3 --- /dev/null +++ b/internal/services/encryption_service_test.go @@ -0,0 +1,218 @@ +package services + +import ( + "bytes" + "crypto/rand" + "encoding/hex" + "testing" +) + +// validTestKey returns a deterministic 64-char hex key for tests. +func validTestKey() string { + return "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" +} + +// randomHexKey generates a random 64-char hex key. +func randomHexKey(t *testing.T) string { + t.Helper() + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + t.Fatal(err) + } + return hex.EncodeToString(b) +} + +func TestNewEncryptionService_Valid(t *testing.T) { + svc, err := NewEncryptionService(validTestKey()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if svc == nil { + t.Fatal("expected non-nil service") + } +} + +func TestNewEncryptionService_InvalidHex(t *testing.T) { + // 64 chars but not valid hex + _, err := NewEncryptionService("zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz") + if err == nil { + t.Fatal("expected error for invalid hex") + } +} + +func TestNewEncryptionService_WrongLength(t *testing.T) { + tests := []struct { + name string + key string + }{ + {"too short", "0123456789abcdef"}, + {"too long", "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef00"}, + {"empty", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewEncryptionService(tt.key) + if err == nil { + t.Fatal("expected error for wrong length key") + } + }) + } +} + +func TestEncryptDecrypt_RoundTrip(t *testing.T) { + svc, err := NewEncryptionService(validTestKey()) + if err != nil { + t.Fatal(err) + } + + plaintext := []byte("Hello, encryption at rest!") + + ciphertext, err := svc.Encrypt(plaintext) + if err != nil { + t.Fatalf("encrypt failed: %v", err) + } + + decrypted, err := svc.Decrypt(ciphertext) + if err != nil { + t.Fatalf("decrypt failed: %v", err) + } + + if !bytes.Equal(plaintext, decrypted) { + t.Fatalf("round-trip mismatch: got %q, want %q", decrypted, plaintext) + } +} + +func TestEncryptDecrypt_EmptyPlaintext(t *testing.T) { + svc, err := NewEncryptionService(validTestKey()) + if err != nil { + t.Fatal(err) + } + + ciphertext, err := svc.Encrypt([]byte{}) + if err != nil { + t.Fatalf("encrypt failed: %v", err) + } + + decrypted, err := svc.Decrypt(ciphertext) + if err != nil { + t.Fatalf("decrypt failed: %v", err) + } + + if len(decrypted) != 0 { + t.Fatalf("expected empty plaintext, got %d bytes", len(decrypted)) + } +} + +func TestEncrypt_DifferentCiphertexts(t *testing.T) { + svc, err := NewEncryptionService(validTestKey()) + if err != nil { + t.Fatal(err) + } + + plaintext := []byte("same input") + ct1, err := svc.Encrypt(plaintext) + if err != nil { + t.Fatal(err) + } + ct2, err := svc.Encrypt(plaintext) + if err != nil { + t.Fatal(err) + } + + if bytes.Equal(ct1, ct2) { + t.Fatal("encrypting the same plaintext twice should produce different ciphertexts") + } +} + +func TestDecrypt_TamperDetection(t *testing.T) { + svc, err := NewEncryptionService(validTestKey()) + if err != nil { + t.Fatal(err) + } + + ciphertext, err := svc.Encrypt([]byte("sensitive data")) + if err != nil { + t.Fatal(err) + } + + // Flip a byte near the end (in the ciphertext portion) + tampered := make([]byte, len(ciphertext)) + copy(tampered, ciphertext) + tampered[len(tampered)-1] ^= 0xFF + + _, err = svc.Decrypt(tampered) + if err == nil { + t.Fatal("expected error when decrypting tampered ciphertext") + } +} + +func TestDecrypt_WrongKEK(t *testing.T) { + svc1, err := NewEncryptionService(validTestKey()) + if err != nil { + t.Fatal(err) + } + + ciphertext, err := svc1.Encrypt([]byte("secret")) + if err != nil { + t.Fatal(err) + } + + // Create a second service with a different key + svc2, err := NewEncryptionService(randomHexKey(t)) + if err != nil { + t.Fatal(err) + } + + _, err = svc2.Decrypt(ciphertext) + if err == nil { + t.Fatal("expected error when decrypting with wrong KEK") + } +} + +func TestDecrypt_TooShort(t *testing.T) { + svc, err := NewEncryptionService(validTestKey()) + if err != nil { + t.Fatal(err) + } + + _, err = svc.Decrypt([]byte("short")) + if err == nil { + t.Fatal("expected error for too-short ciphertext") + } +} + +func TestDecrypt_BadVersion(t *testing.T) { + svc, err := NewEncryptionService(validTestKey()) + if err != nil { + t.Fatal(err) + } + + ciphertext, err := svc.Encrypt([]byte("data")) + if err != nil { + t.Fatal(err) + } + + // Change version byte + ciphertext[0] = 0xFF + + _, err = svc.Decrypt(ciphertext) + if err == nil { + t.Fatal("expected error for bad version") + } +} + +func TestIsEnabled(t *testing.T) { + svc, err := NewEncryptionService(validTestKey()) + if err != nil { + t.Fatal(err) + } + + if !svc.IsEnabled() { + t.Fatal("expected IsEnabled() to return true") + } + + var nilSvc *EncryptionService + if nilSvc.IsEnabled() { + t.Fatal("expected IsEnabled() to return false for nil service") + } +} diff --git a/internal/services/storage_service.go b/internal/services/storage_service.go index 2d370d7..971acba 100644 --- a/internal/services/storage_service.go +++ b/internal/services/storage_service.go @@ -18,8 +18,9 @@ import ( // StorageService handles file uploads to local filesystem type StorageService struct { - cfg *config.StorageConfig - allowedTypes map[string]struct{} // P-12: Parsed once at init for O(1) lookups + cfg *config.StorageConfig + allowedTypes map[string]struct{} // P-12: Parsed once at init for O(1) lookups + encryptionSvc *EncryptionService } // UploadResult contains information about an uploaded file @@ -124,28 +125,39 @@ func (s *StorageService) Upload(file *multipart.FileHeader, category string) (*U subdir = "completions" } + // If encryption is enabled, append .enc suffix to the stored filename + storedFilename := newFilename + if s.encryptionSvc.IsEnabled() { + storedFilename = newFilename + ".enc" + } + // S-18: Sanitize path to prevent traversal attacks - destPath, err := SafeResolvePath(s.cfg.UploadDir, filepath.Join(subdir, newFilename)) + destPath, err := SafeResolvePath(s.cfg.UploadDir, filepath.Join(subdir, storedFilename)) if err != nil { return nil, fmt.Errorf("invalid upload path: %w", err) } - // Create destination file - dst, err := os.Create(destPath) + // Read all file content into memory for potential encryption + fileData, err := io.ReadAll(src) if err != nil { - return nil, fmt.Errorf("failed to create destination file: %w", err) + return nil, fmt.Errorf("failed to read file content: %w", err) } - defer dst.Close() - // Copy file content - written, err := io.Copy(dst, src) - if err != nil { - // Clean up on error - os.Remove(destPath) + // Encrypt if encryption is enabled + if s.encryptionSvc.IsEnabled() { + fileData, err = s.encryptionSvc.Encrypt(fileData) + if err != nil { + return nil, fmt.Errorf("failed to encrypt file: %w", err) + } + } + + // Write file content to disk + if err := os.WriteFile(destPath, fileData, 0644); err != nil { return nil, fmt.Errorf("failed to save file: %w", err) } + written := int64(len(fileData)) - // Generate URL + // Generate URL (always uses the original filename without .enc suffix for the public URL) url := fmt.Sprintf("%s/%s/%s", s.cfg.BaseURL, subdir, newFilename) log.Info(). @@ -163,7 +175,61 @@ func (s *StorageService) Upload(file *multipart.FileHeader, category string) (*U }, nil } -// Delete removes a file from storage +// ReadFile reads and optionally decrypts a stored file. It returns the plaintext +// bytes and the detected MIME type. If the file is stored with an .enc suffix, +// it is decrypted automatically. +func (s *StorageService) ReadFile(storedURL string) ([]byte, string, error) { + if storedURL == "" { + return nil, "", fmt.Errorf("empty file URL") + } + + // Strip base URL prefix to get relative path + relativePath := strings.TrimPrefix(storedURL, s.cfg.BaseURL) + relativePath = strings.TrimPrefix(relativePath, "/") + + // Try .enc variant first, then plain file + var fullPath string + var encrypted bool + + encPath, err := SafeResolvePath(s.cfg.UploadDir, relativePath+".enc") + if err == nil { + if _, statErr := os.Stat(encPath); statErr == nil { + fullPath = encPath + encrypted = true + } + } + + if fullPath == "" { + plainPath, err := SafeResolvePath(s.cfg.UploadDir, relativePath) + if err != nil { + return nil, "", fmt.Errorf("invalid file path: %w", err) + } + fullPath = plainPath + } + + data, err := os.ReadFile(fullPath) + if err != nil { + return nil, "", fmt.Errorf("failed to read file: %w", err) + } + + // Decrypt if this is an encrypted file + if encrypted { + if s.encryptionSvc == nil || !s.encryptionSvc.IsEnabled() { + return nil, "", fmt.Errorf("encrypted file found but encryption service is not configured") + } + data, err = s.encryptionSvc.Decrypt(data) + if err != nil { + return nil, "", fmt.Errorf("failed to decrypt file: %w", err) + } + } + + // Detect MIME type from decrypted content + mimeType := http.DetectContentType(data) + + return data, mimeType, nil +} + +// Delete removes a file from storage, handling both plain and .enc variants func (s *StorageService) Delete(fileURL string) error { // Convert URL to file path relativePath := strings.TrimPrefix(fileURL, s.cfg.BaseURL) @@ -175,14 +241,35 @@ func (s *StorageService) Delete(fileURL string) error { return fmt.Errorf("invalid file path: %w", err) } + // Try to delete the plain file + plainDeleted := false if err := os.Remove(fullPath); err != nil { - if os.IsNotExist(err) { - return nil // File already doesn't exist + if !os.IsNotExist(err) { + return fmt.Errorf("failed to delete file: %w", err) } - return fmt.Errorf("failed to delete file: %w", err) + } else { + plainDeleted = true + log.Info().Str("path", fullPath).Msg("File deleted") + } + + // Also try to delete the .enc variant + encPath, err := SafeResolvePath(s.cfg.UploadDir, relativePath+".enc") + if err == nil { + if err := os.Remove(encPath); err != nil { + if !os.IsNotExist(err) { + return fmt.Errorf("failed to delete encrypted file: %w", err) + } + } else { + log.Info().Str("path", encPath).Msg("Encrypted file deleted") + return nil + } + } + + if !plainDeleted { + // Neither file existed — that's OK + return nil } - log.Info().Str("path", fullPath).Msg("File deleted") return nil } @@ -225,6 +312,11 @@ func (s *StorageService) GetUploadDir() string { return s.cfg.UploadDir } +// SetEncryptionService sets the encryption service for encrypting files at rest +func (s *StorageService) SetEncryptionService(svc *EncryptionService) { + s.encryptionSvc = svc +} + // NewStorageServiceForTest creates a StorageService without creating directories. // This is intended only for unit tests that need a StorageService with a known config. func NewStorageServiceForTest(cfg *config.StorageConfig) *StorageService { diff --git a/internal/services/storage_service_test.go b/internal/services/storage_service_test.go new file mode 100644 index 0000000..1c8fa03 --- /dev/null +++ b/internal/services/storage_service_test.go @@ -0,0 +1,164 @@ +package services + +import ( + "bytes" + "os" + "path/filepath" + "testing" + + "github.com/treytartt/honeydue-api/internal/config" +) + +func setupTestStorage(t *testing.T, encrypt bool) (*StorageService, string) { + t.Helper() + + tmpDir := t.TempDir() + + cfg := &config.StorageConfig{ + UploadDir: tmpDir, + BaseURL: "/uploads", + MaxFileSize: 10 * 1024 * 1024, + AllowedTypes: "image/jpeg,image/png,application/pdf", + } + + svc, err := NewStorageService(cfg) + if err != nil { + t.Fatalf("failed to create storage service: %v", err) + } + + if encrypt { + encSvc, err := NewEncryptionService(validTestKey()) + if err != nil { + t.Fatalf("failed to create encryption service: %v", err) + } + svc.SetEncryptionService(encSvc) + } + + return svc, tmpDir +} + +func TestReadFile_PlainFile(t *testing.T) { + svc, tmpDir := setupTestStorage(t, false) + + // Write a plain file + content := []byte("hello world") + dir := filepath.Join(tmpDir, "images") + if err := os.WriteFile(filepath.Join(dir, "test.jpg"), content, 0644); err != nil { + t.Fatal(err) + } + + data, mimeType, err := svc.ReadFile("/uploads/images/test.jpg") + if err != nil { + t.Fatalf("ReadFile failed: %v", err) + } + + if !bytes.Equal(data, content) { + t.Fatalf("data mismatch: got %q, want %q", data, content) + } + + if mimeType == "" { + t.Fatal("expected non-empty MIME type") + } +} + +func TestReadFile_EncryptedFile(t *testing.T) { + svc, tmpDir := setupTestStorage(t, true) + + // Encrypt and write a file manually (simulating what Upload does) + originalContent := []byte("sensitive document content here - must be long enough for detection") + encSvc, _ := NewEncryptionService(validTestKey()) + encrypted, err := encSvc.Encrypt(originalContent) + if err != nil { + t.Fatal(err) + } + + dir := filepath.Join(tmpDir, "documents") + if err := os.WriteFile(filepath.Join(dir, "test.pdf.enc"), encrypted, 0644); err != nil { + t.Fatal(err) + } + + // ReadFile should find the .enc file and decrypt it + data, _, err := svc.ReadFile("/uploads/documents/test.pdf") + if err != nil { + t.Fatalf("ReadFile failed: %v", err) + } + + if !bytes.Equal(data, originalContent) { + t.Fatalf("decrypted data mismatch: got %q, want %q", data, originalContent) + } +} + +func TestReadFile_EncFilePreferredOverPlain(t *testing.T) { + svc, tmpDir := setupTestStorage(t, true) + + encSvc, _ := NewEncryptionService(validTestKey()) + + plainContent := []byte("plain version") + encContent := []byte("encrypted version - the correct one") + encrypted, _ := encSvc.Encrypt(encContent) + + dir := filepath.Join(tmpDir, "images") + os.WriteFile(filepath.Join(dir, "photo.jpg"), plainContent, 0644) + os.WriteFile(filepath.Join(dir, "photo.jpg.enc"), encrypted, 0644) + + // Should prefer the .enc file + data, _, err := svc.ReadFile("/uploads/images/photo.jpg") + if err != nil { + t.Fatalf("ReadFile failed: %v", err) + } + + if !bytes.Equal(data, encContent) { + t.Fatalf("expected encrypted version content, got %q", data) + } +} + +func TestReadFile_EmptyURL(t *testing.T) { + svc, _ := setupTestStorage(t, false) + + _, _, err := svc.ReadFile("") + if err == nil { + t.Fatal("expected error for empty URL") + } +} + +func TestReadFile_MissingFile(t *testing.T) { + svc, _ := setupTestStorage(t, false) + + _, _, err := svc.ReadFile("/uploads/images/nonexistent.jpg") + if err == nil { + t.Fatal("expected error for missing file") + } +} + +func TestDelete_HandlesEncAndPlain(t *testing.T) { + svc, tmpDir := setupTestStorage(t, false) + + dir := filepath.Join(tmpDir, "images") + + // Create both plain and .enc files + os.WriteFile(filepath.Join(dir, "photo.jpg"), []byte("plain"), 0644) + os.WriteFile(filepath.Join(dir, "photo.jpg.enc"), []byte("encrypted"), 0644) + + err := svc.Delete("/uploads/images/photo.jpg") + if err != nil { + t.Fatalf("Delete failed: %v", err) + } + + // Both should be gone + if _, err := os.Stat(filepath.Join(dir, "photo.jpg")); !os.IsNotExist(err) { + t.Fatal("plain file should be deleted") + } + if _, err := os.Stat(filepath.Join(dir, "photo.jpg.enc")); !os.IsNotExist(err) { + t.Fatal("encrypted file should be deleted") + } +} + +func TestDelete_NonexistentFile(t *testing.T) { + svc, _ := setupTestStorage(t, false) + + // Should not error for non-existent files + err := svc.Delete("/uploads/images/nope.jpg") + if err != nil { + t.Fatalf("Delete should not error for non-existent file: %v", err) + } +} diff --git a/internal/services/task_service.go b/internal/services/task_service.go index ac44c31..d8a6cff 100644 --- a/internal/services/task_service.go +++ b/internal/services/task_service.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "os" "path/filepath" "strings" "time" @@ -848,39 +847,33 @@ func (s *TaskService) sendTaskCompletedNotification(task *models.Task, completio } } -// loadCompletionImagesForEmail reads completion images from disk and prepares them for email embedding +// loadCompletionImagesForEmail reads completion images from disk and prepares them for email embedding. +// Uses StorageService.ReadFile to transparently handle encrypted files. func (s *TaskService) loadCompletionImagesForEmail(images []models.TaskCompletionImage) []EmbeddedImage { var emailImages []EmbeddedImage - uploadDir := s.storageService.GetUploadDir() - for i, img := range images { - // Resolve file path from stored URL - filePath := s.resolveImageFilePath(img.ImageURL, uploadDir) - if filePath == "" { - log.Warn().Str("image_url", img.ImageURL).Msg("Could not resolve image file path") - continue - } - - // Read file from disk - data, err := os.ReadFile(filePath) + // Read file via storage service (handles encryption transparently) + data, mimeType, err := s.storageService.ReadFile(img.ImageURL) if err != nil { - log.Warn().Err(err).Str("path", filePath).Msg("Failed to read completion image for email") + log.Warn().Err(err).Str("image_url", img.ImageURL).Msg("Failed to read completion image for email") continue } - // Determine content type from extension - contentType := s.getContentTypeFromPath(filePath) + // Use detected MIME type, fall back to extension-based detection + if mimeType == "application/octet-stream" { + mimeType = s.getContentTypeFromPath(img.ImageURL) + } // Create embedded image with unique Content-ID emailImages = append(emailImages, EmbeddedImage{ ContentID: fmt.Sprintf("completion-image-%d", i+1), - Filename: filepath.Base(filePath), - ContentType: contentType, + Filename: filepath.Base(img.ImageURL), + ContentType: mimeType, Data: data, }) - log.Debug().Str("path", filePath).Int("size", len(data)).Msg("Loaded completion image for email") + log.Debug().Str("image_url", img.ImageURL).Int("size", len(data)).Msg("Loaded completion image for email") } return emailImages diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go index a9165cb..eb9e6f6 100644 --- a/internal/testutil/testutil.go +++ b/internal/testutil/testutil.go @@ -60,6 +60,9 @@ func SetupTestDB(t *testing.T) *gorm.DB { &models.NotificationPreference{}, &models.APNSDevice{}, &models.GCMDevice{}, + &models.AppleSocialAuth{}, + &models.GoogleSocialAuth{}, + &models.TaskReminderLog{}, &models.UserSubscription{}, &models.SubscriptionSettings{}, &models.TierLimits{},