package services import ( "fmt" "io" "mime/multipart" "net/http" "os" "path/filepath" "strings" "time" "github.com/google/uuid" "github.com/rs/zerolog/log" "github.com/treytartt/honeydue-api/internal/config" ) // StorageService handles file uploads to local filesystem type StorageService struct { cfg *config.StorageConfig allowedTypes map[string]struct{} // P-12: Parsed once at init for O(1) lookups } // UploadResult contains information about an uploaded file type UploadResult struct { URL string `json:"url"` FileName string `json:"file_name"` FileSize int64 `json:"file_size"` MimeType string `json:"mime_type"` } // NewStorageService creates a new storage service func NewStorageService(cfg *config.StorageConfig) (*StorageService, error) { // Ensure upload directory exists if err := os.MkdirAll(cfg.UploadDir, 0755); err != nil { return nil, fmt.Errorf("failed to create upload directory: %w", err) } // Create subdirectories for organization subdirs := []string{"images", "documents", "completions"} for _, subdir := range subdirs { path := filepath.Join(cfg.UploadDir, subdir) if err := os.MkdirAll(path, 0755); err != nil { return nil, fmt.Errorf("failed to create subdirectory %s: %w", subdir, err) } } // P-12: Parse AllowedTypes once at initialization for O(1) lookups allowedTypes := make(map[string]struct{}) for _, t := range strings.Split(cfg.AllowedTypes, ",") { trimmed := strings.TrimSpace(t) if trimmed != "" { allowedTypes[trimmed] = struct{}{} } } log.Info().Str("upload_dir", cfg.UploadDir).Int("allowed_types", len(allowedTypes)).Msg("Storage service initialized") return &StorageService{cfg: cfg, allowedTypes: allowedTypes}, nil } // Upload saves a file to the local filesystem func (s *StorageService) Upload(file *multipart.FileHeader, category string) (*UploadResult, error) { // Validate file size if file.Size > s.cfg.MaxFileSize { return nil, fmt.Errorf("file size %d exceeds maximum allowed %d bytes", file.Size, s.cfg.MaxFileSize) } // Get claimed MIME type from header claimedMimeType := file.Header.Get("Content-Type") if claimedMimeType == "" { claimedMimeType = "application/octet-stream" } // S-09: Detect actual content type from file bytes to prevent disguised uploads src, err := file.Open() if err != nil { return nil, fmt.Errorf("failed to open uploaded file: %w", err) } defer src.Close() // Read the first 512 bytes for content type detection sniffBuf := make([]byte, 512) n, err := src.Read(sniffBuf) if err != nil && n == 0 { return nil, fmt.Errorf("failed to read file for content type detection: %w", err) } detectedMimeType := http.DetectContentType(sniffBuf[:n]) // Validate that the detected type matches the claimed type (at the category level) // Allow application/octet-stream from detection since DetectContentType may not // recognize all valid types, but the claimed type must still be in our allowed list if detectedMimeType != "application/octet-stream" && !s.mimeTypesCompatible(claimedMimeType, detectedMimeType) { return nil, fmt.Errorf("file content type mismatch: claimed %s but detected %s", claimedMimeType, detectedMimeType) } // Use the claimed MIME type (which is more specific) if it's allowed mimeType := claimedMimeType // Validate MIME type against allowed list if !s.isAllowedType(mimeType) { return nil, fmt.Errorf("file type %s is not allowed", mimeType) } // Seek back to beginning after sniffing if _, err := src.Seek(0, io.SeekStart); err != nil { return nil, fmt.Errorf("failed to seek file: %w", err) } // Generate unique filename ext := filepath.Ext(file.Filename) if ext == "" { ext = s.getExtensionFromMimeType(mimeType) } newFilename := fmt.Sprintf("%s_%s%s", time.Now().Format("20060102"), uuid.New().String(), ext) // Determine subdirectory based on category subdir := "images" switch category { case "document", "documents": subdir = "documents" case "completion", "completions": subdir = "completions" } // S-18: Sanitize path to prevent traversal attacks destPath, err := SafeResolvePath(s.cfg.UploadDir, filepath.Join(subdir, newFilename)) if err != nil { return nil, fmt.Errorf("invalid upload path: %w", err) } // Create destination file dst, err := os.Create(destPath) if err != nil { return nil, fmt.Errorf("failed to create destination file: %w", err) } defer dst.Close() // Copy file content written, err := io.Copy(dst, src) if err != nil { // Clean up on error os.Remove(destPath) return nil, fmt.Errorf("failed to save file: %w", err) } // Generate URL url := fmt.Sprintf("%s/%s/%s", s.cfg.BaseURL, subdir, newFilename) log.Info(). Str("filename", newFilename). Str("category", category). Int64("size", written). Str("mime_type", mimeType). Msg("File uploaded successfully") return &UploadResult{ URL: url, FileName: file.Filename, FileSize: written, MimeType: mimeType, }, nil } // Delete removes a file from storage func (s *StorageService) Delete(fileURL string) error { // Convert URL to file path relativePath := strings.TrimPrefix(fileURL, s.cfg.BaseURL) relativePath = strings.TrimPrefix(relativePath, "/") // S-18: Use SafeResolvePath to prevent path traversal fullPath, err := SafeResolvePath(s.cfg.UploadDir, relativePath) if err != nil { return fmt.Errorf("invalid file path: %w", err) } if err := os.Remove(fullPath); err != nil { if os.IsNotExist(err) { return nil // File already doesn't exist } return fmt.Errorf("failed to delete file: %w", err) } log.Info().Str("path", fullPath).Msg("File deleted") return nil } // isAllowedType checks if the MIME type is in the allowed list. // P-12: Uses the pre-parsed allowedTypes map for O(1) lookups instead of // splitting the config string on every call. func (s *StorageService) isAllowedType(mimeType string) bool { _, ok := s.allowedTypes[mimeType] return ok } // mimeTypesCompatible checks if the claimed and detected MIME types are compatible. // Two MIME types are compatible if they share the same primary type (e.g., both "image/*"). func (s *StorageService) mimeTypesCompatible(claimed, detected string) bool { claimedParts := strings.SplitN(claimed, "/", 2) detectedParts := strings.SplitN(detected, "/", 2) if len(claimedParts) < 1 || len(detectedParts) < 1 { return false } return claimedParts[0] == detectedParts[0] } // getExtensionFromMimeType returns a file extension for common MIME types func (s *StorageService) getExtensionFromMimeType(mimeType string) string { extensions := map[string]string{ "image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif", "image/webp": ".webp", "application/pdf": ".pdf", } if ext, ok := extensions[mimeType]; ok { return ext } return "" } // GetUploadDir returns the upload directory path func (s *StorageService) GetUploadDir() string { return s.cfg.UploadDir } // NewStorageServiceForTest creates a StorageService without creating directories. // This is intended only for unit tests that need a StorageService with a known config. func NewStorageServiceForTest(cfg *config.StorageConfig) *StorageService { allowedTypes := make(map[string]struct{}) for _, t := range strings.Split(cfg.AllowedTypes, ",") { trimmed := strings.TrimSpace(t) if trimmed != "" { allowedTypes[trimmed] = struct{}{} } } return &StorageService{cfg: cfg, allowedTypes: allowedTypes} }