Initial commit — PlantGuideScraper project
This commit is contained in:
170
backend/app/workers/export_tasks.py
Normal file
170
backend/app/workers/export_tasks.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from app.workers.celery_app import celery_app
|
||||
from app.database import SessionLocal
|
||||
from app.models import Export, Image, Species
|
||||
from app.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
@celery_app.task(bind=True)
|
||||
def generate_export(self, export_id: int):
|
||||
"""Generate a zip export for CoreML training."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
export = db.query(Export).filter(Export.id == export_id).first()
|
||||
if not export:
|
||||
return {"error": "Export not found"}
|
||||
|
||||
# Update status
|
||||
export.status = "generating"
|
||||
export.celery_task_id = self.request.id
|
||||
db.commit()
|
||||
|
||||
# Parse filter criteria
|
||||
criteria = json.loads(export.filter_criteria) if export.filter_criteria else {}
|
||||
min_images = criteria.get("min_images_per_species", 100)
|
||||
licenses = criteria.get("licenses")
|
||||
min_quality = criteria.get("min_quality")
|
||||
species_ids = criteria.get("species_ids")
|
||||
|
||||
# Build query for images
|
||||
query = db.query(Image).filter(Image.status == "downloaded")
|
||||
|
||||
if licenses:
|
||||
query = query.filter(Image.license.in_(licenses))
|
||||
|
||||
if min_quality:
|
||||
query = query.filter(Image.quality_score >= min_quality)
|
||||
|
||||
if species_ids:
|
||||
query = query.filter(Image.species_id.in_(species_ids))
|
||||
|
||||
# Group by species and filter by min count
|
||||
from sqlalchemy import func
|
||||
species_counts = db.query(
|
||||
Image.species_id,
|
||||
func.count(Image.id).label("count")
|
||||
).filter(Image.status == "downloaded").group_by(Image.species_id).all()
|
||||
|
||||
valid_species_ids = [s.species_id for s in species_counts if s.count >= min_images]
|
||||
|
||||
if species_ids:
|
||||
valid_species_ids = [s for s in valid_species_ids if s in species_ids]
|
||||
|
||||
if not valid_species_ids:
|
||||
export.status = "failed"
|
||||
export.error_message = "No species meet the criteria"
|
||||
export.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
return {"error": "No species meet the criteria"}
|
||||
|
||||
# Create export directory
|
||||
export_dir = Path(settings.exports_path) / f"export_{export_id}"
|
||||
train_dir = export_dir / "Training"
|
||||
test_dir = export_dir / "Testing"
|
||||
train_dir.mkdir(parents=True, exist_ok=True)
|
||||
test_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_images = 0
|
||||
species_count = 0
|
||||
|
||||
# Process each valid species
|
||||
for i, species_id in enumerate(valid_species_ids):
|
||||
species = db.query(Species).filter(Species.id == species_id).first()
|
||||
if not species:
|
||||
continue
|
||||
|
||||
# Get images for this species
|
||||
images_query = query.filter(Image.species_id == species_id)
|
||||
if licenses:
|
||||
images_query = images_query.filter(Image.license.in_(licenses))
|
||||
if min_quality:
|
||||
images_query = images_query.filter(Image.quality_score >= min_quality)
|
||||
|
||||
images = images_query.all()
|
||||
if len(images) < min_images:
|
||||
continue
|
||||
|
||||
species_count += 1
|
||||
|
||||
# Create species folders
|
||||
species_name = species.scientific_name.replace(" ", "_")
|
||||
(train_dir / species_name).mkdir(exist_ok=True)
|
||||
(test_dir / species_name).mkdir(exist_ok=True)
|
||||
|
||||
# Shuffle and split
|
||||
random.shuffle(images)
|
||||
split_idx = int(len(images) * export.train_split)
|
||||
train_images = images[:split_idx]
|
||||
test_images = images[split_idx:]
|
||||
|
||||
# Copy images
|
||||
for j, img in enumerate(train_images):
|
||||
if img.local_path and os.path.exists(img.local_path):
|
||||
ext = Path(img.local_path).suffix or ".jpg"
|
||||
dest = train_dir / species_name / f"img_{j:05d}{ext}"
|
||||
shutil.copy2(img.local_path, dest)
|
||||
total_images += 1
|
||||
|
||||
for j, img in enumerate(test_images):
|
||||
if img.local_path and os.path.exists(img.local_path):
|
||||
ext = Path(img.local_path).suffix or ".jpg"
|
||||
dest = test_dir / species_name / f"img_{j:05d}{ext}"
|
||||
shutil.copy2(img.local_path, dest)
|
||||
total_images += 1
|
||||
|
||||
# Update progress
|
||||
self.update_state(
|
||||
state="PROGRESS",
|
||||
meta={
|
||||
"current": i + 1,
|
||||
"total": len(valid_species_ids),
|
||||
"species": species.scientific_name,
|
||||
}
|
||||
)
|
||||
|
||||
# Create zip file
|
||||
zip_path = Path(settings.exports_path) / f"export_{export_id}.zip"
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
|
||||
for root, dirs, files in os.walk(export_dir):
|
||||
for file in files:
|
||||
file_path = Path(root) / file
|
||||
arcname = file_path.relative_to(export_dir)
|
||||
zipf.write(file_path, arcname)
|
||||
|
||||
# Clean up directory
|
||||
shutil.rmtree(export_dir)
|
||||
|
||||
# Update export record
|
||||
export.status = "completed"
|
||||
export.file_path = str(zip_path)
|
||||
export.file_size = zip_path.stat().st_size
|
||||
export.species_count = species_count
|
||||
export.image_count = total_images
|
||||
export.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"species_count": species_count,
|
||||
"image_count": total_images,
|
||||
"file_size": export.file_size,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
if export:
|
||||
export.status = "failed"
|
||||
export.error_message = str(e)
|
||||
export.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
Reference in New Issue
Block a user