171 lines
5.9 KiB
Python
171 lines
5.9 KiB
Python
import json
|
|
import os
|
|
import random
|
|
import shutil
|
|
import zipfile
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
from app.workers.celery_app import celery_app
|
|
from app.database import SessionLocal
|
|
from app.models import Export, Image, Species
|
|
from app.config import get_settings
|
|
|
|
settings = get_settings()
|
|
|
|
|
|
@celery_app.task(bind=True)
|
|
def generate_export(self, export_id: int):
|
|
"""Generate a zip export for CoreML training."""
|
|
db = SessionLocal()
|
|
try:
|
|
export = db.query(Export).filter(Export.id == export_id).first()
|
|
if not export:
|
|
return {"error": "Export not found"}
|
|
|
|
# Update status
|
|
export.status = "generating"
|
|
export.celery_task_id = self.request.id
|
|
db.commit()
|
|
|
|
# Parse filter criteria
|
|
criteria = json.loads(export.filter_criteria) if export.filter_criteria else {}
|
|
min_images = criteria.get("min_images_per_species", 100)
|
|
licenses = criteria.get("licenses")
|
|
min_quality = criteria.get("min_quality")
|
|
species_ids = criteria.get("species_ids")
|
|
|
|
# Build query for images
|
|
query = db.query(Image).filter(Image.status == "downloaded")
|
|
|
|
if licenses:
|
|
query = query.filter(Image.license.in_(licenses))
|
|
|
|
if min_quality:
|
|
query = query.filter(Image.quality_score >= min_quality)
|
|
|
|
if species_ids:
|
|
query = query.filter(Image.species_id.in_(species_ids))
|
|
|
|
# Group by species and filter by min count
|
|
from sqlalchemy import func
|
|
species_counts = db.query(
|
|
Image.species_id,
|
|
func.count(Image.id).label("count")
|
|
).filter(Image.status == "downloaded").group_by(Image.species_id).all()
|
|
|
|
valid_species_ids = [s.species_id for s in species_counts if s.count >= min_images]
|
|
|
|
if species_ids:
|
|
valid_species_ids = [s for s in valid_species_ids if s in species_ids]
|
|
|
|
if not valid_species_ids:
|
|
export.status = "failed"
|
|
export.error_message = "No species meet the criteria"
|
|
export.completed_at = datetime.utcnow()
|
|
db.commit()
|
|
return {"error": "No species meet the criteria"}
|
|
|
|
# Create export directory
|
|
export_dir = Path(settings.exports_path) / f"export_{export_id}"
|
|
train_dir = export_dir / "Training"
|
|
test_dir = export_dir / "Testing"
|
|
train_dir.mkdir(parents=True, exist_ok=True)
|
|
test_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
total_images = 0
|
|
species_count = 0
|
|
|
|
# Process each valid species
|
|
for i, species_id in enumerate(valid_species_ids):
|
|
species = db.query(Species).filter(Species.id == species_id).first()
|
|
if not species:
|
|
continue
|
|
|
|
# Get images for this species
|
|
images_query = query.filter(Image.species_id == species_id)
|
|
if licenses:
|
|
images_query = images_query.filter(Image.license.in_(licenses))
|
|
if min_quality:
|
|
images_query = images_query.filter(Image.quality_score >= min_quality)
|
|
|
|
images = images_query.all()
|
|
if len(images) < min_images:
|
|
continue
|
|
|
|
species_count += 1
|
|
|
|
# Create species folders
|
|
species_name = species.scientific_name.replace(" ", "_")
|
|
(train_dir / species_name).mkdir(exist_ok=True)
|
|
(test_dir / species_name).mkdir(exist_ok=True)
|
|
|
|
# Shuffle and split
|
|
random.shuffle(images)
|
|
split_idx = int(len(images) * export.train_split)
|
|
train_images = images[:split_idx]
|
|
test_images = images[split_idx:]
|
|
|
|
# Copy images
|
|
for j, img in enumerate(train_images):
|
|
if img.local_path and os.path.exists(img.local_path):
|
|
ext = Path(img.local_path).suffix or ".jpg"
|
|
dest = train_dir / species_name / f"img_{j:05d}{ext}"
|
|
shutil.copy2(img.local_path, dest)
|
|
total_images += 1
|
|
|
|
for j, img in enumerate(test_images):
|
|
if img.local_path and os.path.exists(img.local_path):
|
|
ext = Path(img.local_path).suffix or ".jpg"
|
|
dest = test_dir / species_name / f"img_{j:05d}{ext}"
|
|
shutil.copy2(img.local_path, dest)
|
|
total_images += 1
|
|
|
|
# Update progress
|
|
self.update_state(
|
|
state="PROGRESS",
|
|
meta={
|
|
"current": i + 1,
|
|
"total": len(valid_species_ids),
|
|
"species": species.scientific_name,
|
|
}
|
|
)
|
|
|
|
# Create zip file
|
|
zip_path = Path(settings.exports_path) / f"export_{export_id}.zip"
|
|
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
|
|
for root, dirs, files in os.walk(export_dir):
|
|
for file in files:
|
|
file_path = Path(root) / file
|
|
arcname = file_path.relative_to(export_dir)
|
|
zipf.write(file_path, arcname)
|
|
|
|
# Clean up directory
|
|
shutil.rmtree(export_dir)
|
|
|
|
# Update export record
|
|
export.status = "completed"
|
|
export.file_path = str(zip_path)
|
|
export.file_size = zip_path.stat().st_size
|
|
export.species_count = species_count
|
|
export.image_count = total_images
|
|
export.completed_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
return {
|
|
"status": "completed",
|
|
"species_count": species_count,
|
|
"image_count": total_images,
|
|
"file_size": export.file_size,
|
|
}
|
|
|
|
except Exception as e:
|
|
if export:
|
|
export.status = "failed"
|
|
export.error_message = str(e)
|
|
export.completed_at = datetime.utcnow()
|
|
db.commit()
|
|
raise
|
|
finally:
|
|
db.close()
|