From 40cda4aa84d801f908d66730a74b379aee57aec1 Mon Sep 17 00:00:00 2001 From: RonanMorgan <49660557+RonanMorgan@users.noreply.github.com> Date: Fri, 1 Nov 2024 22:57:49 +0100 Subject: [PATCH] fix(detections): use correct bucket for URL fetching (#366) * fix error when using admin creds * add pagination parameters * 15 event by default * refactor(detections): merged SQL query into a single one * style(detections): silences mypy warnings * fix(detections): fix route logic * fix(detections): fix syntax typo --------- Co-authored-by: Ronan Co-authored-by: F-G Fernandez <26927750+frgfm@users.noreply.github.com> --- src/app/api/api_v1/endpoints/detections.py | 54 ++++++++++++++-------- src/app/api/dependencies.py | 2 +- src/app/crud/base.py | 8 ++++ 3 files changed, 44 insertions(+), 20 deletions(-) diff --git a/src/app/api/api_v1/endpoints/detections.py b/src/app/api/api_v1/endpoints/detections.py index 31c95e93..635f92a6 100644 --- a/src/app/api/api_v1/endpoints/detections.py +++ b/src/app/api/api_v1/endpoints/detections.py @@ -4,7 +4,7 @@ # See LICENSE or go to for full license details. from datetime import datetime -from typing import List, cast +from typing import List, Optional, cast from fastapi import ( APIRouter, @@ -19,6 +19,8 @@ UploadFile, status, ) +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession from app.api.dependencies import ( dispatch_webhook, @@ -30,6 +32,7 @@ ) from app.core.config import settings from app.crud import CameraCRUD, DetectionCRUD, OrganizationCRUD, WebhookCRUD +from app.db import get_session from app.models import Camera, Detection, Organization, Role, UserRole from app.schemas.detections import ( BOXES_PATTERN, @@ -154,32 +157,45 @@ async def fetch_detections( @router.get("/unlabeled/fromdate", status_code=status.HTTP_200_OK, summary="Fetch all the unlabeled detections") async def fetch_unlabeled_detections( from_date: datetime = Query(), - detections: DetectionCRUD = Depends(get_detection_crud), - cameras: CameraCRUD = Depends(get_camera_crud), + limit: Optional[int] = Query(15, description="Maximum number of detections to fetch"), + offset: Optional[int] = Query(0, description="Number of detections to skip before starting to fetch"), + session: AsyncSession = Depends(get_session), token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]), ) -> List[DetectionWithUrl]: - telemetry_client.capture(token_payload.sub, event="unacknowledged-fetch") - - bucket = s3_service.get_bucket(s3_service.resolve_bucket_name(token_payload.organization_id)) - - def get_url(detection: Detection) -> str: - return bucket.get_public_url(detection.bucket_key) + telemetry_client.capture(token_payload.sub, event="detections-fetch-unlabeled") if UserRole.ADMIN in token_payload.scopes: - all_unck_detections = await detections.fetch_all( - filter_pair=("is_wildfire", None), inequality_pair=("created_at", ">=", from_date) + # Custom SQL query to fetch detections along with corresponding organization_id + query = await session.exec( + select(Detection, Camera.organization_id) # type: ignore[attr-defined] + .join(Camera, Detection.camera_id == Camera.id) # type: ignore[arg-type] + .where(Detection.is_wildfire.is_(None)) # type: ignore[union-attr] + .where(Detection.created_at >= from_date) + .limit(limit) + .offset(offset) ) + results = query.all() + unlabeled_detections = [Detection(**detection.__dict__) for detection, _ in results] + urls = [ + s3_service.get_bucket(s3_service.resolve_bucket_name(org_id)).get_public_url(det.bucket_key) + for det, org_id in results + ] else: - org_cams = await cameras.fetch_all(filter_pair=("organization_id", token_payload.organization_id)) - all_unck_detections = await detections.fetch_all( - filter_pair=("is_wildfire", None), - in_pair=("camera_id", [camera.id for camera in org_cams]), - inequality_pair=("created_at", ">=", from_date), + query = await session.exec( + select(Detection) # type: ignore[attr-defined] + .join(Camera, Detection.camera_id == Camera.id) # type: ignore[arg-type] + .where(Detection.is_wildfire.is_(None)) # type: ignore[union-attr] + .where(Detection.created_at >= from_date) + .where(Camera.organization_id == token_payload.organization_id) + .limit(limit) + .offset(offset) ) + results = query.all() + unlabeled_detections = [Detection(**detection.__dict__) for detection in results] + bucket = s3_service.get_bucket(s3_service.resolve_bucket_name(token_payload.organization_id)) + urls = [bucket.get_public_url(detection.bucket_key) for detection in unlabeled_detections] - urls = (get_url(detection) for detection in all_unck_detections) - - return [DetectionWithUrl(**detection.model_dump(), url=url) for detection, url in zip(all_unck_detections, urls)] + return [DetectionWithUrl(**detection.model_dump(), url=url) for detection, url in zip(unlabeled_detections, urls)] @router.patch("/{detection_id}/label", status_code=status.HTTP_200_OK, summary="Label the nature of the detection") diff --git a/src/app/api/dependencies.py b/src/app/api/dependencies.py index 8739976a..0f393e38 100644 --- a/src/app/api/dependencies.py +++ b/src/app/api/dependencies.py @@ -115,7 +115,7 @@ async def get_current_user( async def dispatch_webhook(url: str, payload: BaseModel) -> None: - async with AsyncClient() as client: + async with AsyncClient(timeout=5) as client: try: response = await client.post(url, json=payload.model_dump_json()) response.raise_for_status() diff --git a/src/app/crud/base.py b/src/app/crud/base.py index 6e61dd20..1e2399fc 100644 --- a/src/app/crud/base.py +++ b/src/app/crud/base.py @@ -63,6 +63,8 @@ async def fetch_all( filter_pair: Union[Tuple[str, Any], None] = None, in_pair: Union[Tuple[str, List], None] = None, inequality_pair: Optional[Tuple[str, str, Any]] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, ) -> List[ModelType]: statement = select(self.model) # type: ignore[var-annotated] if isinstance(filter_pair, tuple): @@ -84,6 +86,12 @@ async def fetch_all( else: raise ValueError(f"Unsupported inequality operator: {op}") + if offset is not None: + statement = statement.offset(offset) + + if limit is not None: + statement = statement.limit(limit) + result = await self.session.exec(statement=statement) return [r for r in result]