import json import os from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.responses import FileResponse from sqlalchemy.orm import Session from sqlalchemy import func from app.database import get_db from app.models import Export, Image, Species from app.schemas.export import ( ExportCreate, ExportResponse, ExportListResponse, ExportPreview, ) from app.workers.export_tasks import generate_export router = APIRouter() @router.get("", response_model=ExportListResponse) def list_exports( limit: int = Query(50, ge=1, le=200), db: Session = Depends(get_db), ): """List all exports.""" total = db.query(Export).count() exports = db.query(Export).order_by(Export.created_at.desc()).limit(limit).all() return ExportListResponse( items=[ExportResponse.model_validate(e) for e in exports], total=total, ) @router.post("/preview", response_model=ExportPreview) def preview_export(export: ExportCreate, db: Session = Depends(get_db)): """Preview export without creating it.""" criteria = export.filter_criteria min_images = criteria.min_images_per_species # Build query query = db.query(Image).filter(Image.status == "downloaded") if criteria.licenses: query = query.filter(Image.license.in_(criteria.licenses)) if criteria.min_quality: query = query.filter(Image.quality_score >= criteria.min_quality) if criteria.species_ids: query = query.filter(Image.species_id.in_(criteria.species_ids)) # Count images per species species_counts = db.query( Image.species_id, func.count(Image.id).label("count") ).filter(Image.status == "downloaded") if criteria.licenses: species_counts = species_counts.filter(Image.license.in_(criteria.licenses)) if criteria.min_quality: species_counts = species_counts.filter(Image.quality_score >= criteria.min_quality) if criteria.species_ids: species_counts = species_counts.filter(Image.species_id.in_(criteria.species_ids)) species_counts = species_counts.group_by(Image.species_id).all() valid_species = [s for s in species_counts if s.count >= min_images] total_images = sum(s.count for s in valid_species) # Estimate file size (rough: 50KB per image) estimated_size_mb = (total_images * 50) / 1024 return ExportPreview( species_count=len(valid_species), image_count=total_images, estimated_size_mb=estimated_size_mb, ) @router.post("", response_model=ExportResponse) def create_export(export: ExportCreate, db: Session = Depends(get_db)): """Create and start a new export job.""" db_export = Export( name=export.name, filter_criteria=export.filter_criteria.model_dump_json(), train_split=export.train_split, status="pending", ) db.add(db_export) db.commit() db.refresh(db_export) # Start Celery task task = generate_export.delay(db_export.id) db_export.celery_task_id = task.id db.commit() return ExportResponse.model_validate(db_export) @router.get("/{export_id}", response_model=ExportResponse) def get_export(export_id: int, db: Session = Depends(get_db)): """Get export status.""" export = db.query(Export).filter(Export.id == export_id).first() if not export: raise HTTPException(status_code=404, detail="Export not found") return ExportResponse.model_validate(export) @router.get("/{export_id}/progress") def get_export_progress(export_id: int, db: Session = Depends(get_db)): """Get real-time export progress.""" from app.workers.celery_app import celery_app export = db.query(Export).filter(Export.id == export_id).first() if not export: raise HTTPException(status_code=404, detail="Export not found") if not export.celery_task_id: return {"status": export.status} result = celery_app.AsyncResult(export.celery_task_id) if result.state == "PROGRESS": meta = result.info return { "status": "generating", "current": meta.get("current", 0), "total": meta.get("total", 0), "current_species": meta.get("species", ""), } return {"status": export.status} @router.get("/{export_id}/download") def download_export(export_id: int, db: Session = Depends(get_db)): """Download export zip file.""" export = db.query(Export).filter(Export.id == export_id).first() if not export: raise HTTPException(status_code=404, detail="Export not found") if export.status != "completed": raise HTTPException(status_code=400, detail="Export not ready") if not export.file_path or not os.path.exists(export.file_path): raise HTTPException(status_code=404, detail="Export file not found") return FileResponse( export.file_path, media_type="application/zip", filename=f"{export.name}.zip", ) @router.delete("/{export_id}") def delete_export(export_id: int, db: Session = Depends(get_db)): """Delete an export and its file.""" export = db.query(Export).filter(Export.id == export_id).first() if not export: raise HTTPException(status_code=404, detail="Export not found") # Delete file if exists if export.file_path and os.path.exists(export.file_path): os.remove(export.file_path) db.delete(export) db.commit() return {"status": "deleted"}