import json import os import random import shutil import zipfile from datetime import datetime from pathlib import Path from app.workers.celery_app import celery_app from app.database import SessionLocal from app.models import Export, Image, Species from app.config import get_settings settings = get_settings() @celery_app.task(bind=True) def generate_export(self, export_id: int): """Generate a zip export for CoreML training.""" db = SessionLocal() try: export = db.query(Export).filter(Export.id == export_id).first() if not export: return {"error": "Export not found"} # Update status export.status = "generating" export.celery_task_id = self.request.id db.commit() # Parse filter criteria criteria = json.loads(export.filter_criteria) if export.filter_criteria else {} min_images = criteria.get("min_images_per_species", 100) licenses = criteria.get("licenses") min_quality = criteria.get("min_quality") species_ids = criteria.get("species_ids") # Build query for images query = db.query(Image).filter(Image.status == "downloaded") if licenses: query = query.filter(Image.license.in_(licenses)) if min_quality: query = query.filter(Image.quality_score >= min_quality) if species_ids: query = query.filter(Image.species_id.in_(species_ids)) # Group by species and filter by min count from sqlalchemy import func species_counts = db.query( Image.species_id, func.count(Image.id).label("count") ).filter(Image.status == "downloaded").group_by(Image.species_id).all() valid_species_ids = [s.species_id for s in species_counts if s.count >= min_images] if species_ids: valid_species_ids = [s for s in valid_species_ids if s in species_ids] if not valid_species_ids: export.status = "failed" export.error_message = "No species meet the criteria" export.completed_at = datetime.utcnow() db.commit() return {"error": "No species meet the criteria"} # Create export directory export_dir = Path(settings.exports_path) / f"export_{export_id}" train_dir = export_dir / "Training" test_dir = export_dir / "Testing" train_dir.mkdir(parents=True, exist_ok=True) test_dir.mkdir(parents=True, exist_ok=True) total_images = 0 species_count = 0 # Process each valid species for i, species_id in enumerate(valid_species_ids): species = db.query(Species).filter(Species.id == species_id).first() if not species: continue # Get images for this species images_query = query.filter(Image.species_id == species_id) if licenses: images_query = images_query.filter(Image.license.in_(licenses)) if min_quality: images_query = images_query.filter(Image.quality_score >= min_quality) images = images_query.all() if len(images) < min_images: continue species_count += 1 # Create species folders species_name = species.scientific_name.replace(" ", "_") (train_dir / species_name).mkdir(exist_ok=True) (test_dir / species_name).mkdir(exist_ok=True) # Shuffle and split random.shuffle(images) split_idx = int(len(images) * export.train_split) train_images = images[:split_idx] test_images = images[split_idx:] # Copy images for j, img in enumerate(train_images): if img.local_path and os.path.exists(img.local_path): ext = Path(img.local_path).suffix or ".jpg" dest = train_dir / species_name / f"img_{j:05d}{ext}" shutil.copy2(img.local_path, dest) total_images += 1 for j, img in enumerate(test_images): if img.local_path and os.path.exists(img.local_path): ext = Path(img.local_path).suffix or ".jpg" dest = test_dir / species_name / f"img_{j:05d}{ext}" shutil.copy2(img.local_path, dest) total_images += 1 # Update progress self.update_state( state="PROGRESS", meta={ "current": i + 1, "total": len(valid_species_ids), "species": species.scientific_name, } ) # Create zip file zip_path = Path(settings.exports_path) / f"export_{export_id}.zip" with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf: for root, dirs, files in os.walk(export_dir): for file in files: file_path = Path(root) / file arcname = file_path.relative_to(export_dir) zipf.write(file_path, arcname) # Clean up directory shutil.rmtree(export_dir) # Update export record export.status = "completed" export.file_path = str(zip_path) export.file_size = zip_path.stat().st_size export.species_count = species_count export.image_count = total_images export.completed_at = datetime.utcnow() db.commit() return { "status": "completed", "species_count": species_count, "image_count": total_images, "file_size": export.file_size, } except Exception as e: if export: export.status = "failed" export.error_message = str(e) export.completed_at = datetime.utcnow() db.commit() raise finally: db.close()