176 lines
5.4 KiB
Python
176 lines
5.4 KiB
Python
import json
|
|
import os
|
|
from typing import Optional
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from fastapi.responses import FileResponse
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import func
|
|
|
|
from app.database import get_db
|
|
from app.models import Export, Image, Species
|
|
from app.schemas.export import (
|
|
ExportCreate,
|
|
ExportResponse,
|
|
ExportListResponse,
|
|
ExportPreview,
|
|
)
|
|
from app.workers.export_tasks import generate_export
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.get("", response_model=ExportListResponse)
|
|
def list_exports(
|
|
limit: int = Query(50, ge=1, le=200),
|
|
db: Session = Depends(get_db),
|
|
):
|
|
"""List all exports."""
|
|
total = db.query(Export).count()
|
|
exports = db.query(Export).order_by(Export.created_at.desc()).limit(limit).all()
|
|
|
|
return ExportListResponse(
|
|
items=[ExportResponse.model_validate(e) for e in exports],
|
|
total=total,
|
|
)
|
|
|
|
|
|
@router.post("/preview", response_model=ExportPreview)
|
|
def preview_export(export: ExportCreate, db: Session = Depends(get_db)):
|
|
"""Preview export without creating it."""
|
|
criteria = export.filter_criteria
|
|
min_images = criteria.min_images_per_species
|
|
|
|
# Build query
|
|
query = db.query(Image).filter(Image.status == "downloaded")
|
|
|
|
if criteria.licenses:
|
|
query = query.filter(Image.license.in_(criteria.licenses))
|
|
|
|
if criteria.min_quality:
|
|
query = query.filter(Image.quality_score >= criteria.min_quality)
|
|
|
|
if criteria.species_ids:
|
|
query = query.filter(Image.species_id.in_(criteria.species_ids))
|
|
|
|
# Count images per species
|
|
species_counts = db.query(
|
|
Image.species_id,
|
|
func.count(Image.id).label("count")
|
|
).filter(Image.status == "downloaded")
|
|
|
|
if criteria.licenses:
|
|
species_counts = species_counts.filter(Image.license.in_(criteria.licenses))
|
|
if criteria.min_quality:
|
|
species_counts = species_counts.filter(Image.quality_score >= criteria.min_quality)
|
|
if criteria.species_ids:
|
|
species_counts = species_counts.filter(Image.species_id.in_(criteria.species_ids))
|
|
|
|
species_counts = species_counts.group_by(Image.species_id).all()
|
|
|
|
valid_species = [s for s in species_counts if s.count >= min_images]
|
|
total_images = sum(s.count for s in valid_species)
|
|
|
|
# Estimate file size (rough: 50KB per image)
|
|
estimated_size_mb = (total_images * 50) / 1024
|
|
|
|
return ExportPreview(
|
|
species_count=len(valid_species),
|
|
image_count=total_images,
|
|
estimated_size_mb=estimated_size_mb,
|
|
)
|
|
|
|
|
|
@router.post("", response_model=ExportResponse)
|
|
def create_export(export: ExportCreate, db: Session = Depends(get_db)):
|
|
"""Create and start a new export job."""
|
|
db_export = Export(
|
|
name=export.name,
|
|
filter_criteria=export.filter_criteria.model_dump_json(),
|
|
train_split=export.train_split,
|
|
status="pending",
|
|
)
|
|
db.add(db_export)
|
|
db.commit()
|
|
db.refresh(db_export)
|
|
|
|
# Start Celery task
|
|
task = generate_export.delay(db_export.id)
|
|
db_export.celery_task_id = task.id
|
|
db.commit()
|
|
|
|
return ExportResponse.model_validate(db_export)
|
|
|
|
|
|
@router.get("/{export_id}", response_model=ExportResponse)
|
|
def get_export(export_id: int, db: Session = Depends(get_db)):
|
|
"""Get export status."""
|
|
export = db.query(Export).filter(Export.id == export_id).first()
|
|
if not export:
|
|
raise HTTPException(status_code=404, detail="Export not found")
|
|
|
|
return ExportResponse.model_validate(export)
|
|
|
|
|
|
@router.get("/{export_id}/progress")
|
|
def get_export_progress(export_id: int, db: Session = Depends(get_db)):
|
|
"""Get real-time export progress."""
|
|
from app.workers.celery_app import celery_app
|
|
|
|
export = db.query(Export).filter(Export.id == export_id).first()
|
|
if not export:
|
|
raise HTTPException(status_code=404, detail="Export not found")
|
|
|
|
if not export.celery_task_id:
|
|
return {"status": export.status}
|
|
|
|
result = celery_app.AsyncResult(export.celery_task_id)
|
|
|
|
if result.state == "PROGRESS":
|
|
meta = result.info
|
|
return {
|
|
"status": "generating",
|
|
"current": meta.get("current", 0),
|
|
"total": meta.get("total", 0),
|
|
"current_species": meta.get("species", ""),
|
|
}
|
|
|
|
return {"status": export.status}
|
|
|
|
|
|
@router.get("/{export_id}/download")
|
|
def download_export(export_id: int, db: Session = Depends(get_db)):
|
|
"""Download export zip file."""
|
|
export = db.query(Export).filter(Export.id == export_id).first()
|
|
if not export:
|
|
raise HTTPException(status_code=404, detail="Export not found")
|
|
|
|
if export.status != "completed":
|
|
raise HTTPException(status_code=400, detail="Export not ready")
|
|
|
|
if not export.file_path or not os.path.exists(export.file_path):
|
|
raise HTTPException(status_code=404, detail="Export file not found")
|
|
|
|
return FileResponse(
|
|
export.file_path,
|
|
media_type="application/zip",
|
|
filename=f"{export.name}.zip",
|
|
)
|
|
|
|
|
|
@router.delete("/{export_id}")
|
|
def delete_export(export_id: int, db: Session = Depends(get_db)):
|
|
"""Delete an export and its file."""
|
|
export = db.query(Export).filter(Export.id == export_id).first()
|
|
if not export:
|
|
raise HTTPException(status_code=404, detail="Export not found")
|
|
|
|
# Delete file if exists
|
|
if export.file_path and os.path.exists(export.file_path):
|
|
os.remove(export.file_path)
|
|
|
|
db.delete(export)
|
|
db.commit()
|
|
|
|
return {"status": "deleted"}
|