Initial commit — PlantGuideScraper project
This commit is contained in:
1
backend/app/api/__init__.py
Normal file
1
backend/app/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# API routes
|
||||
175
backend/app/api/exports.py
Normal file
175
backend/app/api/exports.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import Export, Image, Species
|
||||
from app.schemas.export import (
|
||||
ExportCreate,
|
||||
ExportResponse,
|
||||
ExportListResponse,
|
||||
ExportPreview,
|
||||
)
|
||||
from app.workers.export_tasks import generate_export
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=ExportListResponse)
|
||||
def list_exports(
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""List all exports."""
|
||||
total = db.query(Export).count()
|
||||
exports = db.query(Export).order_by(Export.created_at.desc()).limit(limit).all()
|
||||
|
||||
return ExportListResponse(
|
||||
items=[ExportResponse.model_validate(e) for e in exports],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/preview", response_model=ExportPreview)
|
||||
def preview_export(export: ExportCreate, db: Session = Depends(get_db)):
|
||||
"""Preview export without creating it."""
|
||||
criteria = export.filter_criteria
|
||||
min_images = criteria.min_images_per_species
|
||||
|
||||
# Build query
|
||||
query = db.query(Image).filter(Image.status == "downloaded")
|
||||
|
||||
if criteria.licenses:
|
||||
query = query.filter(Image.license.in_(criteria.licenses))
|
||||
|
||||
if criteria.min_quality:
|
||||
query = query.filter(Image.quality_score >= criteria.min_quality)
|
||||
|
||||
if criteria.species_ids:
|
||||
query = query.filter(Image.species_id.in_(criteria.species_ids))
|
||||
|
||||
# Count images per species
|
||||
species_counts = db.query(
|
||||
Image.species_id,
|
||||
func.count(Image.id).label("count")
|
||||
).filter(Image.status == "downloaded")
|
||||
|
||||
if criteria.licenses:
|
||||
species_counts = species_counts.filter(Image.license.in_(criteria.licenses))
|
||||
if criteria.min_quality:
|
||||
species_counts = species_counts.filter(Image.quality_score >= criteria.min_quality)
|
||||
if criteria.species_ids:
|
||||
species_counts = species_counts.filter(Image.species_id.in_(criteria.species_ids))
|
||||
|
||||
species_counts = species_counts.group_by(Image.species_id).all()
|
||||
|
||||
valid_species = [s for s in species_counts if s.count >= min_images]
|
||||
total_images = sum(s.count for s in valid_species)
|
||||
|
||||
# Estimate file size (rough: 50KB per image)
|
||||
estimated_size_mb = (total_images * 50) / 1024
|
||||
|
||||
return ExportPreview(
|
||||
species_count=len(valid_species),
|
||||
image_count=total_images,
|
||||
estimated_size_mb=estimated_size_mb,
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=ExportResponse)
|
||||
def create_export(export: ExportCreate, db: Session = Depends(get_db)):
|
||||
"""Create and start a new export job."""
|
||||
db_export = Export(
|
||||
name=export.name,
|
||||
filter_criteria=export.filter_criteria.model_dump_json(),
|
||||
train_split=export.train_split,
|
||||
status="pending",
|
||||
)
|
||||
db.add(db_export)
|
||||
db.commit()
|
||||
db.refresh(db_export)
|
||||
|
||||
# Start Celery task
|
||||
task = generate_export.delay(db_export.id)
|
||||
db_export.celery_task_id = task.id
|
||||
db.commit()
|
||||
|
||||
return ExportResponse.model_validate(db_export)
|
||||
|
||||
|
||||
@router.get("/{export_id}", response_model=ExportResponse)
|
||||
def get_export(export_id: int, db: Session = Depends(get_db)):
|
||||
"""Get export status."""
|
||||
export = db.query(Export).filter(Export.id == export_id).first()
|
||||
if not export:
|
||||
raise HTTPException(status_code=404, detail="Export not found")
|
||||
|
||||
return ExportResponse.model_validate(export)
|
||||
|
||||
|
||||
@router.get("/{export_id}/progress")
|
||||
def get_export_progress(export_id: int, db: Session = Depends(get_db)):
|
||||
"""Get real-time export progress."""
|
||||
from app.workers.celery_app import celery_app
|
||||
|
||||
export = db.query(Export).filter(Export.id == export_id).first()
|
||||
if not export:
|
||||
raise HTTPException(status_code=404, detail="Export not found")
|
||||
|
||||
if not export.celery_task_id:
|
||||
return {"status": export.status}
|
||||
|
||||
result = celery_app.AsyncResult(export.celery_task_id)
|
||||
|
||||
if result.state == "PROGRESS":
|
||||
meta = result.info
|
||||
return {
|
||||
"status": "generating",
|
||||
"current": meta.get("current", 0),
|
||||
"total": meta.get("total", 0),
|
||||
"current_species": meta.get("species", ""),
|
||||
}
|
||||
|
||||
return {"status": export.status}
|
||||
|
||||
|
||||
@router.get("/{export_id}/download")
|
||||
def download_export(export_id: int, db: Session = Depends(get_db)):
|
||||
"""Download export zip file."""
|
||||
export = db.query(Export).filter(Export.id == export_id).first()
|
||||
if not export:
|
||||
raise HTTPException(status_code=404, detail="Export not found")
|
||||
|
||||
if export.status != "completed":
|
||||
raise HTTPException(status_code=400, detail="Export not ready")
|
||||
|
||||
if not export.file_path or not os.path.exists(export.file_path):
|
||||
raise HTTPException(status_code=404, detail="Export file not found")
|
||||
|
||||
return FileResponse(
|
||||
export.file_path,
|
||||
media_type="application/zip",
|
||||
filename=f"{export.name}.zip",
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{export_id}")
|
||||
def delete_export(export_id: int, db: Session = Depends(get_db)):
|
||||
"""Delete an export and its file."""
|
||||
export = db.query(Export).filter(Export.id == export_id).first()
|
||||
if not export:
|
||||
raise HTTPException(status_code=404, detail="Export not found")
|
||||
|
||||
# Delete file if exists
|
||||
if export.file_path and os.path.exists(export.file_path):
|
||||
os.remove(export.file_path)
|
||||
|
||||
db.delete(export)
|
||||
db.commit()
|
||||
|
||||
return {"status": "deleted"}
|
||||
441
backend/app/api/images.py
Normal file
441
backend/app/api/images.py
Normal file
@@ -0,0 +1,441 @@
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import Image, Species
|
||||
from app.schemas.image import ImageResponse, ImageListResponse
|
||||
from app.config import get_settings
|
||||
|
||||
router = APIRouter()
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
@router.get("", response_model=ImageListResponse)
|
||||
def list_images(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
species_id: Optional[int] = None,
|
||||
source: Optional[str] = None,
|
||||
license: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
min_quality: Optional[float] = None,
|
||||
search: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""List images with pagination and filters."""
|
||||
# Use joinedload to fetch species in single query
|
||||
from sqlalchemy.orm import joinedload
|
||||
query = db.query(Image).options(joinedload(Image.species))
|
||||
|
||||
if species_id:
|
||||
query = query.filter(Image.species_id == species_id)
|
||||
|
||||
if source:
|
||||
query = query.filter(Image.source == source)
|
||||
|
||||
if license:
|
||||
query = query.filter(Image.license == license)
|
||||
|
||||
if status:
|
||||
query = query.filter(Image.status == status)
|
||||
|
||||
if min_quality:
|
||||
query = query.filter(Image.quality_score >= min_quality)
|
||||
|
||||
if search:
|
||||
search_term = f"%{search}%"
|
||||
query = query.join(Species).filter(
|
||||
(Species.scientific_name.ilike(search_term)) |
|
||||
(Species.common_name.ilike(search_term))
|
||||
)
|
||||
|
||||
# Use faster count for simple queries
|
||||
if not search:
|
||||
# Build count query without join for better performance
|
||||
count_query = db.query(func.count(Image.id))
|
||||
if species_id:
|
||||
count_query = count_query.filter(Image.species_id == species_id)
|
||||
if source:
|
||||
count_query = count_query.filter(Image.source == source)
|
||||
if license:
|
||||
count_query = count_query.filter(Image.license == license)
|
||||
if status:
|
||||
count_query = count_query.filter(Image.status == status)
|
||||
if min_quality:
|
||||
count_query = count_query.filter(Image.quality_score >= min_quality)
|
||||
total = count_query.scalar()
|
||||
else:
|
||||
total = query.count()
|
||||
|
||||
pages = (total + page_size - 1) // page_size
|
||||
|
||||
images = query.order_by(Image.created_at.desc()).offset(
|
||||
(page - 1) * page_size
|
||||
).limit(page_size).all()
|
||||
|
||||
items = [
|
||||
ImageResponse(
|
||||
id=img.id,
|
||||
species_id=img.species_id,
|
||||
species_name=img.species.scientific_name if img.species else None,
|
||||
source=img.source,
|
||||
source_id=img.source_id,
|
||||
url=img.url,
|
||||
local_path=img.local_path,
|
||||
license=img.license,
|
||||
attribution=img.attribution,
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
quality_score=img.quality_score,
|
||||
status=img.status,
|
||||
created_at=img.created_at,
|
||||
)
|
||||
for img in images
|
||||
]
|
||||
|
||||
return ImageListResponse(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
pages=pages,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sources")
|
||||
def list_sources(db: Session = Depends(get_db)):
|
||||
"""List all unique image sources."""
|
||||
sources = db.query(Image.source).distinct().all()
|
||||
return [s[0] for s in sources]
|
||||
|
||||
|
||||
@router.get("/licenses")
|
||||
def list_licenses(db: Session = Depends(get_db)):
|
||||
"""List all unique licenses."""
|
||||
licenses = db.query(Image.license).distinct().all()
|
||||
return [l[0] for l in licenses]
|
||||
|
||||
|
||||
@router.post("/process-pending")
|
||||
def process_pending_images(
|
||||
source: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Queue all pending images for download and processing."""
|
||||
from app.workers.quality_tasks import batch_process_pending_images
|
||||
|
||||
query = db.query(func.count(Image.id)).filter(Image.status == "pending")
|
||||
if source:
|
||||
query = query.filter(Image.source == source)
|
||||
pending_count = query.scalar()
|
||||
|
||||
task = batch_process_pending_images.delay(source=source)
|
||||
|
||||
return {
|
||||
"pending_count": pending_count,
|
||||
"task_id": task.id,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/process-pending/status/{task_id}")
|
||||
def process_pending_status(task_id: str):
|
||||
"""Check status of a batch processing task."""
|
||||
from app.workers.celery_app import celery_app
|
||||
|
||||
result = celery_app.AsyncResult(task_id)
|
||||
state = result.state # PENDING, STARTED, PROGRESS, SUCCESS, FAILURE
|
||||
|
||||
response = {"task_id": task_id, "state": state}
|
||||
|
||||
if state == "PROGRESS" and isinstance(result.info, dict):
|
||||
response["queued"] = result.info.get("queued", 0)
|
||||
response["total"] = result.info.get("total", 0)
|
||||
elif state == "SUCCESS" and isinstance(result.result, dict):
|
||||
response["queued"] = result.result.get("queued", 0)
|
||||
response["total"] = result.result.get("total", 0)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/{image_id}", response_model=ImageResponse)
|
||||
def get_image(image_id: int, db: Session = Depends(get_db)):
|
||||
"""Get an image by ID."""
|
||||
image = db.query(Image).filter(Image.id == image_id).first()
|
||||
if not image:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
return ImageResponse(
|
||||
id=image.id,
|
||||
species_id=image.species_id,
|
||||
species_name=image.species.scientific_name if image.species else None,
|
||||
source=image.source,
|
||||
source_id=image.source_id,
|
||||
url=image.url,
|
||||
local_path=image.local_path,
|
||||
license=image.license,
|
||||
attribution=image.attribution,
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
quality_score=image.quality_score,
|
||||
status=image.status,
|
||||
created_at=image.created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{image_id}/file")
|
||||
def get_image_file(image_id: int, db: Session = Depends(get_db)):
|
||||
"""Get the actual image file."""
|
||||
image = db.query(Image).filter(Image.id == image_id).first()
|
||||
if not image:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
if not image.local_path:
|
||||
raise HTTPException(status_code=404, detail="Image file not available")
|
||||
|
||||
return FileResponse(image.local_path, media_type="image/jpeg")
|
||||
|
||||
|
||||
@router.delete("/{image_id}")
|
||||
def delete_image(image_id: int, db: Session = Depends(get_db)):
|
||||
"""Delete an image."""
|
||||
image = db.query(Image).filter(Image.id == image_id).first()
|
||||
if not image:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
# Delete file if exists
|
||||
if image.local_path:
|
||||
import os
|
||||
if os.path.exists(image.local_path):
|
||||
os.remove(image.local_path)
|
||||
|
||||
db.delete(image)
|
||||
db.commit()
|
||||
|
||||
return {"status": "deleted"}
|
||||
|
||||
|
||||
@router.post("/bulk-delete")
|
||||
def bulk_delete_images(
|
||||
image_ids: List[int],
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Delete multiple images."""
|
||||
import os
|
||||
|
||||
images = db.query(Image).filter(Image.id.in_(image_ids)).all()
|
||||
|
||||
deleted = 0
|
||||
for image in images:
|
||||
if image.local_path and os.path.exists(image.local_path):
|
||||
os.remove(image.local_path)
|
||||
db.delete(image)
|
||||
deleted += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
return {"deleted": deleted}
|
||||
|
||||
|
||||
@router.get("/import/scan")
|
||||
def scan_imports(db: Session = Depends(get_db)):
|
||||
"""Scan the imports folder and return what can be imported.
|
||||
|
||||
Expected structure: imports/{source}/{species_name}/*.jpg
|
||||
"""
|
||||
imports_path = Path(settings.imports_path)
|
||||
|
||||
if not imports_path.exists():
|
||||
return {
|
||||
"available": False,
|
||||
"message": f"Imports folder not found: {imports_path}",
|
||||
"sources": [],
|
||||
"total_images": 0,
|
||||
"matched_species": 0,
|
||||
"unmatched_species": [],
|
||||
}
|
||||
|
||||
results = {
|
||||
"available": True,
|
||||
"sources": [],
|
||||
"total_images": 0,
|
||||
"matched_species": 0,
|
||||
"unmatched_species": [],
|
||||
}
|
||||
|
||||
# Get all species for matching
|
||||
species_map = {}
|
||||
for species in db.query(Species).all():
|
||||
# Map by scientific name with underscores and spaces
|
||||
species_map[species.scientific_name.lower()] = species
|
||||
species_map[species.scientific_name.replace(" ", "_").lower()] = species
|
||||
|
||||
seen_unmatched = set()
|
||||
|
||||
# Scan source folders
|
||||
for source_dir in imports_path.iterdir():
|
||||
if not source_dir.is_dir():
|
||||
continue
|
||||
|
||||
source_name = source_dir.name
|
||||
source_info = {
|
||||
"name": source_name,
|
||||
"species_count": 0,
|
||||
"image_count": 0,
|
||||
}
|
||||
|
||||
# Scan species folders within source
|
||||
for species_dir in source_dir.iterdir():
|
||||
if not species_dir.is_dir():
|
||||
continue
|
||||
|
||||
species_name = species_dir.name.replace("_", " ")
|
||||
species_key = species_name.lower()
|
||||
|
||||
# Count images
|
||||
image_files = list(species_dir.glob("*.jpg")) + \
|
||||
list(species_dir.glob("*.jpeg")) + \
|
||||
list(species_dir.glob("*.png"))
|
||||
|
||||
if not image_files:
|
||||
continue
|
||||
|
||||
source_info["image_count"] += len(image_files)
|
||||
results["total_images"] += len(image_files)
|
||||
|
||||
if species_key in species_map or species_dir.name.lower() in species_map:
|
||||
source_info["species_count"] += 1
|
||||
results["matched_species"] += 1
|
||||
else:
|
||||
if species_name not in seen_unmatched:
|
||||
seen_unmatched.add(species_name)
|
||||
results["unmatched_species"].append(species_name)
|
||||
|
||||
if source_info["image_count"] > 0:
|
||||
results["sources"].append(source_info)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@router.post("/import/run")
|
||||
def run_import(
|
||||
move_files: bool = Query(False, description="Move files instead of copy"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Import images from the imports folder.
|
||||
|
||||
Expected structure: imports/{source}/{species_name}/*.jpg
|
||||
Images are copied/moved to: images/{species_name}/{source}_{filename}
|
||||
"""
|
||||
imports_path = Path(settings.imports_path)
|
||||
images_path = Path(settings.images_path)
|
||||
|
||||
if not imports_path.exists():
|
||||
raise HTTPException(status_code=400, detail="Imports folder not found")
|
||||
|
||||
# Get all species for matching
|
||||
species_map = {}
|
||||
for species in db.query(Species).all():
|
||||
species_map[species.scientific_name.lower()] = species
|
||||
species_map[species.scientific_name.replace(" ", "_").lower()] = species
|
||||
|
||||
imported = 0
|
||||
skipped = 0
|
||||
errors = []
|
||||
|
||||
# Scan source folders
|
||||
for source_dir in imports_path.iterdir():
|
||||
if not source_dir.is_dir():
|
||||
continue
|
||||
|
||||
source_name = source_dir.name
|
||||
|
||||
# Scan species folders within source
|
||||
for species_dir in source_dir.iterdir():
|
||||
if not species_dir.is_dir():
|
||||
continue
|
||||
|
||||
species_name = species_dir.name.replace("_", " ")
|
||||
species_key = species_name.lower()
|
||||
|
||||
# Find matching species
|
||||
species = species_map.get(species_key) or species_map.get(species_dir.name.lower())
|
||||
if not species:
|
||||
continue
|
||||
|
||||
# Create target directory
|
||||
target_dir = images_path / species.scientific_name.replace(" ", "_")
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Process images
|
||||
image_files = list(species_dir.glob("*.jpg")) + \
|
||||
list(species_dir.glob("*.jpeg")) + \
|
||||
list(species_dir.glob("*.png"))
|
||||
|
||||
for img_file in image_files:
|
||||
try:
|
||||
# Generate unique filename
|
||||
ext = img_file.suffix.lower()
|
||||
if ext == ".jpeg":
|
||||
ext = ".jpg"
|
||||
new_filename = f"{source_name}_{img_file.stem}_{uuid.uuid4().hex[:8]}{ext}"
|
||||
target_path = target_dir / new_filename
|
||||
|
||||
# Check if already imported (by original filename pattern)
|
||||
existing = db.query(Image).filter(
|
||||
Image.species_id == species.id,
|
||||
Image.source == source_name,
|
||||
Image.source_id == img_file.stem,
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Get image dimensions
|
||||
try:
|
||||
with PILImage.open(img_file) as pil_img:
|
||||
width, height = pil_img.size
|
||||
except Exception:
|
||||
width, height = None, None
|
||||
|
||||
# Copy or move file
|
||||
if move_files:
|
||||
shutil.move(str(img_file), str(target_path))
|
||||
else:
|
||||
shutil.copy2(str(img_file), str(target_path))
|
||||
|
||||
# Create database record
|
||||
image = Image(
|
||||
species_id=species.id,
|
||||
source=source_name,
|
||||
source_id=img_file.stem,
|
||||
url=f"file://{img_file}",
|
||||
local_path=str(target_path),
|
||||
license="unknown",
|
||||
width=width,
|
||||
height=height,
|
||||
status="downloaded",
|
||||
)
|
||||
db.add(image)
|
||||
imported += 1
|
||||
|
||||
except Exception as e:
|
||||
errors.append(f"{img_file}: {str(e)}")
|
||||
|
||||
# Commit after each species to avoid large transactions
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"imported": imported,
|
||||
"skipped": skipped,
|
||||
"errors": errors[:20],
|
||||
}
|
||||
173
backend/app/api/jobs.py
Normal file
173
backend/app/api/jobs.py
Normal file
@@ -0,0 +1,173 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import Job
|
||||
from app.schemas.job import JobCreate, JobResponse, JobListResponse
|
||||
from app.workers.scrape_tasks import run_scrape_job
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=JobListResponse)
|
||||
def list_jobs(
|
||||
status: Optional[str] = None,
|
||||
source: Optional[str] = None,
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""List all jobs."""
|
||||
query = db.query(Job)
|
||||
|
||||
if status:
|
||||
query = query.filter(Job.status == status)
|
||||
|
||||
if source:
|
||||
query = query.filter(Job.source == source)
|
||||
|
||||
total = query.count()
|
||||
jobs = query.order_by(Job.created_at.desc()).limit(limit).all()
|
||||
|
||||
return JobListResponse(
|
||||
items=[JobResponse.model_validate(j) for j in jobs],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=JobResponse)
|
||||
def create_job(job: JobCreate, db: Session = Depends(get_db)):
|
||||
"""Create and start a new scrape job."""
|
||||
species_filter = None
|
||||
if job.species_ids:
|
||||
species_filter = json.dumps(job.species_ids)
|
||||
|
||||
db_job = Job(
|
||||
name=job.name,
|
||||
source=job.source,
|
||||
species_filter=species_filter,
|
||||
only_without_images=job.only_without_images,
|
||||
max_images=job.max_images,
|
||||
status="pending",
|
||||
)
|
||||
db.add(db_job)
|
||||
db.commit()
|
||||
db.refresh(db_job)
|
||||
|
||||
# Start the Celery task
|
||||
task = run_scrape_job.delay(db_job.id)
|
||||
db_job.celery_task_id = task.id
|
||||
db.commit()
|
||||
|
||||
return JobResponse.model_validate(db_job)
|
||||
|
||||
|
||||
@router.get("/{job_id}", response_model=JobResponse)
|
||||
def get_job(job_id: int, db: Session = Depends(get_db)):
|
||||
"""Get job status."""
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
return JobResponse.model_validate(job)
|
||||
|
||||
|
||||
@router.get("/{job_id}/progress")
|
||||
def get_job_progress(job_id: int, db: Session = Depends(get_db)):
|
||||
"""Get real-time job progress from Celery."""
|
||||
from app.workers.celery_app import celery_app
|
||||
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
if not job.celery_task_id:
|
||||
return {
|
||||
"status": job.status,
|
||||
"progress_current": job.progress_current,
|
||||
"progress_total": job.progress_total,
|
||||
}
|
||||
|
||||
# Get Celery task state
|
||||
result = celery_app.AsyncResult(job.celery_task_id)
|
||||
|
||||
if result.state == "PROGRESS":
|
||||
meta = result.info
|
||||
return {
|
||||
"status": "running",
|
||||
"progress_current": meta.get("current", 0),
|
||||
"progress_total": meta.get("total", 0),
|
||||
"current_species": meta.get("species", ""),
|
||||
}
|
||||
|
||||
return {
|
||||
"status": job.status,
|
||||
"progress_current": job.progress_current,
|
||||
"progress_total": job.progress_total,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{job_id}/pause")
|
||||
def pause_job(job_id: int, db: Session = Depends(get_db)):
|
||||
"""Pause a running job."""
|
||||
from app.workers.celery_app import celery_app
|
||||
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
if job.status != "running":
|
||||
raise HTTPException(status_code=400, detail="Job is not running")
|
||||
|
||||
# Revoke Celery task
|
||||
if job.celery_task_id:
|
||||
celery_app.control.revoke(job.celery_task_id, terminate=True)
|
||||
|
||||
job.status = "paused"
|
||||
db.commit()
|
||||
|
||||
return {"status": "paused"}
|
||||
|
||||
|
||||
@router.post("/{job_id}/resume")
|
||||
def resume_job(job_id: int, db: Session = Depends(get_db)):
|
||||
"""Resume a paused job."""
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
if job.status != "paused":
|
||||
raise HTTPException(status_code=400, detail="Job is not paused")
|
||||
|
||||
# Start new Celery task
|
||||
task = run_scrape_job.delay(job.id)
|
||||
job.celery_task_id = task.id
|
||||
job.status = "pending"
|
||||
db.commit()
|
||||
|
||||
return {"status": "resumed"}
|
||||
|
||||
|
||||
@router.post("/{job_id}/cancel")
|
||||
def cancel_job(job_id: int, db: Session = Depends(get_db)):
|
||||
"""Cancel a job."""
|
||||
from app.workers.celery_app import celery_app
|
||||
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
if job.status in ["completed", "failed"]:
|
||||
raise HTTPException(status_code=400, detail="Job already finished")
|
||||
|
||||
# Revoke Celery task
|
||||
if job.celery_task_id:
|
||||
celery_app.control.revoke(job.celery_task_id, terminate=True)
|
||||
|
||||
job.status = "failed"
|
||||
job.error_message = "Cancelled by user"
|
||||
db.commit()
|
||||
|
||||
return {"status": "cancelled"}
|
||||
198
backend/app/api/sources.py
Normal file
198
backend/app/api/sources.py
Normal file
@@ -0,0 +1,198 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import ApiKey
|
||||
from app.schemas.api_key import ApiKeyCreate, ApiKeyUpdate, ApiKeyResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Available sources
|
||||
# auth_type: "none" (no auth), "api_key" (single key), "api_key_secret" (key + secret), "oauth" (client_id + client_secret + access_token)
|
||||
# default_rate: safe default requests per second for each API
|
||||
AVAILABLE_SOURCES = [
|
||||
{"name": "gbif", "label": "GBIF", "requires_secret": False, "auth_type": "none", "default_rate": 1.0}, # Free, no auth required
|
||||
{"name": "inaturalist", "label": "iNaturalist", "requires_secret": True, "auth_type": "api_key_secret", "default_rate": 1.0}, # 60/min limit
|
||||
{"name": "flickr", "label": "Flickr", "requires_secret": True, "auth_type": "api_key_secret", "default_rate": 0.5}, # 3600/hr shared limit
|
||||
{"name": "wikimedia", "label": "Wikimedia Commons", "requires_secret": True, "auth_type": "oauth", "default_rate": 1.0}, # generous limits
|
||||
{"name": "trefle", "label": "Trefle.io", "requires_secret": False, "auth_type": "api_key", "default_rate": 1.0}, # 120/min limit
|
||||
{"name": "duckduckgo", "label": "DuckDuckGo", "requires_secret": False, "auth_type": "none", "default_rate": 0.5}, # Web search, no API key
|
||||
{"name": "bing", "label": "Bing Image Search", "requires_secret": False, "auth_type": "api_key", "default_rate": 3.0}, # Azure Cognitive Services
|
||||
]
|
||||
|
||||
|
||||
def mask_api_key(key: str) -> str:
|
||||
"""Mask API key, showing only last 4 characters."""
|
||||
if not key or len(key) <= 4:
|
||||
return "****"
|
||||
return "*" * (len(key) - 4) + key[-4:]
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_sources(db: Session = Depends(get_db)):
|
||||
"""List all available sources with their configuration status."""
|
||||
api_keys = {k.source: k for k in db.query(ApiKey).all()}
|
||||
|
||||
result = []
|
||||
for source in AVAILABLE_SOURCES:
|
||||
api_key = api_keys.get(source["name"])
|
||||
default_rate = source.get("default_rate", 1.0)
|
||||
result.append({
|
||||
"name": source["name"],
|
||||
"label": source["label"],
|
||||
"requires_secret": source["requires_secret"],
|
||||
"auth_type": source.get("auth_type", "api_key"),
|
||||
"configured": api_key is not None,
|
||||
"enabled": api_key.enabled if api_key else False,
|
||||
"api_key_masked": mask_api_key(api_key.api_key) if api_key else None,
|
||||
"has_secret": bool(api_key.api_secret) if api_key else False,
|
||||
"has_access_token": bool(getattr(api_key, 'access_token', None)) if api_key else False,
|
||||
"rate_limit_per_sec": api_key.rate_limit_per_sec if api_key else default_rate,
|
||||
"default_rate": default_rate,
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/{source}")
|
||||
def get_source(source: str, db: Session = Depends(get_db)):
|
||||
"""Get source configuration."""
|
||||
source_info = next((s for s in AVAILABLE_SOURCES if s["name"] == source), None)
|
||||
if not source_info:
|
||||
raise HTTPException(status_code=404, detail="Unknown source")
|
||||
|
||||
api_key = db.query(ApiKey).filter(ApiKey.source == source).first()
|
||||
default_rate = source_info.get("default_rate", 1.0)
|
||||
|
||||
return {
|
||||
"name": source_info["name"],
|
||||
"label": source_info["label"],
|
||||
"requires_secret": source_info["requires_secret"],
|
||||
"auth_type": source_info.get("auth_type", "api_key"),
|
||||
"configured": api_key is not None,
|
||||
"enabled": api_key.enabled if api_key else False,
|
||||
"api_key_masked": mask_api_key(api_key.api_key) if api_key else None,
|
||||
"has_secret": bool(api_key.api_secret) if api_key else False,
|
||||
"has_access_token": bool(getattr(api_key, 'access_token', None)) if api_key else False,
|
||||
"rate_limit_per_sec": api_key.rate_limit_per_sec if api_key else default_rate,
|
||||
"default_rate": default_rate,
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{source}")
|
||||
def update_source(
|
||||
source: str,
|
||||
config: ApiKeyCreate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Create or update source configuration."""
|
||||
source_info = next((s for s in AVAILABLE_SOURCES if s["name"] == source), None)
|
||||
if not source_info:
|
||||
raise HTTPException(status_code=404, detail="Unknown source")
|
||||
|
||||
# For sources that require auth, validate api_key is provided
|
||||
auth_type = source_info.get("auth_type", "api_key")
|
||||
if auth_type != "none" and not config.api_key:
|
||||
raise HTTPException(status_code=400, detail="API key is required for this source")
|
||||
|
||||
api_key = db.query(ApiKey).filter(ApiKey.source == source).first()
|
||||
|
||||
# Use placeholder for no-auth sources
|
||||
api_key_value = config.api_key or "no-auth"
|
||||
|
||||
if api_key:
|
||||
# Update existing
|
||||
api_key.api_key = api_key_value
|
||||
if config.api_secret:
|
||||
api_key.api_secret = config.api_secret
|
||||
if config.access_token:
|
||||
api_key.access_token = config.access_token
|
||||
api_key.rate_limit_per_sec = config.rate_limit_per_sec
|
||||
api_key.enabled = config.enabled
|
||||
else:
|
||||
# Create new
|
||||
api_key = ApiKey(
|
||||
source=source,
|
||||
api_key=api_key_value,
|
||||
api_secret=config.api_secret,
|
||||
access_token=config.access_token,
|
||||
rate_limit_per_sec=config.rate_limit_per_sec,
|
||||
enabled=config.enabled,
|
||||
)
|
||||
db.add(api_key)
|
||||
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
|
||||
return {
|
||||
"name": source,
|
||||
"configured": True,
|
||||
"enabled": api_key.enabled,
|
||||
"api_key_masked": mask_api_key(api_key.api_key) if auth_type != "none" else None,
|
||||
"has_secret": bool(api_key.api_secret),
|
||||
"has_access_token": bool(api_key.access_token),
|
||||
"rate_limit_per_sec": api_key.rate_limit_per_sec,
|
||||
}
|
||||
|
||||
|
||||
@router.patch("/{source}")
|
||||
def patch_source(
|
||||
source: str,
|
||||
config: ApiKeyUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Partially update source configuration."""
|
||||
api_key = db.query(ApiKey).filter(ApiKey.source == source).first()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="Source not configured")
|
||||
|
||||
update_data = config.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(api_key, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
|
||||
return {
|
||||
"name": source,
|
||||
"configured": True,
|
||||
"enabled": api_key.enabled,
|
||||
"api_key_masked": mask_api_key(api_key.api_key),
|
||||
"has_secret": bool(api_key.api_secret),
|
||||
"has_access_token": bool(api_key.access_token),
|
||||
"rate_limit_per_sec": api_key.rate_limit_per_sec,
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{source}")
|
||||
def delete_source(source: str, db: Session = Depends(get_db)):
|
||||
"""Delete source configuration."""
|
||||
api_key = db.query(ApiKey).filter(ApiKey.source == source).first()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="Source not configured")
|
||||
|
||||
db.delete(api_key)
|
||||
db.commit()
|
||||
|
||||
return {"status": "deleted"}
|
||||
|
||||
|
||||
@router.post("/{source}/test")
|
||||
def test_source(source: str, db: Session = Depends(get_db)):
|
||||
"""Test source API connection."""
|
||||
api_key = db.query(ApiKey).filter(ApiKey.source == source).first()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="Source not configured")
|
||||
|
||||
# Import and test the scraper
|
||||
from app.scrapers import get_scraper
|
||||
|
||||
scraper = get_scraper(source)
|
||||
if not scraper:
|
||||
raise HTTPException(status_code=400, detail="No scraper for this source")
|
||||
|
||||
try:
|
||||
result = scraper.test_connection(api_key)
|
||||
return {"status": "success", "message": result}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
366
backend/app/api/species.py
Normal file
366
backend/app/api/species.py
Normal file
@@ -0,0 +1,366 @@
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile, File
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func, text
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import Species, Image
|
||||
from app.schemas.species import (
|
||||
SpeciesCreate,
|
||||
SpeciesUpdate,
|
||||
SpeciesResponse,
|
||||
SpeciesListResponse,
|
||||
SpeciesImportResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def get_species_with_count(db: Session, species: Species) -> SpeciesResponse:
|
||||
"""Get species response with image count."""
|
||||
image_count = db.query(func.count(Image.id)).filter(
|
||||
Image.species_id == species.id,
|
||||
Image.status == "downloaded"
|
||||
).scalar()
|
||||
|
||||
return SpeciesResponse(
|
||||
id=species.id,
|
||||
scientific_name=species.scientific_name,
|
||||
common_name=species.common_name,
|
||||
genus=species.genus,
|
||||
family=species.family,
|
||||
created_at=species.created_at,
|
||||
image_count=image_count or 0,
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=SpeciesListResponse)
|
||||
def list_species(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=500),
|
||||
search: Optional[str] = None,
|
||||
genus: Optional[str] = None,
|
||||
has_images: Optional[bool] = None,
|
||||
max_images: Optional[int] = Query(None, description="Filter species with less than N images"),
|
||||
min_images: Optional[int] = Query(None, description="Filter species with at least N images"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""List species with pagination and filters.
|
||||
|
||||
Filters:
|
||||
- search: Search by scientific or common name
|
||||
- genus: Filter by genus
|
||||
- has_images: True for species with images, False for species without
|
||||
- max_images: Filter species with fewer than N downloaded images
|
||||
- min_images: Filter species with at least N downloaded images
|
||||
"""
|
||||
# If filtering by image count, we need to use a subquery approach
|
||||
if max_images is not None or min_images is not None:
|
||||
# Build a subquery with image counts per species
|
||||
image_counts = (
|
||||
db.query(
|
||||
Species.id.label("species_id"),
|
||||
func.count(Image.id).label("img_count")
|
||||
)
|
||||
.outerjoin(Image, (Image.species_id == Species.id) & (Image.status == "downloaded"))
|
||||
.group_by(Species.id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Join species with their counts
|
||||
query = db.query(Species).join(
|
||||
image_counts, Species.id == image_counts.c.species_id
|
||||
)
|
||||
|
||||
if max_images is not None:
|
||||
query = query.filter(image_counts.c.img_count < max_images)
|
||||
|
||||
if min_images is not None:
|
||||
query = query.filter(image_counts.c.img_count >= min_images)
|
||||
else:
|
||||
query = db.query(Species)
|
||||
|
||||
if search:
|
||||
search_term = f"%{search}%"
|
||||
query = query.filter(
|
||||
(Species.scientific_name.ilike(search_term)) |
|
||||
(Species.common_name.ilike(search_term))
|
||||
)
|
||||
|
||||
if genus:
|
||||
query = query.filter(Species.genus == genus)
|
||||
|
||||
# Filter by whether species has downloaded images (only if not using min/max filters)
|
||||
if has_images is not None and max_images is None and min_images is None:
|
||||
# Get IDs of species that have at least one downloaded image
|
||||
species_with_images = (
|
||||
db.query(Image.species_id)
|
||||
.filter(Image.status == "downloaded")
|
||||
.distinct()
|
||||
.subquery()
|
||||
)
|
||||
if has_images:
|
||||
query = query.filter(Species.id.in_(db.query(species_with_images.c.species_id)))
|
||||
else:
|
||||
query = query.filter(~Species.id.in_(db.query(species_with_images.c.species_id)))
|
||||
|
||||
total = query.count()
|
||||
pages = (total + page_size - 1) // page_size
|
||||
|
||||
species_list = query.order_by(Species.scientific_name).offset(
|
||||
(page - 1) * page_size
|
||||
).limit(page_size).all()
|
||||
|
||||
# Fetch image counts in bulk for all species on this page
|
||||
species_ids = [s.id for s in species_list]
|
||||
if species_ids:
|
||||
count_query = db.query(
|
||||
Image.species_id,
|
||||
func.count(Image.id)
|
||||
).filter(
|
||||
Image.species_id.in_(species_ids),
|
||||
Image.status == "downloaded"
|
||||
).group_by(Image.species_id).all()
|
||||
count_map = {species_id: count for species_id, count in count_query}
|
||||
else:
|
||||
count_map = {}
|
||||
|
||||
items = [
|
||||
SpeciesResponse(
|
||||
id=s.id,
|
||||
scientific_name=s.scientific_name,
|
||||
common_name=s.common_name,
|
||||
genus=s.genus,
|
||||
family=s.family,
|
||||
created_at=s.created_at,
|
||||
image_count=count_map.get(s.id, 0),
|
||||
)
|
||||
for s in species_list
|
||||
]
|
||||
|
||||
return SpeciesListResponse(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
pages=pages,
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=SpeciesResponse)
|
||||
def create_species(species: SpeciesCreate, db: Session = Depends(get_db)):
|
||||
"""Create a new species."""
|
||||
existing = db.query(Species).filter(
|
||||
Species.scientific_name == species.scientific_name
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="Species already exists")
|
||||
|
||||
# Auto-extract genus from scientific name if not provided
|
||||
genus = species.genus
|
||||
if not genus and " " in species.scientific_name:
|
||||
genus = species.scientific_name.split()[0]
|
||||
|
||||
db_species = Species(
|
||||
scientific_name=species.scientific_name,
|
||||
common_name=species.common_name,
|
||||
genus=genus,
|
||||
family=species.family,
|
||||
)
|
||||
db.add(db_species)
|
||||
db.commit()
|
||||
db.refresh(db_species)
|
||||
|
||||
return get_species_with_count(db, db_species)
|
||||
|
||||
|
||||
@router.post("/import", response_model=SpeciesImportResponse)
|
||||
async def import_species(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Import species from CSV file.
|
||||
|
||||
Expected columns: scientific_name, common_name (optional), genus (optional), family (optional)
|
||||
"""
|
||||
if not file.filename.endswith(".csv"):
|
||||
raise HTTPException(status_code=400, detail="File must be a CSV")
|
||||
|
||||
content = await file.read()
|
||||
text = content.decode("utf-8")
|
||||
|
||||
reader = csv.DictReader(io.StringIO(text))
|
||||
|
||||
imported = 0
|
||||
skipped = 0
|
||||
errors = []
|
||||
|
||||
for row_num, row in enumerate(reader, start=2):
|
||||
scientific_name = row.get("scientific_name", "").strip()
|
||||
if not scientific_name:
|
||||
errors.append(f"Row {row_num}: Missing scientific_name")
|
||||
continue
|
||||
|
||||
# Check if already exists
|
||||
existing = db.query(Species).filter(
|
||||
Species.scientific_name == scientific_name
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Auto-extract genus if not provided
|
||||
genus = row.get("genus", "").strip()
|
||||
if not genus and " " in scientific_name:
|
||||
genus = scientific_name.split()[0]
|
||||
|
||||
try:
|
||||
species = Species(
|
||||
scientific_name=scientific_name,
|
||||
common_name=row.get("common_name", "").strip() or None,
|
||||
genus=genus or None,
|
||||
family=row.get("family", "").strip() or None,
|
||||
)
|
||||
db.add(species)
|
||||
imported += 1
|
||||
except Exception as e:
|
||||
errors.append(f"Row {row_num}: {str(e)}")
|
||||
|
||||
db.commit()
|
||||
|
||||
return SpeciesImportResponse(
|
||||
imported=imported,
|
||||
skipped=skipped,
|
||||
errors=errors[:10], # Limit error messages
|
||||
)
|
||||
|
||||
|
||||
@router.post("/import-json", response_model=SpeciesImportResponse)
|
||||
async def import_species_json(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Import species from JSON file.
|
||||
|
||||
Expected format: {"plants": [{"scientific_name": "...", "common_names": [...], "family": "..."}]}
|
||||
"""
|
||||
if not file.filename.endswith(".json"):
|
||||
raise HTTPException(status_code=400, detail="File must be a JSON")
|
||||
|
||||
content = await file.read()
|
||||
try:
|
||||
data = json.loads(content.decode("utf-8"))
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")
|
||||
|
||||
plants = data.get("plants", [])
|
||||
if not plants:
|
||||
raise HTTPException(status_code=400, detail="No plants found in JSON")
|
||||
|
||||
imported = 0
|
||||
skipped = 0
|
||||
errors = []
|
||||
|
||||
for idx, plant in enumerate(plants):
|
||||
scientific_name = plant.get("scientific_name", "").strip()
|
||||
if not scientific_name:
|
||||
errors.append(f"Plant {idx}: Missing scientific_name")
|
||||
continue
|
||||
|
||||
# Check if already exists
|
||||
existing = db.query(Species).filter(
|
||||
Species.scientific_name == scientific_name
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Auto-extract genus from scientific name
|
||||
genus = None
|
||||
if " " in scientific_name:
|
||||
genus = scientific_name.split()[0]
|
||||
|
||||
# Get first common name if array provided
|
||||
common_names = plant.get("common_names", [])
|
||||
common_name = common_names[0] if common_names else None
|
||||
|
||||
try:
|
||||
species = Species(
|
||||
scientific_name=scientific_name,
|
||||
common_name=common_name,
|
||||
genus=genus,
|
||||
family=plant.get("family"),
|
||||
)
|
||||
db.add(species)
|
||||
imported += 1
|
||||
except Exception as e:
|
||||
errors.append(f"Plant {idx}: {str(e)}")
|
||||
|
||||
db.commit()
|
||||
|
||||
return SpeciesImportResponse(
|
||||
imported=imported,
|
||||
skipped=skipped,
|
||||
errors=errors[:10],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{species_id}", response_model=SpeciesResponse)
|
||||
def get_species(species_id: int, db: Session = Depends(get_db)):
|
||||
"""Get a species by ID."""
|
||||
species = db.query(Species).filter(Species.id == species_id).first()
|
||||
if not species:
|
||||
raise HTTPException(status_code=404, detail="Species not found")
|
||||
|
||||
return get_species_with_count(db, species)
|
||||
|
||||
|
||||
@router.put("/{species_id}", response_model=SpeciesResponse)
|
||||
def update_species(
|
||||
species_id: int,
|
||||
species_update: SpeciesUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Update a species."""
|
||||
species = db.query(Species).filter(Species.id == species_id).first()
|
||||
if not species:
|
||||
raise HTTPException(status_code=404, detail="Species not found")
|
||||
|
||||
update_data = species_update.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(species, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(species)
|
||||
|
||||
return get_species_with_count(db, species)
|
||||
|
||||
|
||||
@router.delete("/{species_id}")
|
||||
def delete_species(species_id: int, db: Session = Depends(get_db)):
|
||||
"""Delete a species and all its images."""
|
||||
species = db.query(Species).filter(Species.id == species_id).first()
|
||||
if not species:
|
||||
raise HTTPException(status_code=404, detail="Species not found")
|
||||
|
||||
db.delete(species)
|
||||
db.commit()
|
||||
|
||||
return {"status": "deleted"}
|
||||
|
||||
|
||||
@router.get("/genera/list")
|
||||
def list_genera(db: Session = Depends(get_db)):
|
||||
"""List all unique genera."""
|
||||
genera = db.query(Species.genus).filter(
|
||||
Species.genus.isnot(None)
|
||||
).distinct().order_by(Species.genus).all()
|
||||
|
||||
return [g[0] for g in genera]
|
||||
190
backend/app/api/stats.py
Normal file
190
backend/app/api/stats.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func, case
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import Species, Image, Job
|
||||
from app.models.cached_stats import CachedStats
|
||||
from app.schemas.stats import StatsResponse, SourceStats, LicenseStats, SpeciesStats, JobStats
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=StatsResponse)
|
||||
def get_stats(db: Session = Depends(get_db)):
|
||||
"""Get dashboard statistics from cache (updated every 60s by Celery)."""
|
||||
# Try to get cached stats
|
||||
cached = db.query(CachedStats).filter(CachedStats.key == "dashboard_stats").first()
|
||||
|
||||
if cached:
|
||||
data = json.loads(cached.value)
|
||||
return StatsResponse(
|
||||
total_species=data["total_species"],
|
||||
total_images=data["total_images"],
|
||||
images_downloaded=data["images_downloaded"],
|
||||
images_pending=data["images_pending"],
|
||||
images_rejected=data["images_rejected"],
|
||||
disk_usage_mb=data["disk_usage_mb"],
|
||||
sources=[SourceStats(**s) for s in data["sources"]],
|
||||
licenses=[LicenseStats(**l) for l in data["licenses"]],
|
||||
jobs=JobStats(**data["jobs"]),
|
||||
top_species=[SpeciesStats(**s) for s in data["top_species"]],
|
||||
under_represented=[SpeciesStats(**s) for s in data["under_represented"]],
|
||||
)
|
||||
|
||||
# No cache yet - return empty stats (Celery will populate soon)
|
||||
# This only happens on first startup before Celery runs
|
||||
return StatsResponse(
|
||||
total_species=0,
|
||||
total_images=0,
|
||||
images_downloaded=0,
|
||||
images_pending=0,
|
||||
images_rejected=0,
|
||||
disk_usage_mb=0.0,
|
||||
sources=[],
|
||||
licenses=[],
|
||||
jobs=JobStats(running=0, pending=0, completed=0, failed=0),
|
||||
top_species=[],
|
||||
under_represented=[],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
def refresh_stats_now(db: Session = Depends(get_db)):
|
||||
"""Manually trigger a stats refresh."""
|
||||
from app.workers.stats_tasks import refresh_stats
|
||||
refresh_stats.delay()
|
||||
return {"status": "refresh_queued"}
|
||||
|
||||
|
||||
@router.get("/sources")
|
||||
def get_source_stats(db: Session = Depends(get_db)):
|
||||
"""Get per-source breakdown."""
|
||||
stats = db.query(
|
||||
Image.source,
|
||||
func.count(Image.id).label("total"),
|
||||
func.sum(case((Image.status == "downloaded", 1), else_=0)).label("downloaded"),
|
||||
func.sum(case((Image.status == "pending", 1), else_=0)).label("pending"),
|
||||
func.sum(case((Image.status == "rejected", 1), else_=0)).label("rejected"),
|
||||
).group_by(Image.source).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"source": s.source,
|
||||
"total": s.total,
|
||||
"downloaded": s.downloaded or 0,
|
||||
"pending": s.pending or 0,
|
||||
"rejected": s.rejected or 0,
|
||||
}
|
||||
for s in stats
|
||||
]
|
||||
|
||||
|
||||
@router.get("/species")
|
||||
def get_species_stats(
|
||||
min_count: int = 0,
|
||||
max_count: int = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get per-species image counts."""
|
||||
query = db.query(
|
||||
Species.id,
|
||||
Species.scientific_name,
|
||||
Species.common_name,
|
||||
Species.genus,
|
||||
func.count(Image.id).label("image_count")
|
||||
).outerjoin(Image, (Image.species_id == Species.id) & (Image.status == "downloaded")
|
||||
).group_by(Species.id)
|
||||
|
||||
if min_count > 0:
|
||||
query = query.having(func.count(Image.id) >= min_count)
|
||||
|
||||
if max_count is not None:
|
||||
query = query.having(func.count(Image.id) <= max_count)
|
||||
|
||||
stats = query.order_by(func.count(Image.id).desc()).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": s.id,
|
||||
"scientific_name": s.scientific_name,
|
||||
"common_name": s.common_name,
|
||||
"genus": s.genus,
|
||||
"image_count": s.image_count,
|
||||
}
|
||||
for s in stats
|
||||
]
|
||||
|
||||
|
||||
@router.get("/distribution")
|
||||
def get_image_distribution(db: Session = Depends(get_db)):
|
||||
"""Get distribution of images per species for ML training assessment.
|
||||
|
||||
Returns counts of species at various image thresholds to help
|
||||
determine dataset quality for training image classifiers.
|
||||
"""
|
||||
from sqlalchemy import text
|
||||
|
||||
# Get image counts per species using optimized raw SQL
|
||||
distribution_sql = text("""
|
||||
WITH species_counts AS (
|
||||
SELECT
|
||||
s.id,
|
||||
COUNT(i.id) as cnt
|
||||
FROM species s
|
||||
LEFT JOIN images i ON i.species_id = s.id AND i.status = 'downloaded'
|
||||
GROUP BY s.id
|
||||
)
|
||||
SELECT
|
||||
COUNT(*) as total_species,
|
||||
SUM(CASE WHEN cnt = 0 THEN 1 ELSE 0 END) as with_0,
|
||||
SUM(CASE WHEN cnt >= 1 AND cnt < 10 THEN 1 ELSE 0 END) as with_1_9,
|
||||
SUM(CASE WHEN cnt >= 10 AND cnt < 25 THEN 1 ELSE 0 END) as with_10_24,
|
||||
SUM(CASE WHEN cnt >= 25 AND cnt < 50 THEN 1 ELSE 0 END) as with_25_49,
|
||||
SUM(CASE WHEN cnt >= 50 AND cnt < 100 THEN 1 ELSE 0 END) as with_50_99,
|
||||
SUM(CASE WHEN cnt >= 100 AND cnt < 200 THEN 1 ELSE 0 END) as with_100_199,
|
||||
SUM(CASE WHEN cnt >= 200 THEN 1 ELSE 0 END) as with_200_plus,
|
||||
SUM(CASE WHEN cnt >= 10 THEN 1 ELSE 0 END) as trainable_10,
|
||||
SUM(CASE WHEN cnt >= 25 THEN 1 ELSE 0 END) as trainable_25,
|
||||
SUM(CASE WHEN cnt >= 50 THEN 1 ELSE 0 END) as trainable_50,
|
||||
SUM(CASE WHEN cnt >= 100 THEN 1 ELSE 0 END) as trainable_100,
|
||||
AVG(cnt) as avg_images,
|
||||
MAX(cnt) as max_images,
|
||||
MIN(cnt) as min_images,
|
||||
SUM(cnt) as total_images
|
||||
FROM species_counts
|
||||
""")
|
||||
|
||||
result = db.execute(distribution_sql).fetchone()
|
||||
|
||||
return {
|
||||
"total_species": result[0] or 0,
|
||||
"distribution": {
|
||||
"0_images": result[1] or 0,
|
||||
"1_to_9": result[2] or 0,
|
||||
"10_to_24": result[3] or 0,
|
||||
"25_to_49": result[4] or 0,
|
||||
"50_to_99": result[5] or 0,
|
||||
"100_to_199": result[6] or 0,
|
||||
"200_plus": result[7] or 0,
|
||||
},
|
||||
"trainable_species": {
|
||||
"min_10_images": result[8] or 0,
|
||||
"min_25_images": result[9] or 0,
|
||||
"min_50_images": result[10] or 0,
|
||||
"min_100_images": result[11] or 0,
|
||||
},
|
||||
"summary": {
|
||||
"avg_images_per_species": round(result[12] or 0, 1),
|
||||
"max_images": result[13] or 0,
|
||||
"min_images": result[14] or 0,
|
||||
"total_downloaded_images": result[15] or 0,
|
||||
},
|
||||
"recommendations": {
|
||||
"for_basic_model": f"{result[8] or 0} species with 10+ images",
|
||||
"for_good_model": f"{result[10] or 0} species with 50+ images",
|
||||
"for_excellent_model": f"{result[11] or 0} species with 100+ images",
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user