Add delete account endpoint and file encryption at rest
Delete Account (Plan #2): - DELETE /api/auth/account/ with password or "DELETE" confirmation - Cascade delete across 15+ tables in correct FK order - Auth provider detection (email/apple/google) for /auth/me/ - File cleanup after account deletion - Handler + repository tests (12 tests) Encryption at Rest (Plan #3): - AES-256-GCM envelope encryption (per-file DEK wrapped by KEK) - Encrypt on upload, auto-decrypt on serve via StorageService.ReadFile() - MediaHandler serves decrypted files via c.Blob() - TaskService email image loading uses ReadFile() - cmd/migrate-encrypt CLI tool with --dry-run for existing files - Encryption service + storage service tests (18 tests)
This commit is contained in:
@@ -50,6 +50,8 @@ STORAGE_UPLOAD_DIR=./uploads
|
||||
STORAGE_BASE_URL=/uploads
|
||||
STORAGE_MAX_FILE_SIZE=10485760
|
||||
STORAGE_ALLOWED_TYPES=image/jpeg,image/png,image/gif,image/webp,application/pdf
|
||||
# 64-char hex key for file encryption at rest. Generate with: openssl rand -hex 32
|
||||
STORAGE_ENCRYPTION_KEY=
|
||||
|
||||
# Feature Flags (Kill Switches)
|
||||
# Set to false to disable. All default to true (enabled).
|
||||
|
||||
17
Makefile
17
Makefile
@@ -1,4 +1,4 @@
|
||||
.PHONY: build run test contract-test clean deps lint docker-build docker-up docker-down migrate
|
||||
.PHONY: build run test contract-test clean deps lint docker-build docker-up docker-down migrate migrate-encrypt migrate-encrypt-dry
|
||||
|
||||
# Binary names
|
||||
API_BINARY=honeydue-api
|
||||
@@ -99,6 +99,13 @@ migrate-down:
|
||||
migrate-create:
|
||||
migrate create -ext sql -dir migrations -seq $(name)
|
||||
|
||||
# Encrypt existing uploads at rest (run after setting STORAGE_ENCRYPTION_KEY)
|
||||
migrate-encrypt:
|
||||
go run ./cmd/migrate-encrypt
|
||||
|
||||
migrate-encrypt-dry:
|
||||
go run ./cmd/migrate-encrypt --dry-run
|
||||
|
||||
# Development helpers
|
||||
dev: deps run
|
||||
|
||||
@@ -139,5 +146,9 @@ help:
|
||||
@echo " docker-build-prod - Build production images (api, worker, admin)"
|
||||
@echo ""
|
||||
@echo "Database:"
|
||||
@echo " migrate-up - Run database migrations"
|
||||
@echo " migrate-down - Rollback database migrations"
|
||||
@echo " migrate-up - Run database migrations"
|
||||
@echo " migrate-down - Rollback database migrations"
|
||||
@echo ""
|
||||
@echo "Encryption:"
|
||||
@echo " migrate-encrypt - Encrypt existing uploads at rest"
|
||||
@echo " migrate-encrypt-dry - Preview encryption migration (dry run)"
|
||||
|
||||
@@ -134,6 +134,17 @@ func main() {
|
||||
Str("base_url", cfg.Storage.BaseURL).
|
||||
Int64("max_file_size", cfg.Storage.MaxFileSize).
|
||||
Msg("Storage service initialized")
|
||||
|
||||
// Initialize file encryption at rest if configured
|
||||
if cfg.Storage.EncryptionKey != "" {
|
||||
encSvc, encErr := services.NewEncryptionService(cfg.Storage.EncryptionKey)
|
||||
if encErr != nil {
|
||||
log.Error().Err(encErr).Msg("Failed to initialize encryption service - files will NOT be encrypted")
|
||||
} else {
|
||||
storageService.SetEncryptionService(encSvc)
|
||||
log.Info().Msg("File encryption at rest enabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
190
cmd/migrate-encrypt/main.go
Normal file
190
cmd/migrate-encrypt/main.go
Normal file
@@ -0,0 +1,190 @@
|
||||
// migrate-encrypt is a standalone CLI tool that encrypts existing uploaded files at rest.
|
||||
//
|
||||
// It walks the uploads directory, encrypts each unencrypted file, updates the
|
||||
// corresponding database record, and removes the original plaintext file.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// go run ./cmd/migrate-encrypt --dry-run # Preview changes
|
||||
// go run ./cmd/migrate-encrypt # Apply changes
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/services"
|
||||
)
|
||||
|
||||
// dbTable represents a table with a URL column that may reference uploaded files.
|
||||
type dbTable struct {
|
||||
table string
|
||||
column string
|
||||
}
|
||||
|
||||
// tables lists all database tables and columns that store file URLs.
|
||||
var tables = []dbTable{
|
||||
{table: "task_document", column: "file_url"},
|
||||
{table: "task_documentimage", column: "image_url"},
|
||||
{table: "task_taskcompletionimage", column: "image_url"},
|
||||
}
|
||||
|
||||
func main() {
|
||||
dryRun := flag.Bool("dry-run", false, "Preview changes without modifying files or database")
|
||||
flag.Parse()
|
||||
|
||||
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339})
|
||||
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to load config")
|
||||
}
|
||||
|
||||
if cfg.Storage.EncryptionKey == "" {
|
||||
log.Fatal().Msg("STORAGE_ENCRYPTION_KEY is not set — cannot encrypt files")
|
||||
}
|
||||
|
||||
encSvc, err := services.NewEncryptionService(cfg.Storage.EncryptionKey)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to create encryption service")
|
||||
}
|
||||
|
||||
dsn := cfg.Database.DSN()
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to connect to database")
|
||||
}
|
||||
|
||||
uploadDir := cfg.Storage.UploadDir
|
||||
absUploadDir, err := filepath.Abs(uploadDir)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Str("upload_dir", uploadDir).Msg("Failed to resolve upload directory")
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Bool("dry_run", *dryRun).
|
||||
Str("upload_dir", absUploadDir).
|
||||
Msg("Starting file encryption migration")
|
||||
|
||||
var totalFiles, encrypted, skipped, errCount int
|
||||
|
||||
// Walk the uploads directory
|
||||
err = filepath.Walk(absUploadDir, func(path string, info os.FileInfo, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
log.Warn().Err(walkErr).Str("path", path).Msg("Error accessing path")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip directories
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip files already encrypted
|
||||
if strings.HasSuffix(path, ".enc") {
|
||||
skipped++
|
||||
return nil
|
||||
}
|
||||
|
||||
totalFiles++
|
||||
|
||||
// Compute the relative path from upload dir
|
||||
relPath, err := filepath.Rel(absUploadDir, path)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("path", path).Msg("Failed to compute relative path")
|
||||
errCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
if *dryRun {
|
||||
log.Info().Str("file", relPath).Msg("[DRY RUN] Would encrypt")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Process within a transaction
|
||||
txErr := db.Transaction(func(tx *gorm.DB) error {
|
||||
// Read plaintext file
|
||||
plaintext, readErr := os.ReadFile(path)
|
||||
if readErr != nil {
|
||||
return readErr
|
||||
}
|
||||
|
||||
// Encrypt
|
||||
ciphertext, encErr := encSvc.Encrypt(plaintext)
|
||||
if encErr != nil {
|
||||
return encErr
|
||||
}
|
||||
|
||||
// Write encrypted file
|
||||
encPath := path + ".enc"
|
||||
if writeErr := os.WriteFile(encPath, ciphertext, 0644); writeErr != nil {
|
||||
return writeErr
|
||||
}
|
||||
|
||||
// Update database records that reference this file
|
||||
// The stored URL uses the BaseURL prefix + relative path
|
||||
// We need to match against the relative path portion
|
||||
for _, t := range tables {
|
||||
// Match URLs ending with the relative path (handles both /uploads/... and bare paths)
|
||||
result := tx.Table(t.table).
|
||||
Where(t.column+" LIKE ?", "%"+relPath).
|
||||
Where(t.column+" NOT LIKE ?", "%.enc").
|
||||
Update(t.column, gorm.Expr(t.column+" || '.enc'"))
|
||||
if result.Error != nil {
|
||||
// Roll back: remove the encrypted file
|
||||
os.Remove(encPath)
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected > 0 {
|
||||
log.Info().
|
||||
Str("table", t.table).
|
||||
Str("column", t.column).
|
||||
Int64("rows", result.RowsAffected).
|
||||
Str("file", relPath).
|
||||
Msg("Updated database records")
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the original plaintext file
|
||||
if removeErr := os.Remove(path); removeErr != nil {
|
||||
log.Warn().Err(removeErr).Str("path", path).Msg("Failed to remove original file (encrypted copy exists)")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if txErr != nil {
|
||||
log.Error().Err(txErr).Str("file", relPath).Msg("Failed to encrypt file")
|
||||
errCount++
|
||||
} else {
|
||||
encrypted++
|
||||
log.Info().Str("file", relPath).Msg("Encrypted successfully")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to walk upload directory")
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Bool("dry_run", *dryRun).
|
||||
Int("total_files", totalFiles).
|
||||
Int("encrypted", encrypted).
|
||||
Int("skipped_already_enc", skipped).
|
||||
Int("errors", errCount).
|
||||
Msg("Encryption migration complete")
|
||||
}
|
||||
@@ -88,6 +88,9 @@ services:
|
||||
APNS_TOPIC: ${APNS_TOPIC:-com.tt.honeyDue}
|
||||
APNS_USE_SANDBOX: "true"
|
||||
FCM_SERVER_KEY: ${FCM_SERVER_KEY}
|
||||
|
||||
# Storage encryption
|
||||
STORAGE_ENCRYPTION_KEY: ${STORAGE_ENCRYPTION_KEY}
|
||||
volumes:
|
||||
- ./push_certs:/certs:ro
|
||||
- ./uploads:/app/uploads
|
||||
|
||||
@@ -93,6 +93,9 @@ services:
|
||||
APNS_TOPIC: ${APNS_TOPIC}
|
||||
APNS_USE_SANDBOX: "${APNS_USE_SANDBOX:-false}"
|
||||
FCM_SERVER_KEY: ${FCM_SERVER_KEY}
|
||||
|
||||
# Storage encryption
|
||||
STORAGE_ENCRYPTION_KEY: ${STORAGE_ENCRYPTION_KEY}
|
||||
volumes:
|
||||
- push_certs:/certs:ro
|
||||
- uploads:/app/uploads
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -137,10 +138,11 @@ type SecurityConfig struct {
|
||||
|
||||
// StorageConfig holds file storage settings
|
||||
type StorageConfig struct {
|
||||
UploadDir string // Directory to store uploaded files
|
||||
BaseURL string // Public URL prefix for serving files (e.g., "/uploads")
|
||||
MaxFileSize int64 // Max file size in bytes (default 10MB)
|
||||
AllowedTypes string // Comma-separated MIME types
|
||||
UploadDir string // Directory to store uploaded files
|
||||
BaseURL string // Public URL prefix for serving files (e.g., "/uploads")
|
||||
MaxFileSize int64 // Max file size in bytes (default 10MB)
|
||||
AllowedTypes string // Comma-separated MIME types
|
||||
EncryptionKey string // 64-char hex key for file encryption at rest (optional)
|
||||
}
|
||||
|
||||
// FeatureFlags holds kill switches for major subsystems.
|
||||
@@ -262,10 +264,11 @@ func Load() (*Config, error) {
|
||||
MaxPasswordResetRate: 3,
|
||||
},
|
||||
Storage: StorageConfig{
|
||||
UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"),
|
||||
BaseURL: viper.GetString("STORAGE_BASE_URL"),
|
||||
MaxFileSize: viper.GetInt64("STORAGE_MAX_FILE_SIZE"),
|
||||
AllowedTypes: viper.GetString("STORAGE_ALLOWED_TYPES"),
|
||||
UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"),
|
||||
BaseURL: viper.GetString("STORAGE_BASE_URL"),
|
||||
MaxFileSize: viper.GetInt64("STORAGE_MAX_FILE_SIZE"),
|
||||
AllowedTypes: viper.GetString("STORAGE_ALLOWED_TYPES"),
|
||||
EncryptionKey: viper.GetString("STORAGE_ENCRYPTION_KEY"),
|
||||
},
|
||||
AppleAuth: AppleAuthConfig{
|
||||
ClientID: viper.GetString("APPLE_CLIENT_ID"),
|
||||
@@ -414,6 +417,16 @@ func validate(cfg *Config) error {
|
||||
// Database password might come from DATABASE_URL, don't require it separately
|
||||
// The actual connection will fail if credentials are wrong
|
||||
|
||||
// Validate STORAGE_ENCRYPTION_KEY if set: must be exactly 64 hex characters
|
||||
if cfg.Storage.EncryptionKey != "" {
|
||||
if len(cfg.Storage.EncryptionKey) != 64 {
|
||||
return fmt.Errorf("STORAGE_ENCRYPTION_KEY must be exactly 64 hex characters (got %d)", len(cfg.Storage.EncryptionKey))
|
||||
}
|
||||
if _, err := hex.DecodeString(cfg.Storage.EncryptionKey); err != nil {
|
||||
return fmt.Errorf("STORAGE_ENCRYPTION_KEY contains invalid hex: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -63,3 +63,9 @@ type AppleSignInRequest struct {
|
||||
type GoogleSignInRequest struct {
|
||||
IDToken string `json:"id_token" validate:"required"` // Google ID token from Credential Manager
|
||||
}
|
||||
|
||||
// DeleteAccountRequest represents the delete account request body
|
||||
type DeleteAccountRequest struct {
|
||||
Password *string `json:"password"`
|
||||
Confirmation *string `json:"confirmation"`
|
||||
}
|
||||
|
||||
@@ -45,15 +45,16 @@ type RegisterResponse struct {
|
||||
|
||||
// CurrentUserResponse represents the /auth/me/ response
|
||||
type CurrentUserResponse struct {
|
||||
ID uint `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"first_name"`
|
||||
LastName string `json:"last_name"`
|
||||
IsActive bool `json:"is_active"`
|
||||
DateJoined time.Time `json:"date_joined"`
|
||||
LastLogin *time.Time `json:"last_login,omitempty"`
|
||||
Profile *UserProfileResponse `json:"profile,omitempty"`
|
||||
ID uint `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"first_name"`
|
||||
LastName string `json:"last_name"`
|
||||
IsActive bool `json:"is_active"`
|
||||
DateJoined time.Time `json:"date_joined"`
|
||||
LastLogin *time.Time `json:"last_login,omitempty"`
|
||||
Profile *UserProfileResponse `json:"profile,omitempty"`
|
||||
AuthProvider string `json:"auth_provider"`
|
||||
}
|
||||
|
||||
// VerifyEmailResponse represents the email verification response
|
||||
@@ -125,17 +126,18 @@ func NewUserProfileResponse(profile *models.UserProfile) *UserProfileResponse {
|
||||
}
|
||||
|
||||
// NewCurrentUserResponse creates a CurrentUserResponse from a User model
|
||||
func NewCurrentUserResponse(user *models.User) CurrentUserResponse {
|
||||
func NewCurrentUserResponse(user *models.User, authProvider string) CurrentUserResponse {
|
||||
return CurrentUserResponse{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
FirstName: user.FirstName,
|
||||
LastName: user.LastName,
|
||||
IsActive: user.IsActive,
|
||||
DateJoined: user.DateJoined,
|
||||
LastLogin: user.LastLogin,
|
||||
Profile: NewUserProfileResponse(user.Profile),
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
FirstName: user.FirstName,
|
||||
LastName: user.LastName,
|
||||
IsActive: user.IsActive,
|
||||
DateJoined: user.DateJoined,
|
||||
LastLogin: user.LastLogin,
|
||||
Profile: NewUserProfileResponse(user.Profile),
|
||||
AuthProvider: authProvider,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ type AuthHandler struct {
|
||||
cache *services.CacheService
|
||||
appleAuthService *services.AppleAuthService
|
||||
googleAuthService *services.GoogleAuthService
|
||||
storageService *services.StorageService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new auth handler
|
||||
@@ -43,6 +44,11 @@ func (h *AuthHandler) SetGoogleAuthService(googleAuth *services.GoogleAuthServic
|
||||
h.googleAuthService = googleAuth
|
||||
}
|
||||
|
||||
// SetStorageService sets the storage service for file deletion during account deletion
|
||||
func (h *AuthHandler) SetStorageService(storageService *services.StorageService) {
|
||||
h.storageService = storageService
|
||||
}
|
||||
|
||||
// Login handles POST /api/auth/login/
|
||||
func (h *AuthHandler) Login(c echo.Context) error {
|
||||
var req requests.LoginRequest
|
||||
@@ -406,3 +412,48 @@ func (h *AuthHandler) GoogleSignIn(c echo.Context) error {
|
||||
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// DeleteAccount handles DELETE /api/auth/account/
|
||||
func (h *AuthHandler) DeleteAccount(c echo.Context) error {
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req requests.DeleteAccountRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
|
||||
fileURLs, err := h.authService.DeleteAccount(user.ID, req.Password, req.Confirmation)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Account deletion failed")
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete files from disk (best effort, don't fail the request)
|
||||
if h.storageService != nil && len(fileURLs) > 0 {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Uint("user_id", user.ID).Msg("Panic in file cleanup goroutine")
|
||||
}
|
||||
}()
|
||||
for _, fileURL := range fileURLs {
|
||||
if err := h.storageService.Delete(fileURL); err != nil {
|
||||
log.Warn().Err(err).Str("file_url", fileURL).Msg("Failed to delete file during account cleanup")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Invalidate auth token from cache
|
||||
token := middleware.GetAuthToken(c)
|
||||
if h.cache != nil && token != "" {
|
||||
if err := h.cache.InvalidateAuthToken(c.Request().Context(), token); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to invalidate token in cache after account deletion")
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, responses.MessageResponse{Message: "Account deleted successfully"})
|
||||
}
|
||||
|
||||
217
internal/handlers/auth_handler_delete_test.go
Normal file
217
internal/handlers/auth_handler_delete_test.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/repositories"
|
||||
"github.com/treytartt/honeydue-api/internal/services"
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
)
|
||||
|
||||
func setupDeleteAccountHandler(t *testing.T) (*AuthHandler, *echo.Echo, *gorm.DB) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
SecretKey: "test-secret-key",
|
||||
PasswordResetExpiry: 15 * time.Minute,
|
||||
ConfirmationExpiry: 24 * time.Hour,
|
||||
MaxPasswordResetRate: 3,
|
||||
},
|
||||
}
|
||||
authService := services.NewAuthService(userRepo, cfg)
|
||||
handler := NewAuthHandler(authService, nil, nil)
|
||||
e := testutil.SetupTestRouter()
|
||||
return handler, e, db
|
||||
}
|
||||
|
||||
func TestAuthHandler_DeleteAccount_EmailUser(t *testing.T) {
|
||||
handler, e, db := setupDeleteAccountHandler(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "deletetest", "delete@test.com", "password123")
|
||||
|
||||
// Create profile for the user
|
||||
profile := &models.UserProfile{UserID: user.ID, Verified: true}
|
||||
require.NoError(t, db.Create(profile).Error)
|
||||
|
||||
// Create auth token
|
||||
testutil.CreateTestToken(t, db, user.ID)
|
||||
|
||||
authGroup := e.Group("/api/auth")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.DELETE("/account/", handler.DeleteAccount)
|
||||
|
||||
t.Run("successful deletion with correct password", func(t *testing.T) {
|
||||
password := "password123"
|
||||
req := map[string]interface{}{
|
||||
"password": password,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response["message"], "Account deleted successfully")
|
||||
|
||||
// Verify user is actually deleted
|
||||
var count int64
|
||||
db.Model(&models.User{}).Where("id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
|
||||
// Verify profile is deleted
|
||||
db.Model(&models.UserProfile{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
|
||||
// Verify auth token is deleted
|
||||
db.Model(&models.AuthToken{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_DeleteAccount_WrongPassword(t *testing.T) {
|
||||
handler, e, db := setupDeleteAccountHandler(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "wrongpw", "wrongpw@test.com", "password123")
|
||||
|
||||
authGroup := e.Group("/api/auth")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.DELETE("/account/", handler.DeleteAccount)
|
||||
|
||||
t.Run("wrong password returns 401", func(t *testing.T) {
|
||||
wrongPw := "wrongpassword"
|
||||
req := map[string]interface{}{
|
||||
"password": wrongPw,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_DeleteAccount_MissingPassword(t *testing.T) {
|
||||
handler, e, db := setupDeleteAccountHandler(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "nopw", "nopw@test.com", "password123")
|
||||
|
||||
authGroup := e.Group("/api/auth")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.DELETE("/account/", handler.DeleteAccount)
|
||||
|
||||
t.Run("missing password returns 400", func(t *testing.T) {
|
||||
req := map[string]interface{}{}
|
||||
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_DeleteAccount_SocialAuthUser(t *testing.T) {
|
||||
handler, e, db := setupDeleteAccountHandler(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "randompassword")
|
||||
|
||||
// Create Apple social auth record
|
||||
appleAuth := &models.AppleSocialAuth{
|
||||
UserID: user.ID,
|
||||
AppleID: "apple_sub_123",
|
||||
Email: "apple@test.com",
|
||||
}
|
||||
require.NoError(t, db.Create(appleAuth).Error)
|
||||
|
||||
// Create profile
|
||||
profile := &models.UserProfile{UserID: user.ID, Verified: true}
|
||||
require.NoError(t, db.Create(profile).Error)
|
||||
|
||||
authGroup := e.Group("/api/auth")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.DELETE("/account/", handler.DeleteAccount)
|
||||
|
||||
t.Run("successful deletion with DELETE confirmation", func(t *testing.T) {
|
||||
confirmation := "DELETE"
|
||||
req := map[string]interface{}{
|
||||
"confirmation": confirmation,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
// Verify user is deleted
|
||||
var count int64
|
||||
db.Model(&models.User{}).Where("id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
|
||||
// Verify apple auth is deleted
|
||||
db.Model(&models.AppleSocialAuth{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_DeleteAccount_SocialAuthMissingConfirmation(t *testing.T) {
|
||||
handler, e, db := setupDeleteAccountHandler(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "randompassword")
|
||||
|
||||
// Create Google social auth record
|
||||
googleAuth := &models.GoogleSocialAuth{
|
||||
UserID: user.ID,
|
||||
GoogleID: "google_sub_456",
|
||||
Email: "google@test.com",
|
||||
}
|
||||
require.NoError(t, db.Create(googleAuth).Error)
|
||||
|
||||
authGroup := e.Group("/api/auth")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.DELETE("/account/", handler.DeleteAccount)
|
||||
|
||||
t.Run("missing confirmation returns 400", func(t *testing.T) {
|
||||
req := map[string]interface{}{}
|
||||
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("wrong confirmation returns 400", func(t *testing.T) {
|
||||
wrongConfirmation := "delete"
|
||||
req := map[string]interface{}{
|
||||
"confirmation": wrongConfirmation,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_DeleteAccount_Unauthenticated(t *testing.T) {
|
||||
handler, e, _ := setupDeleteAccountHandler(t)
|
||||
|
||||
// No auth middleware - unauthenticated request
|
||||
e.DELETE("/api/auth/account/", handler.DeleteAccount)
|
||||
|
||||
t.Run("unauthenticated request returns 401", func(t *testing.T) {
|
||||
req := map[string]interface{}{
|
||||
"password": "password123",
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -60,15 +62,18 @@ func (h *MediaHandler) ServeDocument(c echo.Context) error {
|
||||
return apperrors.Forbidden("error.access_denied")
|
||||
}
|
||||
|
||||
// Serve the file
|
||||
filePath := h.resolveFilePath(doc.FileURL)
|
||||
if filePath == "" {
|
||||
// Serve the file (supports encrypted files transparently)
|
||||
data, mimeType, err := h.storageSvc.ReadFile(doc.FileURL)
|
||||
if err != nil {
|
||||
return apperrors.NotFound("error.file_not_found")
|
||||
}
|
||||
|
||||
// Set caching headers (private, 1 hour)
|
||||
// Set caching and disposition headers
|
||||
c.Response().Header().Set("Cache-Control", "private, max-age=3600")
|
||||
return c.File(filePath)
|
||||
if doc.FileName != "" {
|
||||
c.Response().Header().Set("Content-Disposition", "inline; filename=\""+doc.FileName+"\"")
|
||||
}
|
||||
return c.Blob(http.StatusOK, mimeType, data)
|
||||
}
|
||||
|
||||
// ServeDocumentImage serves a document image with access control
|
||||
@@ -102,14 +107,15 @@ func (h *MediaHandler) ServeDocumentImage(c echo.Context) error {
|
||||
return apperrors.Forbidden("error.access_denied")
|
||||
}
|
||||
|
||||
// Serve the file
|
||||
filePath := h.resolveFilePath(img.ImageURL)
|
||||
if filePath == "" {
|
||||
// Serve the file (supports encrypted files transparently)
|
||||
data, mimeType, err := h.storageSvc.ReadFile(img.ImageURL)
|
||||
if err != nil {
|
||||
return apperrors.NotFound("error.file_not_found")
|
||||
}
|
||||
|
||||
c.Response().Header().Set("Cache-Control", "private, max-age=3600")
|
||||
return c.File(filePath)
|
||||
c.Response().Header().Set("Content-Disposition", "inline; filename=\""+filepath.Base(img.ImageURL)+"\"")
|
||||
return c.Blob(http.StatusOK, mimeType, data)
|
||||
}
|
||||
|
||||
// ServeCompletionImage serves a task completion image with access control
|
||||
@@ -149,14 +155,15 @@ func (h *MediaHandler) ServeCompletionImage(c echo.Context) error {
|
||||
return apperrors.Forbidden("error.access_denied")
|
||||
}
|
||||
|
||||
// Serve the file
|
||||
filePath := h.resolveFilePath(img.ImageURL)
|
||||
if filePath == "" {
|
||||
// Serve the file (supports encrypted files transparently)
|
||||
data, mimeType, err := h.storageSvc.ReadFile(img.ImageURL)
|
||||
if err != nil {
|
||||
return apperrors.NotFound("error.file_not_found")
|
||||
}
|
||||
|
||||
c.Response().Header().Set("Cache-Control", "private, max-age=3600")
|
||||
return c.File(filePath)
|
||||
c.Response().Header().Set("Content-Disposition", "inline; filename=\""+filepath.Base(img.ImageURL)+"\"")
|
||||
return c.Blob(http.StatusOK, mimeType, data)
|
||||
}
|
||||
|
||||
// resolveFilePath converts a stored URL to an actual file path.
|
||||
|
||||
@@ -34,6 +34,12 @@ func NewUserRepository(db *gorm.DB) *UserRepository {
|
||||
return &UserRepository{db: db}
|
||||
}
|
||||
|
||||
// DB returns the underlying *gorm.DB connection. This is useful when callers
|
||||
// need to pass the connection (e.g., a transaction) to methods that accept *gorm.DB.
|
||||
func (r *UserRepository) DB() *gorm.DB {
|
||||
return r.db
|
||||
}
|
||||
|
||||
// Transaction runs fn inside a database transaction. The callback receives a
|
||||
// new UserRepository backed by the transaction so all operations within fn
|
||||
// share the same transactional connection.
|
||||
@@ -509,6 +515,195 @@ func (r *UserRepository) FindProfilesInSharedResidences(userID uint) ([]models.U
|
||||
return profiles, err
|
||||
}
|
||||
|
||||
// --- Auth Provider Detection ---
|
||||
|
||||
// FindAuthProvider determines the auth provider for a user.
|
||||
// Returns "apple", "google", or "email".
|
||||
func (r *UserRepository) FindAuthProvider(userID uint) (string, error) {
|
||||
var count int64
|
||||
if err := r.db.Model(&models.AppleSocialAuth{}).Where("user_id = ?", userID).Count(&count).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
if count > 0 {
|
||||
return "apple", nil
|
||||
}
|
||||
|
||||
if err := r.db.Model(&models.GoogleSocialAuth{}).Where("user_id = ?", userID).Count(&count).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
if count > 0 {
|
||||
return "google", nil
|
||||
}
|
||||
|
||||
return "email", nil
|
||||
}
|
||||
|
||||
// --- Account Deletion ---
|
||||
|
||||
// DeleteUserCascade deletes a user and all related records in dependency order.
|
||||
// Should be called on a repository backed by a transaction (via Transaction callback).
|
||||
// Returns a list of file URLs that need to be deleted from disk after the transaction commits.
|
||||
func (r *UserRepository) DeleteUserCascade(userID uint) ([]string, error) {
|
||||
var fileURLs []string
|
||||
db := r.db
|
||||
|
||||
// 1. Push notification devices
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.APNSDevice{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.GCMDevice{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. Notifications
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.Notification{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. Notification preferences
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.NotificationPreference{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. Task reminder logs
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.TaskReminderLog{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 5. Find residences owned by user
|
||||
var ownedResidences []models.Residence
|
||||
if err := db.Where("owner_id = ?", userID).Find(&ownedResidences).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, residence := range ownedResidences {
|
||||
// Collect file URLs before deleting
|
||||
|
||||
// Task completion images (via completion_id -> task_id -> residence_id)
|
||||
var completionImageURLs []string
|
||||
db.Model(&models.TaskCompletionImage{}).
|
||||
Joins("JOIN task_taskcompletion ON task_taskcompletion.id = task_taskcompletionimage.completion_id").
|
||||
Joins("JOIN task_task ON task_task.id = task_taskcompletion.task_id").
|
||||
Where("task_task.residence_id = ?", residence.ID).
|
||||
Pluck("task_taskcompletionimage.image_url", &completionImageURLs)
|
||||
fileURLs = append(fileURLs, completionImageURLs...)
|
||||
|
||||
// Delete task completion images
|
||||
db.Exec(`DELETE FROM task_taskcompletionimage WHERE completion_id IN (
|
||||
SELECT tc.id FROM task_taskcompletion tc
|
||||
JOIN task_task t ON t.id = tc.task_id
|
||||
WHERE t.residence_id = ?
|
||||
)`, residence.ID)
|
||||
|
||||
// Delete task completions
|
||||
db.Exec(`DELETE FROM task_taskcompletion WHERE task_id IN (
|
||||
SELECT id FROM task_task WHERE residence_id = ?
|
||||
)`, residence.ID)
|
||||
|
||||
// Document images (via document_id -> residence_id)
|
||||
var docImageURLs []string
|
||||
db.Model(&models.DocumentImage{}).
|
||||
Joins("JOIN task_document ON task_document.id = task_documentimage.document_id").
|
||||
Where("task_document.residence_id = ?", residence.ID).
|
||||
Pluck("task_documentimage.image_url", &docImageURLs)
|
||||
fileURLs = append(fileURLs, docImageURLs...)
|
||||
|
||||
// Delete document images
|
||||
db.Exec(`DELETE FROM task_documentimage WHERE document_id IN (
|
||||
SELECT id FROM task_document WHERE residence_id = ?
|
||||
)`, residence.ID)
|
||||
|
||||
// Document file URLs
|
||||
var docFileURLs []string
|
||||
db.Model(&models.Document{}).Where("residence_id = ?", residence.ID).Pluck("file_url", &docFileURLs)
|
||||
fileURLs = append(fileURLs, docFileURLs...)
|
||||
|
||||
// Delete documents
|
||||
if err := db.Where("residence_id = ?", residence.ID).Delete(&models.Document{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Delete tasks
|
||||
if err := db.Where("residence_id = ?", residence.ID).Delete(&models.Task{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Delete contractor specialties (many-to-many join table)
|
||||
db.Exec(`DELETE FROM task_contractor_specialties WHERE contractor_id IN (
|
||||
SELECT id FROM task_contractor WHERE residence_id = ?
|
||||
)`, residence.ID)
|
||||
|
||||
// Delete contractors
|
||||
if err := db.Where("residence_id = ?", residence.ID).Delete(&models.Contractor{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Delete share codes
|
||||
if err := db.Where("residence_id = ?", residence.ID).Delete(&models.ResidenceShareCode{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Remove residence membership records (many-to-many join table)
|
||||
db.Exec("DELETE FROM residence_residence_users WHERE residence_id = ?", residence.ID)
|
||||
|
||||
// Delete the residence itself
|
||||
if err := db.Delete(&residence).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Remove user from shared residences they don't own (membership only)
|
||||
db.Exec("DELETE FROM residence_residence_users WHERE user_id = ?", userID)
|
||||
|
||||
// 7. Subscription
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.UserSubscription{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 8. Social auth records
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.AppleSocialAuth{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.GoogleSocialAuth{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 9. Confirmation codes
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.ConfirmationCode{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 10. Password reset codes
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.PasswordResetCode{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 11. Auth tokens
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.AuthToken{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 12. User profile
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.UserProfile{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 13. User
|
||||
if err := db.Where("id = ?", userID).Delete(&models.User{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Filter out empty URLs
|
||||
var cleanURLs []string
|
||||
for _, url := range fileURLs {
|
||||
if url != "" {
|
||||
cleanURLs = append(cleanURLs, url)
|
||||
}
|
||||
}
|
||||
|
||||
return cleanURLs, nil
|
||||
}
|
||||
|
||||
// --- Apple Social Auth Methods ---
|
||||
|
||||
// FindByAppleID finds an Apple social auth by Apple ID
|
||||
|
||||
@@ -187,3 +187,169 @@ func TestUserRepository_GetOrCreateProfile(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, profile1.ID, profile2.ID)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindAuthProvider(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
t.Run("email user", func(t *testing.T) {
|
||||
user := testutil.CreateTestUser(t, db, "emailuser", "email@test.com", "password123")
|
||||
provider, err := repo.FindAuthProvider(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "email", provider)
|
||||
})
|
||||
|
||||
t.Run("apple user", func(t *testing.T) {
|
||||
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "password123")
|
||||
appleAuth := &models.AppleSocialAuth{
|
||||
UserID: user.ID,
|
||||
AppleID: "apple_sub_test",
|
||||
Email: "apple@test.com",
|
||||
}
|
||||
require.NoError(t, db.Create(appleAuth).Error)
|
||||
|
||||
provider, err := repo.FindAuthProvider(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "apple", provider)
|
||||
})
|
||||
|
||||
t.Run("google user", func(t *testing.T) {
|
||||
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "password123")
|
||||
googleAuth := &models.GoogleSocialAuth{
|
||||
UserID: user.ID,
|
||||
GoogleID: "google_sub_test",
|
||||
Email: "google@test.com",
|
||||
}
|
||||
require.NoError(t, db.Create(googleAuth).Error)
|
||||
|
||||
provider, err := repo.FindAuthProvider(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "google", provider)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserRepository_DeleteUserCascade(t *testing.T) {
|
||||
t.Run("deletes user with no residences", func(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "deletebare", "deletebare@test.com", "password123")
|
||||
|
||||
// Create profile and token
|
||||
profile := &models.UserProfile{UserID: user.ID, Verified: true}
|
||||
require.NoError(t, db.Create(profile).Error)
|
||||
_, err := models.GetOrCreateToken(db, user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var fileURLs []string
|
||||
txErr := repo.Transaction(func(txRepo *UserRepository) error {
|
||||
urls, err := txRepo.DeleteUserCascade(user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fileURLs = urls
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, txErr)
|
||||
assert.Empty(t, fileURLs)
|
||||
|
||||
// Verify user is gone
|
||||
var count int64
|
||||
db.Model(&models.User{}).Where("id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
|
||||
// Verify profile is gone
|
||||
db.Model(&models.UserProfile{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
|
||||
// Verify token is gone
|
||||
db.Model(&models.AuthToken{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
})
|
||||
|
||||
t.Run("returns file URLs for cleanup", func(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "deletefiles", "deletefiles@test.com", "password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test Home")
|
||||
|
||||
// Create document with file
|
||||
doc := &models.Document{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Test Doc",
|
||||
FileURL: "/uploads/documents/test.pdf",
|
||||
}
|
||||
require.NoError(t, db.Create(doc).Error)
|
||||
|
||||
// Create document image
|
||||
docImage := &models.DocumentImage{
|
||||
DocumentID: doc.ID,
|
||||
ImageURL: "/uploads/images/docimg.jpg",
|
||||
}
|
||||
require.NoError(t, db.Create(docImage).Error)
|
||||
|
||||
var fileURLs []string
|
||||
txErr := repo.Transaction(func(txRepo *UserRepository) error {
|
||||
urls, err := txRepo.DeleteUserCascade(user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fileURLs = urls
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, txErr)
|
||||
|
||||
// Should return the file URLs
|
||||
assert.Contains(t, fileURLs, "/uploads/documents/test.pdf")
|
||||
assert.Contains(t, fileURLs, "/uploads/images/docimg.jpg")
|
||||
|
||||
// Verify everything deleted
|
||||
var count int64
|
||||
db.Model(&models.User{}).Where("id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
db.Model(&models.Residence{}).Where("id = ?", residence.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
db.Model(&models.Document{}).Where("id = ?", doc.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
})
|
||||
|
||||
t.Run("handles user with owned and shared residences", func(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "deleteowner", "deleteowner@test.com", "password123")
|
||||
otherUser := testutil.CreateTestUser(t, db, "otheruser", "other@test.com", "password123")
|
||||
otherResidence := testutil.CreateTestResidence(t, db, otherUser.ID, "Other Home")
|
||||
|
||||
// Owner's residence
|
||||
ownedResidence := testutil.CreateTestResidence(t, db, owner.ID, "Owner Home")
|
||||
|
||||
// Add owner as member of other user's residence
|
||||
db.Exec("INSERT INTO residence_residence_users (residence_id, user_id) VALUES (?, ?)", otherResidence.ID, owner.ID)
|
||||
|
||||
txErr := repo.Transaction(func(txRepo *UserRepository) error {
|
||||
_, err := txRepo.DeleteUserCascade(owner.ID)
|
||||
return err
|
||||
})
|
||||
require.NoError(t, txErr)
|
||||
|
||||
// Owner's residence should be deleted
|
||||
var count int64
|
||||
db.Model(&models.Residence{}).Where("id = ?", ownedResidence.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
|
||||
// Other user's residence should still exist
|
||||
db.Model(&models.Residence{}).Where("id = ?", otherResidence.ID).Count(&count)
|
||||
assert.Equal(t, int64(1), count)
|
||||
|
||||
// Owner should no longer be a member of other's residence
|
||||
db.Raw("SELECT COUNT(*) FROM residence_residence_users WHERE user_id = ?", owner.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
|
||||
// Other user should still exist
|
||||
db.Model(&models.User{}).Where("id = ?", otherUser.ID).Count(&count)
|
||||
assert.Equal(t, int64(1), count)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -182,6 +182,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
authHandler := handlers.NewAuthHandler(authService, deps.EmailService, deps.Cache)
|
||||
authHandler.SetAppleAuthService(appleAuthService)
|
||||
authHandler.SetGoogleAuthService(googleAuthService)
|
||||
authHandler.SetStorageService(deps.StorageService)
|
||||
userHandler := handlers.NewUserHandler(userService)
|
||||
residenceHandler := handlers.NewResidenceHandler(residenceService, deps.PDFService, deps.EmailService, cfg.Features.PDFReportsEnabled)
|
||||
taskHandler := handlers.NewTaskHandler(taskService, deps.StorageService)
|
||||
@@ -347,6 +348,7 @@ func setupProtectedAuthRoutes(api *echo.Group, authHandler *handlers.AuthHandler
|
||||
auth.POST("/verify/", authHandler.VerifyEmail) // Alias for mobile app compatibility
|
||||
auth.POST("/verify-email/", authHandler.VerifyEmail) // Original route
|
||||
auth.POST("/resend-verification/", authHandler.ResendVerification)
|
||||
auth.DELETE("/account/", authHandler.DeleteAccount)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -200,10 +200,69 @@ func (s *AuthService) GetCurrentUser(userID uint) (*responses.CurrentUserRespons
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := responses.NewCurrentUserResponse(user)
|
||||
authProvider, err := s.userRepo.FindAuthProvider(userID)
|
||||
if err != nil {
|
||||
// Log but don't fail - default to "email"
|
||||
log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to determine auth provider")
|
||||
authProvider = "email"
|
||||
}
|
||||
|
||||
response := responses.NewCurrentUserResponse(user, authProvider)
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// DeleteAccount deletes a user's account and all associated data.
|
||||
// For email auth users, password verification is required.
|
||||
// For social auth users, confirmation string "DELETE" is required.
|
||||
// Returns a list of file URLs that need to be deleted from disk.
|
||||
func (s *AuthService) DeleteAccount(userID uint, password, confirmation *string) ([]string, error) {
|
||||
// Fetch user
|
||||
user, err := s.userRepo.FindByID(userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, repositories.ErrUserNotFound) {
|
||||
return nil, apperrors.NotFound("error.user_not_found")
|
||||
}
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Determine auth provider
|
||||
authProvider, err := s.userRepo.FindAuthProvider(userID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Validate credentials based on auth provider
|
||||
if authProvider == "email" {
|
||||
if password == nil || *password == "" {
|
||||
return nil, apperrors.BadRequest("error.password_required")
|
||||
}
|
||||
if !user.CheckPassword(*password) {
|
||||
return nil, apperrors.Unauthorized("error.invalid_credentials")
|
||||
}
|
||||
} else {
|
||||
// Social auth (apple or google) - require confirmation
|
||||
if confirmation == nil || *confirmation != "DELETE" {
|
||||
return nil, apperrors.BadRequest("error.confirmation_required")
|
||||
}
|
||||
}
|
||||
|
||||
// Start transaction and cascade delete
|
||||
var fileURLs []string
|
||||
txErr := s.userRepo.Transaction(func(txRepo *repositories.UserRepository) error {
|
||||
urls, err := txRepo.DeleteUserCascade(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fileURLs = urls
|
||||
return nil
|
||||
})
|
||||
if txErr != nil {
|
||||
return nil, apperrors.Internal(txErr)
|
||||
}
|
||||
|
||||
return fileURLs, nil
|
||||
}
|
||||
|
||||
// UpdateProfile updates a user's profile
|
||||
func (s *AuthService) UpdateProfile(userID uint, req *requests.UpdateProfileRequest) (*responses.CurrentUserResponse, error) {
|
||||
user, err := s.userRepo.FindByID(userID)
|
||||
@@ -240,7 +299,13 @@ func (s *AuthService) UpdateProfile(userID uint, req *requests.UpdateProfileRequ
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := responses.NewCurrentUserResponse(user)
|
||||
authProvider, err := s.userRepo.FindAuthProvider(userID)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to determine auth provider")
|
||||
authProvider = "email"
|
||||
}
|
||||
|
||||
response := responses.NewCurrentUserResponse(user, authProvider)
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
|
||||
179
internal/services/encryption_service.go
Normal file
179
internal/services/encryption_service.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
// encryptionVersion is the current file format version byte.
|
||||
encryptionVersion byte = 0x01
|
||||
|
||||
// aes256KeyLen is the required key length in bytes for AES-256.
|
||||
aes256KeyLen = 32
|
||||
|
||||
// gcmNonceSize is the standard GCM nonce size (12 bytes).
|
||||
gcmNonceSize = 12
|
||||
|
||||
// gcmTagSize is the standard GCM authentication tag size (16 bytes).
|
||||
gcmTagSize = 16
|
||||
|
||||
// encryptedDEKSize is the encrypted DEK length: 32-byte DEK + 16-byte GCM tag.
|
||||
encryptedDEKSize = aes256KeyLen + gcmTagSize
|
||||
|
||||
// headerSize is the fixed header: version(1) + KEK nonce(12) + encrypted DEK(48) + DEK nonce(12).
|
||||
headerSize = 1 + gcmNonceSize + encryptedDEKSize + gcmNonceSize
|
||||
)
|
||||
|
||||
// EncryptionService provides AES-256-GCM envelope encryption for files at rest.
|
||||
type EncryptionService struct {
|
||||
kek []byte // Key Encryption Key (32 bytes)
|
||||
}
|
||||
|
||||
// NewEncryptionService creates an EncryptionService from a 64-character hex-encoded KEK.
|
||||
func NewEncryptionService(hexKey string) (*EncryptionService, error) {
|
||||
if len(hexKey) != 64 {
|
||||
return nil, fmt.Errorf("encryption key must be exactly 64 hex characters (got %d)", len(hexKey))
|
||||
}
|
||||
|
||||
kek, err := hex.DecodeString(hexKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid hex in encryption key: %w", err)
|
||||
}
|
||||
|
||||
if len(kek) != aes256KeyLen {
|
||||
return nil, fmt.Errorf("decoded key must be %d bytes", aes256KeyLen)
|
||||
}
|
||||
|
||||
return &EncryptionService{kek: kek}, nil
|
||||
}
|
||||
|
||||
// IsEnabled returns true if the encryption service is configured and ready.
|
||||
func (s *EncryptionService) IsEnabled() bool {
|
||||
return s != nil && len(s.kek) == aes256KeyLen
|
||||
}
|
||||
|
||||
// Encrypt encrypts plaintext using envelope encryption (random DEK encrypted with KEK).
|
||||
//
|
||||
// File format:
|
||||
//
|
||||
// [1-byte version 0x01]
|
||||
// [12-byte KEK nonce]
|
||||
// [48-byte encrypted DEK (32-byte DEK + 16-byte GCM tag)]
|
||||
// [12-byte DEK nonce]
|
||||
// [ciphertext + 16-byte GCM tag]
|
||||
func (s *EncryptionService) Encrypt(plaintext []byte) ([]byte, error) {
|
||||
// Generate a random Data Encryption Key (DEK)
|
||||
dek := make([]byte, aes256KeyLen)
|
||||
if _, err := io.ReadFull(rand.Reader, dek); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate DEK: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the DEK with the KEK
|
||||
kekBlock, err := aes.NewCipher(s.kek)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create KEK cipher: %w", err)
|
||||
}
|
||||
kekGCM, err := cipher.NewGCM(kekBlock)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create KEK GCM: %w", err)
|
||||
}
|
||||
|
||||
kekNonce := make([]byte, gcmNonceSize)
|
||||
if _, err := io.ReadFull(rand.Reader, kekNonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate KEK nonce: %w", err)
|
||||
}
|
||||
encryptedDEK := kekGCM.Seal(nil, kekNonce, dek, nil)
|
||||
|
||||
// Encrypt the plaintext with the DEK
|
||||
dekBlock, err := aes.NewCipher(dek)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create DEK cipher: %w", err)
|
||||
}
|
||||
dekGCM, err := cipher.NewGCM(dekBlock)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create DEK GCM: %w", err)
|
||||
}
|
||||
|
||||
dekNonce := make([]byte, gcmNonceSize)
|
||||
if _, err := io.ReadFull(rand.Reader, dekNonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate DEK nonce: %w", err)
|
||||
}
|
||||
ciphertext := dekGCM.Seal(nil, dekNonce, plaintext, nil)
|
||||
|
||||
// Pack the output: version + kekNonce + encryptedDEK + dekNonce + ciphertext
|
||||
out := make([]byte, 0, headerSize+len(ciphertext))
|
||||
out = append(out, encryptionVersion)
|
||||
out = append(out, kekNonce...)
|
||||
out = append(out, encryptedDEK...)
|
||||
out = append(out, dekNonce...)
|
||||
out = append(out, ciphertext...)
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Decrypt reverses the Encrypt operation, recovering the original plaintext.
|
||||
func (s *EncryptionService) Decrypt(blob []byte) ([]byte, error) {
|
||||
if len(blob) < headerSize {
|
||||
return nil, fmt.Errorf("ciphertext too short (%d bytes, minimum %d)", len(blob), headerSize)
|
||||
}
|
||||
|
||||
// Parse version
|
||||
version := blob[0]
|
||||
if version != encryptionVersion {
|
||||
return nil, fmt.Errorf("unsupported encryption version: 0x%02x", version)
|
||||
}
|
||||
|
||||
offset := 1
|
||||
|
||||
// Parse KEK nonce
|
||||
kekNonce := blob[offset : offset+gcmNonceSize]
|
||||
offset += gcmNonceSize
|
||||
|
||||
// Parse encrypted DEK
|
||||
encryptedDEK := blob[offset : offset+encryptedDEKSize]
|
||||
offset += encryptedDEKSize
|
||||
|
||||
// Parse DEK nonce
|
||||
dekNonce := blob[offset : offset+gcmNonceSize]
|
||||
offset += gcmNonceSize
|
||||
|
||||
// Remaining bytes are the ciphertext + GCM tag
|
||||
ciphertext := blob[offset:]
|
||||
|
||||
// Decrypt the DEK with the KEK
|
||||
kekBlock, err := aes.NewCipher(s.kek)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create KEK cipher: %w", err)
|
||||
}
|
||||
kekGCM, err := cipher.NewGCM(kekBlock)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create KEK GCM: %w", err)
|
||||
}
|
||||
|
||||
dek, err := kekGCM.Open(nil, kekNonce, encryptedDEK, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt DEK (wrong key?): %w", err)
|
||||
}
|
||||
|
||||
// Decrypt the plaintext with the DEK
|
||||
dekBlock, err := aes.NewCipher(dek)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create DEK cipher: %w", err)
|
||||
}
|
||||
dekGCM, err := cipher.NewGCM(dekBlock)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create DEK GCM: %w", err)
|
||||
}
|
||||
|
||||
plaintext, err := dekGCM.Open(nil, dekNonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt data (tampered?): %w", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
218
internal/services/encryption_service_test.go
Normal file
218
internal/services/encryption_service_test.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// validTestKey returns a deterministic 64-char hex key for tests.
|
||||
func validTestKey() string {
|
||||
return "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
}
|
||||
|
||||
// randomHexKey generates a random 64-char hex key.
|
||||
func randomHexKey(t *testing.T) string {
|
||||
t.Helper()
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func TestNewEncryptionService_Valid(t *testing.T) {
|
||||
svc, err := NewEncryptionService(validTestKey())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if svc == nil {
|
||||
t.Fatal("expected non-nil service")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEncryptionService_InvalidHex(t *testing.T) {
|
||||
// 64 chars but not valid hex
|
||||
_, err := NewEncryptionService("zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid hex")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEncryptionService_WrongLength(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
}{
|
||||
{"too short", "0123456789abcdef"},
|
||||
{"too long", "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef00"},
|
||||
{"empty", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewEncryptionService(tt.key)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for wrong length key")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt_RoundTrip(t *testing.T) {
|
||||
svc, err := NewEncryptionService(validTestKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
plaintext := []byte("Hello, encryption at rest!")
|
||||
|
||||
ciphertext, err := svc.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt failed: %v", err)
|
||||
}
|
||||
|
||||
decrypted, err := svc.Decrypt(ciphertext)
|
||||
if err != nil {
|
||||
t.Fatalf("decrypt failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(plaintext, decrypted) {
|
||||
t.Fatalf("round-trip mismatch: got %q, want %q", decrypted, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt_EmptyPlaintext(t *testing.T) {
|
||||
svc, err := NewEncryptionService(validTestKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ciphertext, err := svc.Encrypt([]byte{})
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt failed: %v", err)
|
||||
}
|
||||
|
||||
decrypted, err := svc.Decrypt(ciphertext)
|
||||
if err != nil {
|
||||
t.Fatalf("decrypt failed: %v", err)
|
||||
}
|
||||
|
||||
if len(decrypted) != 0 {
|
||||
t.Fatalf("expected empty plaintext, got %d bytes", len(decrypted))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncrypt_DifferentCiphertexts(t *testing.T) {
|
||||
svc, err := NewEncryptionService(validTestKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
plaintext := []byte("same input")
|
||||
ct1, err := svc.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ct2, err := svc.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if bytes.Equal(ct1, ct2) {
|
||||
t.Fatal("encrypting the same plaintext twice should produce different ciphertexts")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecrypt_TamperDetection(t *testing.T) {
|
||||
svc, err := NewEncryptionService(validTestKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ciphertext, err := svc.Encrypt([]byte("sensitive data"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Flip a byte near the end (in the ciphertext portion)
|
||||
tampered := make([]byte, len(ciphertext))
|
||||
copy(tampered, ciphertext)
|
||||
tampered[len(tampered)-1] ^= 0xFF
|
||||
|
||||
_, err = svc.Decrypt(tampered)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when decrypting tampered ciphertext")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecrypt_WrongKEK(t *testing.T) {
|
||||
svc1, err := NewEncryptionService(validTestKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ciphertext, err := svc1.Encrypt([]byte("secret"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a second service with a different key
|
||||
svc2, err := NewEncryptionService(randomHexKey(t))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = svc2.Decrypt(ciphertext)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when decrypting with wrong KEK")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecrypt_TooShort(t *testing.T) {
|
||||
svc, err := NewEncryptionService(validTestKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = svc.Decrypt([]byte("short"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for too-short ciphertext")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecrypt_BadVersion(t *testing.T) {
|
||||
svc, err := NewEncryptionService(validTestKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ciphertext, err := svc.Encrypt([]byte("data"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Change version byte
|
||||
ciphertext[0] = 0xFF
|
||||
|
||||
_, err = svc.Decrypt(ciphertext)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for bad version")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsEnabled(t *testing.T) {
|
||||
svc, err := NewEncryptionService(validTestKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !svc.IsEnabled() {
|
||||
t.Fatal("expected IsEnabled() to return true")
|
||||
}
|
||||
|
||||
var nilSvc *EncryptionService
|
||||
if nilSvc.IsEnabled() {
|
||||
t.Fatal("expected IsEnabled() to return false for nil service")
|
||||
}
|
||||
}
|
||||
@@ -18,8 +18,9 @@ import (
|
||||
|
||||
// StorageService handles file uploads to local filesystem
|
||||
type StorageService struct {
|
||||
cfg *config.StorageConfig
|
||||
allowedTypes map[string]struct{} // P-12: Parsed once at init for O(1) lookups
|
||||
cfg *config.StorageConfig
|
||||
allowedTypes map[string]struct{} // P-12: Parsed once at init for O(1) lookups
|
||||
encryptionSvc *EncryptionService
|
||||
}
|
||||
|
||||
// UploadResult contains information about an uploaded file
|
||||
@@ -124,28 +125,39 @@ func (s *StorageService) Upload(file *multipart.FileHeader, category string) (*U
|
||||
subdir = "completions"
|
||||
}
|
||||
|
||||
// If encryption is enabled, append .enc suffix to the stored filename
|
||||
storedFilename := newFilename
|
||||
if s.encryptionSvc.IsEnabled() {
|
||||
storedFilename = newFilename + ".enc"
|
||||
}
|
||||
|
||||
// S-18: Sanitize path to prevent traversal attacks
|
||||
destPath, err := SafeResolvePath(s.cfg.UploadDir, filepath.Join(subdir, newFilename))
|
||||
destPath, err := SafeResolvePath(s.cfg.UploadDir, filepath.Join(subdir, storedFilename))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid upload path: %w", err)
|
||||
}
|
||||
|
||||
// Create destination file
|
||||
dst, err := os.Create(destPath)
|
||||
// Read all file content into memory for potential encryption
|
||||
fileData, err := io.ReadAll(src)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create destination file: %w", err)
|
||||
return nil, fmt.Errorf("failed to read file content: %w", err)
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
// Copy file content
|
||||
written, err := io.Copy(dst, src)
|
||||
if err != nil {
|
||||
// Clean up on error
|
||||
os.Remove(destPath)
|
||||
// Encrypt if encryption is enabled
|
||||
if s.encryptionSvc.IsEnabled() {
|
||||
fileData, err = s.encryptionSvc.Encrypt(fileData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Write file content to disk
|
||||
if err := os.WriteFile(destPath, fileData, 0644); err != nil {
|
||||
return nil, fmt.Errorf("failed to save file: %w", err)
|
||||
}
|
||||
written := int64(len(fileData))
|
||||
|
||||
// Generate URL
|
||||
// Generate URL (always uses the original filename without .enc suffix for the public URL)
|
||||
url := fmt.Sprintf("%s/%s/%s", s.cfg.BaseURL, subdir, newFilename)
|
||||
|
||||
log.Info().
|
||||
@@ -163,7 +175,61 @@ func (s *StorageService) Upload(file *multipart.FileHeader, category string) (*U
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Delete removes a file from storage
|
||||
// ReadFile reads and optionally decrypts a stored file. It returns the plaintext
|
||||
// bytes and the detected MIME type. If the file is stored with an .enc suffix,
|
||||
// it is decrypted automatically.
|
||||
func (s *StorageService) ReadFile(storedURL string) ([]byte, string, error) {
|
||||
if storedURL == "" {
|
||||
return nil, "", fmt.Errorf("empty file URL")
|
||||
}
|
||||
|
||||
// Strip base URL prefix to get relative path
|
||||
relativePath := strings.TrimPrefix(storedURL, s.cfg.BaseURL)
|
||||
relativePath = strings.TrimPrefix(relativePath, "/")
|
||||
|
||||
// Try .enc variant first, then plain file
|
||||
var fullPath string
|
||||
var encrypted bool
|
||||
|
||||
encPath, err := SafeResolvePath(s.cfg.UploadDir, relativePath+".enc")
|
||||
if err == nil {
|
||||
if _, statErr := os.Stat(encPath); statErr == nil {
|
||||
fullPath = encPath
|
||||
encrypted = true
|
||||
}
|
||||
}
|
||||
|
||||
if fullPath == "" {
|
||||
plainPath, err := SafeResolvePath(s.cfg.UploadDir, relativePath)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("invalid file path: %w", err)
|
||||
}
|
||||
fullPath = plainPath
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(fullPath)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt if this is an encrypted file
|
||||
if encrypted {
|
||||
if s.encryptionSvc == nil || !s.encryptionSvc.IsEnabled() {
|
||||
return nil, "", fmt.Errorf("encrypted file found but encryption service is not configured")
|
||||
}
|
||||
data, err = s.encryptionSvc.Decrypt(data)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to decrypt file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Detect MIME type from decrypted content
|
||||
mimeType := http.DetectContentType(data)
|
||||
|
||||
return data, mimeType, nil
|
||||
}
|
||||
|
||||
// Delete removes a file from storage, handling both plain and .enc variants
|
||||
func (s *StorageService) Delete(fileURL string) error {
|
||||
// Convert URL to file path
|
||||
relativePath := strings.TrimPrefix(fileURL, s.cfg.BaseURL)
|
||||
@@ -175,14 +241,35 @@ func (s *StorageService) Delete(fileURL string) error {
|
||||
return fmt.Errorf("invalid file path: %w", err)
|
||||
}
|
||||
|
||||
// Try to delete the plain file
|
||||
plainDeleted := false
|
||||
if err := os.Remove(fullPath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil // File already doesn't exist
|
||||
if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to delete file: %w", err)
|
||||
}
|
||||
return fmt.Errorf("failed to delete file: %w", err)
|
||||
} else {
|
||||
plainDeleted = true
|
||||
log.Info().Str("path", fullPath).Msg("File deleted")
|
||||
}
|
||||
|
||||
// Also try to delete the .enc variant
|
||||
encPath, err := SafeResolvePath(s.cfg.UploadDir, relativePath+".enc")
|
||||
if err == nil {
|
||||
if err := os.Remove(encPath); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to delete encrypted file: %w", err)
|
||||
}
|
||||
} else {
|
||||
log.Info().Str("path", encPath).Msg("Encrypted file deleted")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if !plainDeleted {
|
||||
// Neither file existed — that's OK
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info().Str("path", fullPath).Msg("File deleted")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -225,6 +312,11 @@ func (s *StorageService) GetUploadDir() string {
|
||||
return s.cfg.UploadDir
|
||||
}
|
||||
|
||||
// SetEncryptionService sets the encryption service for encrypting files at rest
|
||||
func (s *StorageService) SetEncryptionService(svc *EncryptionService) {
|
||||
s.encryptionSvc = svc
|
||||
}
|
||||
|
||||
// NewStorageServiceForTest creates a StorageService without creating directories.
|
||||
// This is intended only for unit tests that need a StorageService with a known config.
|
||||
func NewStorageServiceForTest(cfg *config.StorageConfig) *StorageService {
|
||||
|
||||
164
internal/services/storage_service_test.go
Normal file
164
internal/services/storage_service_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
)
|
||||
|
||||
func setupTestStorage(t *testing.T, encrypt bool) (*StorageService, string) {
|
||||
t.Helper()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
cfg := &config.StorageConfig{
|
||||
UploadDir: tmpDir,
|
||||
BaseURL: "/uploads",
|
||||
MaxFileSize: 10 * 1024 * 1024,
|
||||
AllowedTypes: "image/jpeg,image/png,application/pdf",
|
||||
}
|
||||
|
||||
svc, err := NewStorageService(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create storage service: %v", err)
|
||||
}
|
||||
|
||||
if encrypt {
|
||||
encSvc, err := NewEncryptionService(validTestKey())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create encryption service: %v", err)
|
||||
}
|
||||
svc.SetEncryptionService(encSvc)
|
||||
}
|
||||
|
||||
return svc, tmpDir
|
||||
}
|
||||
|
||||
func TestReadFile_PlainFile(t *testing.T) {
|
||||
svc, tmpDir := setupTestStorage(t, false)
|
||||
|
||||
// Write a plain file
|
||||
content := []byte("hello world")
|
||||
dir := filepath.Join(tmpDir, "images")
|
||||
if err := os.WriteFile(filepath.Join(dir, "test.jpg"), content, 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, mimeType, err := svc.ReadFile("/uploads/images/test.jpg")
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(data, content) {
|
||||
t.Fatalf("data mismatch: got %q, want %q", data, content)
|
||||
}
|
||||
|
||||
if mimeType == "" {
|
||||
t.Fatal("expected non-empty MIME type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFile_EncryptedFile(t *testing.T) {
|
||||
svc, tmpDir := setupTestStorage(t, true)
|
||||
|
||||
// Encrypt and write a file manually (simulating what Upload does)
|
||||
originalContent := []byte("sensitive document content here - must be long enough for detection")
|
||||
encSvc, _ := NewEncryptionService(validTestKey())
|
||||
encrypted, err := encSvc.Encrypt(originalContent)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dir := filepath.Join(tmpDir, "documents")
|
||||
if err := os.WriteFile(filepath.Join(dir, "test.pdf.enc"), encrypted, 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// ReadFile should find the .enc file and decrypt it
|
||||
data, _, err := svc.ReadFile("/uploads/documents/test.pdf")
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(data, originalContent) {
|
||||
t.Fatalf("decrypted data mismatch: got %q, want %q", data, originalContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFile_EncFilePreferredOverPlain(t *testing.T) {
|
||||
svc, tmpDir := setupTestStorage(t, true)
|
||||
|
||||
encSvc, _ := NewEncryptionService(validTestKey())
|
||||
|
||||
plainContent := []byte("plain version")
|
||||
encContent := []byte("encrypted version - the correct one")
|
||||
encrypted, _ := encSvc.Encrypt(encContent)
|
||||
|
||||
dir := filepath.Join(tmpDir, "images")
|
||||
os.WriteFile(filepath.Join(dir, "photo.jpg"), plainContent, 0644)
|
||||
os.WriteFile(filepath.Join(dir, "photo.jpg.enc"), encrypted, 0644)
|
||||
|
||||
// Should prefer the .enc file
|
||||
data, _, err := svc.ReadFile("/uploads/images/photo.jpg")
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(data, encContent) {
|
||||
t.Fatalf("expected encrypted version content, got %q", data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFile_EmptyURL(t *testing.T) {
|
||||
svc, _ := setupTestStorage(t, false)
|
||||
|
||||
_, _, err := svc.ReadFile("")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFile_MissingFile(t *testing.T) {
|
||||
svc, _ := setupTestStorage(t, false)
|
||||
|
||||
_, _, err := svc.ReadFile("/uploads/images/nonexistent.jpg")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelete_HandlesEncAndPlain(t *testing.T) {
|
||||
svc, tmpDir := setupTestStorage(t, false)
|
||||
|
||||
dir := filepath.Join(tmpDir, "images")
|
||||
|
||||
// Create both plain and .enc files
|
||||
os.WriteFile(filepath.Join(dir, "photo.jpg"), []byte("plain"), 0644)
|
||||
os.WriteFile(filepath.Join(dir, "photo.jpg.enc"), []byte("encrypted"), 0644)
|
||||
|
||||
err := svc.Delete("/uploads/images/photo.jpg")
|
||||
if err != nil {
|
||||
t.Fatalf("Delete failed: %v", err)
|
||||
}
|
||||
|
||||
// Both should be gone
|
||||
if _, err := os.Stat(filepath.Join(dir, "photo.jpg")); !os.IsNotExist(err) {
|
||||
t.Fatal("plain file should be deleted")
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(dir, "photo.jpg.enc")); !os.IsNotExist(err) {
|
||||
t.Fatal("encrypted file should be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelete_NonexistentFile(t *testing.T) {
|
||||
svc, _ := setupTestStorage(t, false)
|
||||
|
||||
// Should not error for non-existent files
|
||||
err := svc.Delete("/uploads/images/nope.jpg")
|
||||
if err != nil {
|
||||
t.Fatalf("Delete should not error for non-existent file: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -848,39 +847,33 @@ func (s *TaskService) sendTaskCompletedNotification(task *models.Task, completio
|
||||
}
|
||||
}
|
||||
|
||||
// loadCompletionImagesForEmail reads completion images from disk and prepares them for email embedding
|
||||
// loadCompletionImagesForEmail reads completion images from disk and prepares them for email embedding.
|
||||
// Uses StorageService.ReadFile to transparently handle encrypted files.
|
||||
func (s *TaskService) loadCompletionImagesForEmail(images []models.TaskCompletionImage) []EmbeddedImage {
|
||||
var emailImages []EmbeddedImage
|
||||
|
||||
uploadDir := s.storageService.GetUploadDir()
|
||||
|
||||
for i, img := range images {
|
||||
// Resolve file path from stored URL
|
||||
filePath := s.resolveImageFilePath(img.ImageURL, uploadDir)
|
||||
if filePath == "" {
|
||||
log.Warn().Str("image_url", img.ImageURL).Msg("Could not resolve image file path")
|
||||
continue
|
||||
}
|
||||
|
||||
// Read file from disk
|
||||
data, err := os.ReadFile(filePath)
|
||||
// Read file via storage service (handles encryption transparently)
|
||||
data, mimeType, err := s.storageService.ReadFile(img.ImageURL)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("path", filePath).Msg("Failed to read completion image for email")
|
||||
log.Warn().Err(err).Str("image_url", img.ImageURL).Msg("Failed to read completion image for email")
|
||||
continue
|
||||
}
|
||||
|
||||
// Determine content type from extension
|
||||
contentType := s.getContentTypeFromPath(filePath)
|
||||
// Use detected MIME type, fall back to extension-based detection
|
||||
if mimeType == "application/octet-stream" {
|
||||
mimeType = s.getContentTypeFromPath(img.ImageURL)
|
||||
}
|
||||
|
||||
// Create embedded image with unique Content-ID
|
||||
emailImages = append(emailImages, EmbeddedImage{
|
||||
ContentID: fmt.Sprintf("completion-image-%d", i+1),
|
||||
Filename: filepath.Base(filePath),
|
||||
ContentType: contentType,
|
||||
Filename: filepath.Base(img.ImageURL),
|
||||
ContentType: mimeType,
|
||||
Data: data,
|
||||
})
|
||||
|
||||
log.Debug().Str("path", filePath).Int("size", len(data)).Msg("Loaded completion image for email")
|
||||
log.Debug().Str("image_url", img.ImageURL).Int("size", len(data)).Msg("Loaded completion image for email")
|
||||
}
|
||||
|
||||
return emailImages
|
||||
|
||||
@@ -60,6 +60,9 @@ func SetupTestDB(t *testing.T) *gorm.DB {
|
||||
&models.NotificationPreference{},
|
||||
&models.APNSDevice{},
|
||||
&models.GCMDevice{},
|
||||
&models.AppleSocialAuth{},
|
||||
&models.GoogleSocialAuth{},
|
||||
&models.TaskReminderLog{},
|
||||
&models.UserSubscription{},
|
||||
&models.SubscriptionSettings{},
|
||||
&models.TierLimits{},
|
||||
|
||||
Reference in New Issue
Block a user