import csv import io import json from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile, File from sqlalchemy.orm import Session from sqlalchemy import func, text from app.database import get_db from app.models import Species, Image from app.schemas.species import ( SpeciesCreate, SpeciesUpdate, SpeciesResponse, SpeciesListResponse, SpeciesImportResponse, ) router = APIRouter() def get_species_with_count(db: Session, species: Species) -> SpeciesResponse: """Get species response with image count.""" image_count = db.query(func.count(Image.id)).filter( Image.species_id == species.id, Image.status == "downloaded" ).scalar() return SpeciesResponse( id=species.id, scientific_name=species.scientific_name, common_name=species.common_name, genus=species.genus, family=species.family, created_at=species.created_at, image_count=image_count or 0, ) @router.get("", response_model=SpeciesListResponse) def list_species( page: int = Query(1, ge=1), page_size: int = Query(50, ge=1, le=500), search: Optional[str] = None, genus: Optional[str] = None, has_images: Optional[bool] = None, max_images: Optional[int] = Query(None, description="Filter species with less than N images"), min_images: Optional[int] = Query(None, description="Filter species with at least N images"), db: Session = Depends(get_db), ): """List species with pagination and filters. Filters: - search: Search by scientific or common name - genus: Filter by genus - has_images: True for species with images, False for species without - max_images: Filter species with fewer than N downloaded images - min_images: Filter species with at least N downloaded images """ # If filtering by image count, we need to use a subquery approach if max_images is not None or min_images is not None: # Build a subquery with image counts per species image_counts = ( db.query( Species.id.label("species_id"), func.count(Image.id).label("img_count") ) .outerjoin(Image, (Image.species_id == Species.id) & (Image.status == "downloaded")) .group_by(Species.id) .subquery() ) # Join species with their counts query = db.query(Species).join( image_counts, Species.id == image_counts.c.species_id ) if max_images is not None: query = query.filter(image_counts.c.img_count < max_images) if min_images is not None: query = query.filter(image_counts.c.img_count >= min_images) else: query = db.query(Species) if search: search_term = f"%{search}%" query = query.filter( (Species.scientific_name.ilike(search_term)) | (Species.common_name.ilike(search_term)) ) if genus: query = query.filter(Species.genus == genus) # Filter by whether species has downloaded images (only if not using min/max filters) if has_images is not None and max_images is None and min_images is None: # Get IDs of species that have at least one downloaded image species_with_images = ( db.query(Image.species_id) .filter(Image.status == "downloaded") .distinct() .subquery() ) if has_images: query = query.filter(Species.id.in_(db.query(species_with_images.c.species_id))) else: query = query.filter(~Species.id.in_(db.query(species_with_images.c.species_id))) total = query.count() pages = (total + page_size - 1) // page_size species_list = query.order_by(Species.scientific_name).offset( (page - 1) * page_size ).limit(page_size).all() # Fetch image counts in bulk for all species on this page species_ids = [s.id for s in species_list] if species_ids: count_query = db.query( Image.species_id, func.count(Image.id) ).filter( Image.species_id.in_(species_ids), Image.status == "downloaded" ).group_by(Image.species_id).all() count_map = {species_id: count for species_id, count in count_query} else: count_map = {} items = [ SpeciesResponse( id=s.id, scientific_name=s.scientific_name, common_name=s.common_name, genus=s.genus, family=s.family, created_at=s.created_at, image_count=count_map.get(s.id, 0), ) for s in species_list ] return SpeciesListResponse( items=items, total=total, page=page, page_size=page_size, pages=pages, ) @router.post("", response_model=SpeciesResponse) def create_species(species: SpeciesCreate, db: Session = Depends(get_db)): """Create a new species.""" existing = db.query(Species).filter( Species.scientific_name == species.scientific_name ).first() if existing: raise HTTPException(status_code=400, detail="Species already exists") # Auto-extract genus from scientific name if not provided genus = species.genus if not genus and " " in species.scientific_name: genus = species.scientific_name.split()[0] db_species = Species( scientific_name=species.scientific_name, common_name=species.common_name, genus=genus, family=species.family, ) db.add(db_species) db.commit() db.refresh(db_species) return get_species_with_count(db, db_species) @router.post("/import", response_model=SpeciesImportResponse) async def import_species( file: UploadFile = File(...), db: Session = Depends(get_db), ): """Import species from CSV file. Expected columns: scientific_name, common_name (optional), genus (optional), family (optional) """ if not file.filename.endswith(".csv"): raise HTTPException(status_code=400, detail="File must be a CSV") content = await file.read() text = content.decode("utf-8") reader = csv.DictReader(io.StringIO(text)) imported = 0 skipped = 0 errors = [] for row_num, row in enumerate(reader, start=2): scientific_name = row.get("scientific_name", "").strip() if not scientific_name: errors.append(f"Row {row_num}: Missing scientific_name") continue # Check if already exists existing = db.query(Species).filter( Species.scientific_name == scientific_name ).first() if existing: skipped += 1 continue # Auto-extract genus if not provided genus = row.get("genus", "").strip() if not genus and " " in scientific_name: genus = scientific_name.split()[0] try: species = Species( scientific_name=scientific_name, common_name=row.get("common_name", "").strip() or None, genus=genus or None, family=row.get("family", "").strip() or None, ) db.add(species) imported += 1 except Exception as e: errors.append(f"Row {row_num}: {str(e)}") db.commit() return SpeciesImportResponse( imported=imported, skipped=skipped, errors=errors[:10], # Limit error messages ) @router.post("/import-json", response_model=SpeciesImportResponse) async def import_species_json( file: UploadFile = File(...), db: Session = Depends(get_db), ): """Import species from JSON file. Expected format: {"plants": [{"scientific_name": "...", "common_names": [...], "family": "..."}]} """ if not file.filename.endswith(".json"): raise HTTPException(status_code=400, detail="File must be a JSON") content = await file.read() try: data = json.loads(content.decode("utf-8")) except json.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") plants = data.get("plants", []) if not plants: raise HTTPException(status_code=400, detail="No plants found in JSON") imported = 0 skipped = 0 errors = [] for idx, plant in enumerate(plants): scientific_name = plant.get("scientific_name", "").strip() if not scientific_name: errors.append(f"Plant {idx}: Missing scientific_name") continue # Check if already exists existing = db.query(Species).filter( Species.scientific_name == scientific_name ).first() if existing: skipped += 1 continue # Auto-extract genus from scientific name genus = None if " " in scientific_name: genus = scientific_name.split()[0] # Get first common name if array provided common_names = plant.get("common_names", []) common_name = common_names[0] if common_names else None try: species = Species( scientific_name=scientific_name, common_name=common_name, genus=genus, family=plant.get("family"), ) db.add(species) imported += 1 except Exception as e: errors.append(f"Plant {idx}: {str(e)}") db.commit() return SpeciesImportResponse( imported=imported, skipped=skipped, errors=errors[:10], ) @router.get("/{species_id}", response_model=SpeciesResponse) def get_species(species_id: int, db: Session = Depends(get_db)): """Get a species by ID.""" species = db.query(Species).filter(Species.id == species_id).first() if not species: raise HTTPException(status_code=404, detail="Species not found") return get_species_with_count(db, species) @router.put("/{species_id}", response_model=SpeciesResponse) def update_species( species_id: int, species_update: SpeciesUpdate, db: Session = Depends(get_db), ): """Update a species.""" species = db.query(Species).filter(Species.id == species_id).first() if not species: raise HTTPException(status_code=404, detail="Species not found") update_data = species_update.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(species, field, value) db.commit() db.refresh(species) return get_species_with_count(db, species) @router.delete("/{species_id}") def delete_species(species_id: int, db: Session = Depends(get_db)): """Delete a species and all its images.""" species = db.query(Species).filter(Species.id == species_id).first() if not species: raise HTTPException(status_code=404, detail="Species not found") db.delete(species) db.commit() return {"status": "deleted"} @router.get("/genera/list") def list_genera(db: Session = Depends(get_db)): """List all unique genera.""" genera = db.query(Species.genus).filter( Species.genus.isnot(None) ).distinct().order_by(Species.genus).all() return [g[0] for g in genera]