Files
2026-04-12 09:54:27 -05:00

367 lines
11 KiB
Python

import csv
import io
import json
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile, File
from sqlalchemy.orm import Session
from sqlalchemy import func, text
from app.database import get_db
from app.models import Species, Image
from app.schemas.species import (
SpeciesCreate,
SpeciesUpdate,
SpeciesResponse,
SpeciesListResponse,
SpeciesImportResponse,
)
router = APIRouter()
def get_species_with_count(db: Session, species: Species) -> SpeciesResponse:
"""Get species response with image count."""
image_count = db.query(func.count(Image.id)).filter(
Image.species_id == species.id,
Image.status == "downloaded"
).scalar()
return SpeciesResponse(
id=species.id,
scientific_name=species.scientific_name,
common_name=species.common_name,
genus=species.genus,
family=species.family,
created_at=species.created_at,
image_count=image_count or 0,
)
@router.get("", response_model=SpeciesListResponse)
def list_species(
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=500),
search: Optional[str] = None,
genus: Optional[str] = None,
has_images: Optional[bool] = None,
max_images: Optional[int] = Query(None, description="Filter species with less than N images"),
min_images: Optional[int] = Query(None, description="Filter species with at least N images"),
db: Session = Depends(get_db),
):
"""List species with pagination and filters.
Filters:
- search: Search by scientific or common name
- genus: Filter by genus
- has_images: True for species with images, False for species without
- max_images: Filter species with fewer than N downloaded images
- min_images: Filter species with at least N downloaded images
"""
# If filtering by image count, we need to use a subquery approach
if max_images is not None or min_images is not None:
# Build a subquery with image counts per species
image_counts = (
db.query(
Species.id.label("species_id"),
func.count(Image.id).label("img_count")
)
.outerjoin(Image, (Image.species_id == Species.id) & (Image.status == "downloaded"))
.group_by(Species.id)
.subquery()
)
# Join species with their counts
query = db.query(Species).join(
image_counts, Species.id == image_counts.c.species_id
)
if max_images is not None:
query = query.filter(image_counts.c.img_count < max_images)
if min_images is not None:
query = query.filter(image_counts.c.img_count >= min_images)
else:
query = db.query(Species)
if search:
search_term = f"%{search}%"
query = query.filter(
(Species.scientific_name.ilike(search_term)) |
(Species.common_name.ilike(search_term))
)
if genus:
query = query.filter(Species.genus == genus)
# Filter by whether species has downloaded images (only if not using min/max filters)
if has_images is not None and max_images is None and min_images is None:
# Get IDs of species that have at least one downloaded image
species_with_images = (
db.query(Image.species_id)
.filter(Image.status == "downloaded")
.distinct()
.subquery()
)
if has_images:
query = query.filter(Species.id.in_(db.query(species_with_images.c.species_id)))
else:
query = query.filter(~Species.id.in_(db.query(species_with_images.c.species_id)))
total = query.count()
pages = (total + page_size - 1) // page_size
species_list = query.order_by(Species.scientific_name).offset(
(page - 1) * page_size
).limit(page_size).all()
# Fetch image counts in bulk for all species on this page
species_ids = [s.id for s in species_list]
if species_ids:
count_query = db.query(
Image.species_id,
func.count(Image.id)
).filter(
Image.species_id.in_(species_ids),
Image.status == "downloaded"
).group_by(Image.species_id).all()
count_map = {species_id: count for species_id, count in count_query}
else:
count_map = {}
items = [
SpeciesResponse(
id=s.id,
scientific_name=s.scientific_name,
common_name=s.common_name,
genus=s.genus,
family=s.family,
created_at=s.created_at,
image_count=count_map.get(s.id, 0),
)
for s in species_list
]
return SpeciesListResponse(
items=items,
total=total,
page=page,
page_size=page_size,
pages=pages,
)
@router.post("", response_model=SpeciesResponse)
def create_species(species: SpeciesCreate, db: Session = Depends(get_db)):
"""Create a new species."""
existing = db.query(Species).filter(
Species.scientific_name == species.scientific_name
).first()
if existing:
raise HTTPException(status_code=400, detail="Species already exists")
# Auto-extract genus from scientific name if not provided
genus = species.genus
if not genus and " " in species.scientific_name:
genus = species.scientific_name.split()[0]
db_species = Species(
scientific_name=species.scientific_name,
common_name=species.common_name,
genus=genus,
family=species.family,
)
db.add(db_species)
db.commit()
db.refresh(db_species)
return get_species_with_count(db, db_species)
@router.post("/import", response_model=SpeciesImportResponse)
async def import_species(
file: UploadFile = File(...),
db: Session = Depends(get_db),
):
"""Import species from CSV file.
Expected columns: scientific_name, common_name (optional), genus (optional), family (optional)
"""
if not file.filename.endswith(".csv"):
raise HTTPException(status_code=400, detail="File must be a CSV")
content = await file.read()
text = content.decode("utf-8")
reader = csv.DictReader(io.StringIO(text))
imported = 0
skipped = 0
errors = []
for row_num, row in enumerate(reader, start=2):
scientific_name = row.get("scientific_name", "").strip()
if not scientific_name:
errors.append(f"Row {row_num}: Missing scientific_name")
continue
# Check if already exists
existing = db.query(Species).filter(
Species.scientific_name == scientific_name
).first()
if existing:
skipped += 1
continue
# Auto-extract genus if not provided
genus = row.get("genus", "").strip()
if not genus and " " in scientific_name:
genus = scientific_name.split()[0]
try:
species = Species(
scientific_name=scientific_name,
common_name=row.get("common_name", "").strip() or None,
genus=genus or None,
family=row.get("family", "").strip() or None,
)
db.add(species)
imported += 1
except Exception as e:
errors.append(f"Row {row_num}: {str(e)}")
db.commit()
return SpeciesImportResponse(
imported=imported,
skipped=skipped,
errors=errors[:10], # Limit error messages
)
@router.post("/import-json", response_model=SpeciesImportResponse)
async def import_species_json(
file: UploadFile = File(...),
db: Session = Depends(get_db),
):
"""Import species from JSON file.
Expected format: {"plants": [{"scientific_name": "...", "common_names": [...], "family": "..."}]}
"""
if not file.filename.endswith(".json"):
raise HTTPException(status_code=400, detail="File must be a JSON")
content = await file.read()
try:
data = json.loads(content.decode("utf-8"))
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")
plants = data.get("plants", [])
if not plants:
raise HTTPException(status_code=400, detail="No plants found in JSON")
imported = 0
skipped = 0
errors = []
for idx, plant in enumerate(plants):
scientific_name = plant.get("scientific_name", "").strip()
if not scientific_name:
errors.append(f"Plant {idx}: Missing scientific_name")
continue
# Check if already exists
existing = db.query(Species).filter(
Species.scientific_name == scientific_name
).first()
if existing:
skipped += 1
continue
# Auto-extract genus from scientific name
genus = None
if " " in scientific_name:
genus = scientific_name.split()[0]
# Get first common name if array provided
common_names = plant.get("common_names", [])
common_name = common_names[0] if common_names else None
try:
species = Species(
scientific_name=scientific_name,
common_name=common_name,
genus=genus,
family=plant.get("family"),
)
db.add(species)
imported += 1
except Exception as e:
errors.append(f"Plant {idx}: {str(e)}")
db.commit()
return SpeciesImportResponse(
imported=imported,
skipped=skipped,
errors=errors[:10],
)
@router.get("/{species_id}", response_model=SpeciesResponse)
def get_species(species_id: int, db: Session = Depends(get_db)):
"""Get a species by ID."""
species = db.query(Species).filter(Species.id == species_id).first()
if not species:
raise HTTPException(status_code=404, detail="Species not found")
return get_species_with_count(db, species)
@router.put("/{species_id}", response_model=SpeciesResponse)
def update_species(
species_id: int,
species_update: SpeciesUpdate,
db: Session = Depends(get_db),
):
"""Update a species."""
species = db.query(Species).filter(Species.id == species_id).first()
if not species:
raise HTTPException(status_code=404, detail="Species not found")
update_data = species_update.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(species, field, value)
db.commit()
db.refresh(species)
return get_species_with_count(db, species)
@router.delete("/{species_id}")
def delete_species(species_id: int, db: Session = Depends(get_db)):
"""Delete a species and all its images."""
species = db.query(Species).filter(Species.id == species_id).first()
if not species:
raise HTTPException(status_code=404, detail="Species not found")
db.delete(species)
db.commit()
return {"status": "deleted"}
@router.get("/genera/list")
def list_genera(db: Session = Depends(get_db)):
"""List all unique genera."""
genera = db.query(Species.genus).filter(
Species.genus.isnot(None)
).distinct().order_by(Species.genus).all()
return [g[0] for g in genera]