Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(detections): add bboxes column to detection and a route to fetch unacknowledged ones #340

Merged
merged 65 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
d2eb60e
feat: first commit for Site table creation
Jun 17, 2024
d69085c
fix: error when launching backend
Jun 17, 2024
f4ec529
fix: sites tests
Jun 18, 2024
0f4026e
fix mypy
Jun 18, 2024
ceb525e
fix tests
Jun 18, 2024
b698b2c
fix: add site_id in user table
Jun 18, 2024
a4caaaa
feat: implement new security behavior using site_id
Jun 18, 2024
79a9867
feat : add CRUD endpoints for the site path + tests
Jun 19, 2024
40e6889
refactor: Site -> Organization
Jun 19, 2024
1495e21
feat: refactor security behavior
Jun 19, 2024
87d3ef9
feat: fix tests
Jun 21, 2024
ecf5ad2
fix e2e tests
Jun 21, 2024
0e647d8
fix client tests
Jun 21, 2024
65dec07
fix "role" column
Jun 21, 2024
ea9089f
fix: remove useless security in case of create_detection
Jun 24, 2024
76d46b7
fix: resolve first comments
Jul 3, 2024
926b0c3
fix error in detection endpoint
Jul 3, 2024
147804f
take feedback into account
Jul 11, 2024
185b2de
feat: add crud function to avoid for loop
Jul 11, 2024
1877874
feedback PR
Jul 11, 2024
4dda307
fix lint
Jul 12, 2024
12e4258
Merge branch 'main' into rs/add-site-table
Jul 13, 2024
10f9f7e
fix linting
Jul 13, 2024
0f15ca2
feat: add acknowledged endpoint
Jul 4, 2024
6c6fd43
feat: no need to have warning level of error
Jul 4, 2024
ecdf13c
feat: use is_wildfire instead of a new boolean
Jul 8, 2024
52b6abb
refactor fetch_all crud
Jul 11, 2024
522fc2b
fix: rm Exception
Jul 12, 2024
77473ec
fix mypy
Jul 12, 2024
357ae53
feat: start implemnting new payload from_date
Jul 13, 2024
9cb48cc
fir error date -> datetime
Jul 13, 2024
d1f3ac1
fix unlabeled endpoints
Jul 14, 2024
6c641b2
Merge branch 'main' into rs/add-acknowledged
Jul 15, 2024
be6574e
Add localization in Detection table (#342)
RonanMorgan Jul 19, 2024
4fc04d4
Send url with detection (#346)
RonanMorgan Jul 19, 2024
02c884b
clean up comments
Jul 19, 2024
2d3cde3
feat: refactor fecth_all function
Jul 19, 2024
c2f7b91
clean up comments
Jul 19, 2024
5da3d23
fix typing
Jul 19, 2024
43235d0
create DetectionWithUrl object
Jul 22, 2024
c9ab7c6
feat: use Form for regexp check
Jul 22, 2024
2fcf274
localization -> bboxes
Jul 22, 2024
c612268
fix: error with test client
Jul 22, 2024
f809a39
fix: error in client
Jul 22, 2024
acfc9c5
feat: don't forget to build client
Jul 22, 2024
1cff08a
revert(client): simplified imports
frgfm Aug 23, 2024
0063562
fix(client): update client management of detections
frgfm Aug 23, 2024
ddd18ab
Merge branch 'main' into rs/add-acknowledged
frgfm Aug 23, 2024
6cc2f52
fix(test): fix a dict unpacking
frgfm Aug 23, 2024
cc7d7d1
test(detections): updated test cases
frgfm Aug 23, 2024
78319c4
docs(client): update the docstring example
frgfm Aug 23, 2024
c6b12a8
test(e2e): add test case for unlabeled detection fetch
frgfm Aug 23, 2024
fc9b793
test(e2e): made e2e test more robust
frgfm Aug 23, 2024
b8cb19a
feat(detections): made bboxes non optional
frgfm Aug 23, 2024
bd9dfac
refactor(detection): improve bboxes forwarding
frgfm Aug 23, 2024
24f7030
test(e2e): fix bboxes forwarding
frgfm Aug 23, 2024
5b204bd
test(client): update bboxes tests
frgfm Aug 23, 2024
3413381
test(detections): improve testing cases
frgfm Aug 23, 2024
ceb9f93
fix(schemas): update the bboxes validator
frgfm Aug 23, 2024
54907dd
style(tests): remove unnecessary print
frgfm Aug 23, 2024
f82063f
refactor(detections): clean detection management
frgfm Aug 23, 2024
3910832
test(client): update test case for client
frgfm Aug 23, 2024
1b27e52
test(client): fix test cases
frgfm Aug 23, 2024
716e8ba
fix(client): refined string formatting
frgfm Aug 23, 2024
1a9491f
fix(detections): fix the string length
frgfm Aug 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ stop:
test:
poetry export -f requirements.txt --without-hashes --with test --output requirements.txt
docker compose -f docker-compose.dev.yml up -d --build --wait
- docker compose exec -T backend pytest --cov=app
- docker compose -f docker-compose.dev.yml exec -T backend pytest --cov=app
docker compose -f docker-compose.dev.yml down

build-client:
pip install -e client/.

# Run tests for the Python client
# the "-" are used to launch the next command even if a command fail
test-client:
test-client: build-client
poetry export -f requirements.txt --without-hashes --output requirements.txt
docker compose -f docker-compose.dev.yml up -d --build --wait
- cd client && pytest --cov=pyroclient tests/ && cd ..
Expand Down
59 changes: 56 additions & 3 deletions client/pyroclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Dict

from typing import Dict, List, Tuple
from urllib.parse import urljoin

import requests
Expand All @@ -29,6 +30,7 @@
"detections-create": "/detections",
"detections-label": "/detections/{det_id}/label",
"detections-fetch": "/detections",
"detections-fetch-unl": "/detections/unlabeled/fromdate",
"detections-url": "/detections/{det_id}/url",
#################
# ORGS
Expand All @@ -37,6 +39,32 @@
}


def _to_str(coord: float) -> str:
"""Format string conditionally"""
return f"{coord:.0f}" if coord == round(coord) else f"{coord:.3f}"


def _dump_bbox_to_json(
bboxes: List[Tuple[float, float, float, float, float]],
) -> str:
"""Performs a custom JSON dump for list of coordinates

Args:
bboxes: list of tuples where each tuple is a relative coordinate in order xmin, ymin, xmax, ymax, conf
Returns:
the JSON string dump with 3 decimal precision
"""
if any(coord > 1 or coord < 0 for bbox in bboxes for coord in bbox):
raise ValueError("coordinates are expected to be relative")

Check warning on line 58 in client/pyroclient/client.py

View check run for this annotation

Codecov / codecov/patch

client/pyroclient/client.py#L58

Added line #L58 was not covered by tests
if any(len(bbox) != 5 for bbox in bboxes):
raise ValueError("Each bbox is expected to be in format xmin, ymin, xmax, ymax, conf")

Check warning on line 60 in client/pyroclient/client.py

View check run for this annotation

Codecov / codecov/patch

client/pyroclient/client.py#L60

Added line #L60 was not covered by tests
box_strs = (
f"({_to_str(xmin)},{_to_str(ymin)},{_to_str(xmax)},{_to_str(ymax)},{_to_str(conf)})"
for xmin, ymin, xmax, ymax, conf in bboxes
)
return f"[{','.join(box_strs)}]"


class Client:
"""Isometric Python client for Pyronear wildfire detection API

Expand Down Expand Up @@ -92,25 +120,32 @@
self,
media: bytes,
azimuth: float,
bboxes: List[Tuple[float, float, float, float, float]],
) -> Response:
"""Notify the detection of a wildfire on the picture taken by a camera.

>>> from pyroclient import Client
>>> api_client = Client("MY_CAM_TOKEN")
>>> with open("path/to/my/file.ext", "rb") as f: data = f.read()
>>> response = api_client.create_detection(data, azimuth=124.2)
>>> response = api_client.create_detection(data, azimuth=124.2, bboxes=[(.1,.1,.5,.8,.5)])

Args:
media: byte data of the picture
azimuth: the azimuth of the camera when the picture was taken
bboxes: list of tuples where each tuple is a relative coordinate in order xmin, ymin, xmax, ymax, conf

Returns:
HTTP response
"""
if not isinstance(bboxes, (list, tuple)) or len(bboxes) == 0 or len(bboxes) > 5:
raise ValueError("bboxes must be a non-empty list of tuples with a maximum of 5 boxes")
return requests.post(
self.routes["detections-create"],
headers=self.headers,
data={"azimuth": azimuth},
data={
"azimuth": azimuth,
"bboxes": _dump_bbox_to_json(bboxes),
},
timeout=self.timeout,
files={"file": ("logo.png", media, "image/png")},
)
Expand Down Expand Up @@ -187,6 +222,24 @@
timeout=self.timeout,
)

def fetch_unlabeled_detections(self, from_date: str) -> Response:
"""List the detections accessible to the authenticated user

>>> from pyroclient import client
>>> api_client = Client("MY_USER_TOKEN")
>>> response = api_client.fetch_unacknowledged_detections("2023-07-04T00:00:00")

Returns:
HTTP response
"""
params = {"from_date": from_date}
return requests.get(
self.routes["detections-fetch-unl"],
headers=self.headers,
params=params,
timeout=self.timeout,
)

# ORGANIZATIONS

def fetch_organizations(self) -> Response:
Expand Down
2 changes: 2 additions & 0 deletions client/pyroclient/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from typing import Union

__all__ = ["HTTPRequestError"]


class HTTPRequestError(Exception):
def __init__(self, status_code: int, response_message: Union[str, None] = None) -> None:
Expand Down
19 changes: 14 additions & 5 deletions client/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,31 @@ def test_client_constructor(token, host, timeout, expected_error):
def test_cam_workflow(cam_token, mock_img):
cam_client = Client(cam_token, "http://localhost:5050", timeout=10)
assert cam_client.heartbeat().status_code == 200
response = cam_client.create_detection(mock_img, 123.2)
assert response.status_code == 201, print(response.__dict__)
# Check that adding bboxes works
with pytest.raises(ValueError, match="bboxes must be a non-empty list of tuples"):
cam_client.create_detection(mock_img, 123.2, None)
with pytest.raises(ValueError, match="bboxes must be a non-empty list of tuples"):
cam_client.create_detection(mock_img, 123.2, [])
response = cam_client.create_detection(mock_img, 123.2, [(0, 0, 1.0, 0.9, 0.5)])
assert response.status_code == 201, response.__dict__
response = cam_client.create_detection(mock_img, 123.2, [(0, 0, 1.0, 0.9, 0.5), (0.2, 0.2, 0.7, 0.7, 0.8)])
assert response.status_code == 201, response.__dict__
return response.json()["id"]


def test_agent_workflow(test_cam_workflow, agent_token):
# Agent workflow
agent_client = Client(agent_token, "http://localhost:5050", timeout=10)
response = agent_client.label_detection(test_cam_workflow, True)
assert response.status_code == 200, print(response.__dict__)
assert response.status_code == 200, response.__dict__


def test_user_workflow(test_cam_workflow, user_token):
# User workflow
user_client = Client(user_token, "http://localhost:5050", timeout=10)
response = user_client.get_detection_url(test_cam_workflow)
assert response.status_code == 200, print(response.__dict__)
assert response.status_code == 200, response.__dict__
response = user_client.fetch_detections()
assert response.status_code == 200, print(response.__dict__)
assert response.status_code == 200, response.__dict__
response = user_client.fetch_unlabeled_detections("2018-06-06T00:00:00")
assert response.status_code == 200, response.__dict__
2 changes: 2 additions & 0 deletions scripts/localstack/setup-s3.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
#!/usr/bin/env bash
awslocal s3 mb s3://bucket
echo -n "" > my_file
awslocal s3 cp my_file s3://bucket/my_file
11 changes: 8 additions & 3 deletions scripts/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,18 @@ def main(args):
# Take a picture
file_bytes = requests.get("https://pyronear.org/img/logo.png", timeout=5).content
# Create a detection
detection_id = requests.post(
response = requests.post(
f"{args.endpoint}/detections",
headers=cam_auth,
data={"azimuth": 45.6},
data={"azimuth": 45.6, "bboxes": "[(0.1,0.1,0.8,0.8,0.5)]"},
files={"file": ("logo.png", file_bytes, "image/png")},
timeout=5,
).json()["id"]
)
assert response.status_code == 201, response.text
detection_id = response.json()["id"]

# Fetch unlabeled detections
api_request("get", f"{args.endpoint}/detections/unlabeled/fromdate?from_date=2018-06-06T00:00:00", agent_auth)

# Acknowledge it
api_request("patch", f"{args.endpoint}/detections/{detection_id}/label", agent_auth, {"is_wildfire": True})
Expand Down
71 changes: 64 additions & 7 deletions src/app/api/api_v1/endpoints/detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,27 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

import asyncio
import hashlib
from datetime import datetime
from mimetypes import guess_extension
from typing import List, cast

import magic
from fastapi import APIRouter, Depends, File, Form, HTTPException, Path, Security, UploadFile, status
from fastapi import APIRouter, Depends, File, Form, HTTPException, Path, Query, Security, UploadFile, status

from app.api.dependencies import get_camera_crud, get_detection_crud, get_jwt
from app.core.config import settings
from app.crud import CameraCRUD, DetectionCRUD
from app.models import Camera, Detection, Role, UserRole
from app.schemas.detections import DetectionCreate, DetectionLabel, DetectionUrl
from app.schemas.detections import (
BOXES_PATTERN,
COMPILED_BOXES_PATTERN,
DetectionCreate,
DetectionLabel,
DetectionUrl,
DetectionWithUrl,
)
from app.schemas.login import TokenPayload
from app.services.storage import s3_bucket
from app.services.telemetry import telemetry_client
Expand All @@ -24,12 +33,27 @@

@router.post("/", status_code=status.HTTP_201_CREATED, summary="Register a new wildfire detection")
async def create_detection(
bboxes: str = Form(
...,
description="string representation of list of detection localizations, each represented as a tuple of relative coords (max 3 decimals) in order: xmin, ymin, xmax, ymax, conf",
pattern=BOXES_PATTERN,
min_length=2,
max_length=settings.MAX_BBOX_STR_LENGTH,
),
azimuth: float = Form(..., gt=0, lt=360, description="angle between north and direction in degrees"),
file: UploadFile = File(..., alias="file"),
detections: DetectionCRUD = Depends(get_detection_crud),
token_payload: TokenPayload = Security(get_jwt, scopes=[Role.CAMERA]),
) -> Detection:
telemetry_client.capture(f"camera|{token_payload.sub}", event="detections-create")

# Throw an error if the format is invalid and can't be captured by the regex
if any(box[0] >= box[2] or box[1] >= box[3] for box in COMPILED_BOXES_PATTERN.findall(bboxes)):
raise HTTPException(

Check warning on line 52 in src/app/api/api_v1/endpoints/detections.py

View check run for this annotation

Codecov / codecov/patch

src/app/api/api_v1/endpoints/detections.py#L52

Added line #L52 was not covered by tests
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="xmin & ymin are expected to be respectively smaller than xmax & ymax",
)

# Upload media
# Concatenate the first 8 chars (to avoid system interactions issues) of SHA256 hash with file extension
sha_hash = hashlib.sha256(file.file.read()).hexdigest()
Expand All @@ -40,7 +64,7 @@
# guess_extension will return none if this fails
extension = guess_extension(magic.from_buffer(file.file.read(), mime=True)) or ""
# Concatenate timestamp & hash
bucket_key = f"{datetime.utcnow().strftime('%Y%m%d%H%M%S')}-{sha_hash[:8]}{extension}"
bucket_key = f"{token_payload.sub}-{datetime.utcnow().strftime('%Y%m%d%H%M%S')}-{sha_hash[:8]}{extension}"
# Reset byte position of the file (cf. https://fastapi.tiangolo.com/tutorial/request-files/#uploadfile)
await file.seek(0)
# Failed upload
Expand All @@ -57,8 +81,10 @@
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Data was corrupted during upload",
)

return await detections.create(DetectionCreate(camera_id=token_payload.sub, bucket_key=bucket_key, azimuth=azimuth))
# Format the string
return await detections.create(
DetectionCreate(camera_id=token_payload.sub, bucket_key=bucket_key, azimuth=azimuth, bboxes=bboxes)
)


@router.get("/{detection_id}", status_code=status.HTTP_200_OK, summary="Fetch the information of a specific detection")
Expand Down Expand Up @@ -112,10 +138,41 @@
if UserRole.ADMIN in token_payload.scopes:
return [elt for elt in await detections.fetch_all()]

cameras_list = await cameras.fetch_all(("organization_id", token_payload.organization_id))
cameras_list = await cameras.fetch_all(filter_pair=("organization_id", token_payload.organization_id))
camera_ids = [camera.id for camera in cameras_list]

return await detections.get_in(camera_ids, "camera_id")
return await detections.fetch_all(in_pair=("camera_id", camera_ids))

Check warning on line 144 in src/app/api/api_v1/endpoints/detections.py

View check run for this annotation

Codecov / codecov/patch

src/app/api/api_v1/endpoints/detections.py#L144

Added line #L144 was not covered by tests


@router.get("/unlabeled/fromdate", status_code=status.HTTP_200_OK, summary="Fetch all the unlabeled detections")
async def fetch_unlabeled_detections(
from_date: datetime = Query(),
detections: DetectionCRUD = Depends(get_detection_crud),
cameras: CameraCRUD = Depends(get_camera_crud),
token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]),
) -> List[DetectionWithUrl]:
telemetry_client.capture(token_payload.sub, event="unacknowledged-fetch")

async def get_url(detection: Detection) -> str:
return await s3_bucket.get_public_url(detection.bucket_key)

if UserRole.ADMIN in token_payload.scopes:
all_unck_detections = await detections.fetch_all(
filter_pair=("is_wildfire", None), inequality_pair=("created_at", ">=", from_date)
)
else:
org_cams = await cameras.fetch_all(filter_pair=("organization_id", token_payload.organization_id))
all_unck_detections = await detections.fetch_all(
filter_pair=("is_wildfire", None),
in_pair=("camera_id", [camera.id for camera in org_cams]),
inequality_pair=("created_at", ">=", from_date),
)

# Launch all get_url calls in parallel
url_tasks = [get_url(detection) for detection in all_unck_detections]
urls = await asyncio.gather(*url_tasks)

Check warning on line 173 in src/app/api/api_v1/endpoints/detections.py

View check run for this annotation

Codecov / codecov/patch

src/app/api/api_v1/endpoints/detections.py#L173

Added line #L173 was not covered by tests

return [DetectionWithUrl(**detection.model_dump(), url=url) for detection, url in zip(all_unck_detections, urls)]


@router.patch("/{detection_id}/label", status_code=status.HTTP_200_OK, summary="Label the nature of the detection")
Expand Down
7 changes: 7 additions & 0 deletions src/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def sqlachmey_uri(cls, v: str) -> str:
JWT_UNLIMITED: int = 60 * 24 * 365
JWT_ALGORITHM: str = "HS256"

# DB conversion
MAX_BOXES_PER_DETECTION: int = 5
DECIMALS_PER_COORD: int = 3
MAX_BBOX_STR_LENGTH: int = (
2 + MAX_BOXES_PER_DETECTION * (2 + 5 * (2 + DECIMALS_PER_COORD) + 4 * 2) + (MAX_BOXES_PER_DETECTION - 1) * 2
)

# Storage
S3_BUCKET_NAME: str = os.environ["S3_BUCKET_NAME"]
S3_ACCESS_KEY: str = os.environ["S3_ACCESS_KEY"]
Expand Down
29 changes: 26 additions & 3 deletions src/app/crud/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

from typing import Any, Generic, List, Tuple, Type, TypeVar, Union, cast
from typing import Any, Generic, List, Optional, Tuple, Type, TypeVar, Union, cast

from fastapi import HTTPException, status
from pydantic import BaseModel
Expand Down Expand Up @@ -58,11 +58,34 @@
)
return entry

async def fetch_all(self, filter_pair: Union[Tuple[str, Any], None] = None) -> List[ModelType]:
async def fetch_all(
self,
filter_pair: Union[Tuple[str, Any], None] = None,
in_pair: Union[Tuple[str, List], None] = None,
inequality_pair: Optional[Tuple[str, str, Any]] = None,
) -> List[ModelType]:
statement = select(self.model) # type: ignore[var-annotated]
if isinstance(filter_pair, tuple):
statement = statement.where(getattr(self.model, filter_pair[0]) == filter_pair[1])
return await self.session.exec(statement=statement)

if isinstance(in_pair, tuple):
statement = statement.where(getattr(self.model, in_pair[0]).in_(in_pair[1]))

if isinstance(inequality_pair, tuple):
field, op, value = inequality_pair
if op == ">=":
statement = statement.where(getattr(self.model, field) >= value)
elif op == ">":
statement = statement.where(getattr(self.model, field) > value)
elif op == "<=":
statement = statement.where(getattr(self.model, field) <= value)
elif op == "<":
statement = statement.where(getattr(self.model, field) < value)

Check warning on line 83 in src/app/crud/base.py

View check run for this annotation

Codecov / codecov/patch

src/app/crud/base.py#L78-L83

Added lines #L78 - L83 were not covered by tests
else:
raise ValueError(f"Unsupported inequality operator: {op}")

Check warning on line 85 in src/app/crud/base.py

View check run for this annotation

Codecov / codecov/patch

src/app/crud/base.py#L85

Added line #L85 was not covered by tests

result = await self.session.exec(statement=statement)
return [r for r in result]

async def update(self, entry_id: int, payload: UpdateSchemaType) -> ModelType:
access = cast(ModelType, await self.get(entry_id, strict=True))
Expand Down
Loading