package services import ( "fmt" "io" "mime/multipart" "net/http" "path/filepath" "strings" "time" "github.com/google/uuid" "github.com/rs/zerolog/log" "github.com/treytartt/honeydue-api/internal/config" ) // StorageService handles file uploads, validation, encryption, and URL generation. // It delegates raw I/O to a StorageBackend (local filesystem or S3-compatible). type StorageService struct { cfg *config.StorageConfig backend StorageBackend allowedTypes map[string]struct{} // P-12: Parsed once at init for O(1) lookups encryptionSvc *EncryptionService } // UploadResult contains information about an uploaded file type UploadResult struct { URL string `json:"url"` FileName string `json:"file_name"` FileSize int64 `json:"file_size"` MimeType string `json:"mime_type"` } // NewStorageService creates a new storage service with the appropriate backend. // If S3 config is set, uses S3-compatible storage (B2, MinIO). // Otherwise, uses local filesystem. func NewStorageService(cfg *config.StorageConfig) (*StorageService, error) { var backend StorageBackend var err error if cfg.IsS3() { backend, err = NewS3Backend(cfg.S3Endpoint, cfg.S3KeyID, cfg.S3AppKey, cfg.S3Bucket, cfg.S3UseSSL, cfg.S3Region) if err != nil { return nil, fmt.Errorf("failed to initialize S3 storage: %w", err) } log.Info(). Str("endpoint", cfg.S3Endpoint). Str("bucket", cfg.S3Bucket). Bool("ssl", cfg.S3UseSSL). Msg("Storage service initialized (S3)") } else { backend, err = NewLocalBackend(cfg.UploadDir) if err != nil { return nil, fmt.Errorf("failed to initialize local storage: %w", err) } log.Info(). Str("upload_dir", cfg.UploadDir). Msg("Storage service initialized (local)") } // P-12: Parse AllowedTypes once at initialization for O(1) lookups allowedTypes := parseAllowedTypes(cfg.AllowedTypes) return &StorageService{cfg: cfg, backend: backend, allowedTypes: allowedTypes}, nil } // Upload saves a file to storage (local or S3) func (s *StorageService) Upload(file *multipart.FileHeader, category string) (*UploadResult, error) { // Validate file size if file.Size > s.cfg.MaxFileSize { return nil, fmt.Errorf("file size %d exceeds maximum allowed %d bytes", file.Size, s.cfg.MaxFileSize) } // Get claimed MIME type from header claimedMimeType := file.Header.Get("Content-Type") if claimedMimeType == "" { claimedMimeType = "application/octet-stream" } // S-09: Detect actual content type from file bytes to prevent disguised uploads src, err := file.Open() if err != nil { return nil, fmt.Errorf("failed to open uploaded file: %w", err) } defer src.Close() // Read the first 512 bytes for content type detection sniffBuf := make([]byte, 512) n, err := src.Read(sniffBuf) if err != nil && n == 0 { return nil, fmt.Errorf("failed to read file for content type detection: %w", err) } detectedMimeType := http.DetectContentType(sniffBuf[:n]) // Validate that the detected type matches the claimed type (at the category level) if detectedMimeType != "application/octet-stream" && !s.mimeTypesCompatible(claimedMimeType, detectedMimeType) { return nil, fmt.Errorf("file content type mismatch: claimed %s but detected %s", claimedMimeType, detectedMimeType) } mimeType := claimedMimeType // Validate MIME type against allowed list if !s.isAllowedType(mimeType) { return nil, fmt.Errorf("file type %s is not allowed", mimeType) } // Seek back to beginning after sniffing if _, err := src.Seek(0, io.SeekStart); err != nil { return nil, fmt.Errorf("failed to seek file: %w", err) } // Generate unique filename ext := filepath.Ext(file.Filename) if ext == "" { ext = s.getExtensionFromMimeType(mimeType) } newFilename := fmt.Sprintf("%s_%s%s", time.Now().Format("20060102"), uuid.New().String(), ext) // Determine subdirectory based on category subdir := "images" switch category { case "document", "documents": subdir = "documents" case "completion", "completions": subdir = "completions" } // If encryption is enabled, append .enc suffix to the stored filename storedFilename := newFilename if s.encryptionSvc.IsEnabled() { storedFilename = newFilename + ".enc" } // Build the storage key (e.g., "images/20240101_uuid.jpg") key := subdir + "/" + storedFilename // Read all file content into memory for potential encryption fileData, err := io.ReadAll(src) if err != nil { return nil, fmt.Errorf("failed to read file content: %w", err) } // Encrypt if encryption is enabled if s.encryptionSvc.IsEnabled() { fileData, err = s.encryptionSvc.Encrypt(fileData) if err != nil { return nil, fmt.Errorf("failed to encrypt file: %w", err) } } // Write to backend if err := s.backend.Write(key, fileData); err != nil { return nil, fmt.Errorf("failed to save file: %w", err) } written := int64(len(fileData)) // Generate URL (always uses the original filename without .enc suffix) url := fmt.Sprintf("%s/%s/%s", s.cfg.BaseURL, subdir, newFilename) log.Info(). Str("filename", newFilename). Str("category", category). Int64("size", written). Str("mime_type", mimeType). Bool("s3", s.cfg.IsS3()). Msg("File uploaded successfully") return &UploadResult{ URL: url, FileName: file.Filename, FileSize: written, MimeType: mimeType, }, nil } // ReadFile reads and optionally decrypts a stored file. It returns the plaintext // bytes and the detected MIME type. If the file is stored with an .enc suffix, // it is decrypted automatically. func (s *StorageService) ReadFile(storedURL string) ([]byte, string, error) { if storedURL == "" { return nil, "", fmt.Errorf("empty file URL") } // Strip base URL prefix to get relative key relativeKey := strings.TrimPrefix(storedURL, s.cfg.BaseURL) relativeKey = strings.TrimPrefix(relativeKey, "/") // Try .enc variant first, then plain file var data []byte var encrypted bool var err error data, err = s.backend.Read(relativeKey + ".enc") if err == nil { encrypted = true } else { // Fall back to plain file data, err = s.backend.Read(relativeKey) if err != nil { return nil, "", fmt.Errorf("failed to read file: %w", err) } } // Decrypt if this is an encrypted file if encrypted { if s.encryptionSvc == nil || !s.encryptionSvc.IsEnabled() { return nil, "", fmt.Errorf("encrypted file found but encryption service is not configured") } data, err = s.encryptionSvc.Decrypt(data) if err != nil { return nil, "", fmt.Errorf("failed to decrypt file: %w", err) } } // Detect MIME type from decrypted content mimeType := http.DetectContentType(data) return data, mimeType, nil } // Delete removes a file from storage, handling both plain and .enc variants func (s *StorageService) Delete(fileURL string) error { relativePath := strings.TrimPrefix(fileURL, s.cfg.BaseURL) relativePath = strings.TrimPrefix(relativePath, "/") // Delete both plain and .enc variants (ignore not-found errors) plainErr := s.backend.Delete(relativePath) encErr := s.backend.Delete(relativePath + ".enc") // Only return an error if both failed for reasons other than not-found if plainErr != nil { log.Debug().Err(plainErr).Str("key", relativePath).Msg("Delete plain file") } if encErr != nil { log.Debug().Err(encErr).Str("key", relativePath+".enc").Msg("Delete enc file") } return nil } // GetUploadDir returns the upload directory path. // For S3 backends, returns empty string. func (s *StorageService) GetUploadDir() string { if lb, ok := s.backend.(*LocalBackend); ok { return lb.BaseDir() } return s.cfg.UploadDir } // SetEncryptionService sets the encryption service for encrypting files at rest func (s *StorageService) SetEncryptionService(svc *EncryptionService) { s.encryptionSvc = svc } // isAllowedType checks if the MIME type is in the allowed list. func (s *StorageService) isAllowedType(mimeType string) bool { _, ok := s.allowedTypes[mimeType] return ok } // mimeTypesCompatible checks if the claimed and detected MIME types are compatible. func (s *StorageService) mimeTypesCompatible(claimed, detected string) bool { claimedParts := strings.SplitN(claimed, "/", 2) detectedParts := strings.SplitN(detected, "/", 2) if len(claimedParts) < 1 || len(detectedParts) < 1 { return false } return claimedParts[0] == detectedParts[0] } // getExtensionFromMimeType returns a file extension for common MIME types func (s *StorageService) getExtensionFromMimeType(mimeType string) string { extensions := map[string]string{ "image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif", "image/webp": ".webp", "application/pdf": ".pdf", } if ext, ok := extensions[mimeType]; ok { return ext } return "" } // parseAllowedTypes splits a comma-separated MIME type string into a set. func parseAllowedTypes(types string) map[string]struct{} { allowed := make(map[string]struct{}) for _, t := range strings.Split(types, ",") { trimmed := strings.TrimSpace(t) if trimmed != "" { allowed[trimmed] = struct{}{} } } return allowed } // NewStorageServiceForTest creates a StorageService without creating directories. // This is intended only for unit tests that need a StorageService with a known config. func NewStorageServiceForTest(cfg *config.StorageConfig) *StorageService { return &StorageService{ cfg: cfg, backend: nil, // tests that need a backend must set it up allowedTypes: parseAllowedTypes(cfg.AllowedTypes), } }