Files
PlantGuideScraper/backend/app/api/exports.py
2026-04-12 09:54:27 -05:00

176 lines
5.4 KiB
Python

import json
import os
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from sqlalchemy import func
from app.database import get_db
from app.models import Export, Image, Species
from app.schemas.export import (
ExportCreate,
ExportResponse,
ExportListResponse,
ExportPreview,
)
from app.workers.export_tasks import generate_export
router = APIRouter()
@router.get("", response_model=ExportListResponse)
def list_exports(
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
):
"""List all exports."""
total = db.query(Export).count()
exports = db.query(Export).order_by(Export.created_at.desc()).limit(limit).all()
return ExportListResponse(
items=[ExportResponse.model_validate(e) for e in exports],
total=total,
)
@router.post("/preview", response_model=ExportPreview)
def preview_export(export: ExportCreate, db: Session = Depends(get_db)):
"""Preview export without creating it."""
criteria = export.filter_criteria
min_images = criteria.min_images_per_species
# Build query
query = db.query(Image).filter(Image.status == "downloaded")
if criteria.licenses:
query = query.filter(Image.license.in_(criteria.licenses))
if criteria.min_quality:
query = query.filter(Image.quality_score >= criteria.min_quality)
if criteria.species_ids:
query = query.filter(Image.species_id.in_(criteria.species_ids))
# Count images per species
species_counts = db.query(
Image.species_id,
func.count(Image.id).label("count")
).filter(Image.status == "downloaded")
if criteria.licenses:
species_counts = species_counts.filter(Image.license.in_(criteria.licenses))
if criteria.min_quality:
species_counts = species_counts.filter(Image.quality_score >= criteria.min_quality)
if criteria.species_ids:
species_counts = species_counts.filter(Image.species_id.in_(criteria.species_ids))
species_counts = species_counts.group_by(Image.species_id).all()
valid_species = [s for s in species_counts if s.count >= min_images]
total_images = sum(s.count for s in valid_species)
# Estimate file size (rough: 50KB per image)
estimated_size_mb = (total_images * 50) / 1024
return ExportPreview(
species_count=len(valid_species),
image_count=total_images,
estimated_size_mb=estimated_size_mb,
)
@router.post("", response_model=ExportResponse)
def create_export(export: ExportCreate, db: Session = Depends(get_db)):
"""Create and start a new export job."""
db_export = Export(
name=export.name,
filter_criteria=export.filter_criteria.model_dump_json(),
train_split=export.train_split,
status="pending",
)
db.add(db_export)
db.commit()
db.refresh(db_export)
# Start Celery task
task = generate_export.delay(db_export.id)
db_export.celery_task_id = task.id
db.commit()
return ExportResponse.model_validate(db_export)
@router.get("/{export_id}", response_model=ExportResponse)
def get_export(export_id: int, db: Session = Depends(get_db)):
"""Get export status."""
export = db.query(Export).filter(Export.id == export_id).first()
if not export:
raise HTTPException(status_code=404, detail="Export not found")
return ExportResponse.model_validate(export)
@router.get("/{export_id}/progress")
def get_export_progress(export_id: int, db: Session = Depends(get_db)):
"""Get real-time export progress."""
from app.workers.celery_app import celery_app
export = db.query(Export).filter(Export.id == export_id).first()
if not export:
raise HTTPException(status_code=404, detail="Export not found")
if not export.celery_task_id:
return {"status": export.status}
result = celery_app.AsyncResult(export.celery_task_id)
if result.state == "PROGRESS":
meta = result.info
return {
"status": "generating",
"current": meta.get("current", 0),
"total": meta.get("total", 0),
"current_species": meta.get("species", ""),
}
return {"status": export.status}
@router.get("/{export_id}/download")
def download_export(export_id: int, db: Session = Depends(get_db)):
"""Download export zip file."""
export = db.query(Export).filter(Export.id == export_id).first()
if not export:
raise HTTPException(status_code=404, detail="Export not found")
if export.status != "completed":
raise HTTPException(status_code=400, detail="Export not ready")
if not export.file_path or not os.path.exists(export.file_path):
raise HTTPException(status_code=404, detail="Export file not found")
return FileResponse(
export.file_path,
media_type="application/zip",
filename=f"{export.name}.zip",
)
@router.delete("/{export_id}")
def delete_export(export_id: int, db: Session = Depends(get_db)):
"""Delete an export and its file."""
export = db.query(Export).filter(Export.id == export_id).first()
if not export:
raise HTTPException(status_code=404, detail="Export not found")
# Delete file if exists
if export.file_path and os.path.exists(export.file_path):
os.remove(export.file_path)
db.delete(export)
db.commit()
return {"status": "deleted"}