Skip to content

Commit

Permalink
New data model (#238)
Browse files Browse the repository at this point in the history
* adapt cam creation

* adapt api login

* fix engine variables

* fix engine variables

* fix heartbit

* fix send alert

* fix send alert

* missing azimuth

* style

* fix deps

* add missing bboxes

* add default bbox

* add default bbox

* fix mypy

* fix test

* fix test

* fix ci

* fix default bbox

* secret

* fix empty

* fix empty

* fix empty
  • Loading branch information
MateoLostanlen authored Feb 3, 2025
1 parent 15d5399 commit 8ae55e8
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 131 deletions.
5 changes: 1 addition & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@ jobs:
- name: Run unittests
env:
API_URL: ${{ secrets.API_URL }}
API_LOGIN: ${{ secrets.API_LOGIN }}
API_PWD: ${{ secrets.API_PWD }}
LAT: 48.88
LON: 2.38
API_TOKEN: ${{ secrets.API_TOKEN }}
run: |
coverage run -m pytest tests/
coverage xml
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dynamic = ["version"]
dependencies = [
"onnxruntime==1.18.1",
"ncnn==1.0.20240410",
"pyroclient @ git+https://github.com/pyronear/pyro-api.git@767be30a781b52b29d68579d543e3f45ac8c4713#egg=pyroclient&subdirectory=client",
"pyroclient @ git+https://github.com/pyronear/pyro-api.git@main#egg=pyroclient&subdirectory=client",
"requests>=2.20.0,<3.0.0",
"tqdm>=4.62.0",
"huggingface_hub==0.23.1",
Expand Down
122 changes: 60 additions & 62 deletions pyroengine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ class Engine:
conf_thresh: confidence threshold to send an alert
api_url: url of the pyronear API
cam_creds: api credectials for each camera, the dictionary should be as the one in the example
latitude: device latitude
longitude: device longitude
alert_relaxation: number of consecutive positive detections required to send the first alert, and also
the number of consecutive negative detections before stopping the alert
frame_size: Resize frame to frame_size before sending it to the api in order to save bandwidth (H, W)
Expand All @@ -84,8 +82,6 @@ def __init__(
max_bbox_size: float = 0.4,
api_url: Optional[str] = None,
cam_creds: Optional[Dict[str, Dict[str, str]]] = None,
latitude: Optional[float] = None,
longitude: Optional[float] = None,
nb_consecutive_frames: int = 4,
frame_size: Optional[Tuple[int, int]] = None,
cache_backup_period: int = 60,
Expand All @@ -105,15 +101,13 @@ def __init__(
self.conf_thresh = conf_thresh

# API Setup
if isinstance(api_url, str):
assert isinstance(latitude, float) and isinstance(longitude, float) and isinstance(cam_creds, dict)
self.latitude = latitude
self.longitude = longitude
self.api_client = {}
self.api_client: dict[str, Any] = {}
if isinstance(api_url, str) and isinstance(cam_creds, dict):
# Instantiate clients for each camera
for _id, vals in cam_creds.items():
self.api_client[_id] = client.Client(api_url, vals["login"], vals["password"])
for _id, (camera_token, _) in cam_creds.items():
ip = _id.split("_")[0]
if ip not in self.api_client.keys():
self.api_client[ip] = client.Client(camera_token, api_url)

# Cache & relaxation
self.frame_saving_period = frame_saving_period
Expand All @@ -123,6 +117,7 @@ def __init__(
self.cache_backup_period = cache_backup_period
self.day_time_strategy = day_time_strategy
self.save_captured_frames = save_captured_frames
self.cam_creds = cam_creds

# Local backup
self._backup_size = backup_size
Expand Down Expand Up @@ -181,7 +176,7 @@ def _dump_cache(self) -> None:
"frame_path": str(self._cache.joinpath(f"pending_frame{idx}.jpg")),
"cam_id": info["cam_id"],
"ts": info["ts"],
"localization": info["localization"],
"bboxes": info["bboxes"],
}
)

Expand All @@ -204,7 +199,8 @@ def _load_cache(self) -> None:

def heartbeat(self, cam_id: str) -> Response:
"""Updates last ping of device"""
return self.api_client[cam_id].heartbeat()
ip = cam_id.split("_")[0]
return self.api_client[ip].heartbeat()

def _update_states(self, frame: Image.Image, preds: np.ndarray, cam_key: str) -> int:
"""Updates the detection states"""
Expand Down Expand Up @@ -244,10 +240,27 @@ def _update_states(self, frame: Image.Image, preds: np.ndarray, cam_key: str) ->
iou_match = [np.max(iou) > 0 for iou in ious]
output_predictions = preds[iou_match, :]

if len(output_predictions) == 0:
missing_bbox = combine_predictions
missing_bbox[:, -1] = 0

else:
# Add missing bboxes
ious = box_iou(combine_predictions[:, :4], output_predictions[:, :4])
missing_bbox = combine_predictions[ious[0] == 0, :]
if len(missing_bbox):
missing_bbox[:, -1] = 0
output_predictions = np.concatenate([output_predictions, missing_bbox])

# Limit bbox size for api
output_predictions = np.round(output_predictions, 3) # max 3 digit
output_predictions = output_predictions[:5, :] # max 5 bbox

# Add default bbox
if len(output_predictions) == 0:
output_predictions = np.zeros((1, 5))
output_predictions[:, 2:4] += 0.0001

self._states[cam_key]["last_predictions"].append(
(frame, preds, output_predictions.tolist(), datetime.now(timezone.utc).isoformat(), False)
)
Expand Down Expand Up @@ -295,12 +308,10 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:
# Alert
if conf > self.conf_thresh and len(self.api_client) > 0 and isinstance(cam_id, str):
# Save the alert in cache to avoid connection issues
for idx, (frame, preds, localization, ts, is_staged) in enumerate(
self._states[cam_key]["last_predictions"]
):
for idx, (frame, preds, bboxes, ts, is_staged) in enumerate(self._states[cam_key]["last_predictions"]):
if not is_staged:
self._stage_alert(frame, cam_id, ts, localization)
self._states[cam_key]["last_predictions"][idx] = frame, preds, localization, ts, True
self._stage_alert(frame, cam_id, ts, bboxes)
self._states[cam_key]["last_predictions"][idx] = frame, preds, bboxes, ts, True

# Check if it's time to backup pending alerts
ts = datetime.now(timezone.utc)
Expand All @@ -310,7 +321,7 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:

return float(conf)

def _stage_alert(self, frame: Image.Image, cam_id: str, ts: int, localization: list) -> None:
def _stage_alert(self, frame: Image.Image, cam_id: str, ts: int, bboxes: list) -> None:
# Store information in the queue
self._alerts.append(
{
Expand All @@ -319,53 +330,40 @@ def _stage_alert(self, frame: Image.Image, cam_id: str, ts: int, localization: l
"ts": ts,
"media_id": None,
"alert_id": None,
"localization": localization,
"bboxes": bboxes,
}
)

def _process_alerts(self) -> None:
for _ in range(len(self._alerts)):
# try to upload the oldest element
frame_info = self._alerts[0]
cam_id = frame_info["cam_id"]
logging.info(f"Camera '{cam_id}' - Sending alert from {frame_info['ts']}...")

# Save alert on device
self._local_backup(frame_info["frame"], cam_id)

try:
# Media creation
if not isinstance(self._alerts[0]["media_id"], int):
self._alerts[0]["media_id"] = self.api_client[cam_id].create_media_from_device().json()["id"]
# Alert creation
if not isinstance(self._alerts[0]["alert_id"], int):
self._alerts[0]["alert_id"] = (
self.api_client[cam_id]
.send_alert_from_device(
lat=self.latitude,
lon=self.longitude,
media_id=self._alerts[0]["media_id"],
localization=self._alerts[0]["localization"],
)
.json()["id"]
)
# Media upload
stream = io.BytesIO()
frame_info["frame"].save(stream, format="JPEG", quality=self.jpeg_quality)
response = self.api_client[cam_id].upload_media(
self._alerts[0]["media_id"],
media_data=stream.getvalue(),
)
# Force a KeyError if the request failed
response.json()["id"]
# Clear
self._alerts.popleft()
logging.info(f"Camera '{cam_id}' - alert sent")
stream.seek(0) # "Rewind" the stream to the beginning so we can read its content
except (KeyError, ConnectionError) as e:
logging.warning(f"Camera '{cam_id}' - unable to upload cache")
logging.warning(e)
break
if self.cam_creds is not None:
for _ in range(len(self._alerts)):
# try to upload the oldest element
frame_info = self._alerts[0]
cam_id = frame_info["cam_id"]
logging.info(f"Camera '{cam_id}' - Sending alert from {frame_info['ts']}...")

# Save alert on device
self._local_backup(frame_info["frame"], cam_id)

try:
# Detection creation
stream = io.BytesIO()
frame_info["frame"].save(stream, format="JPEG", quality=self.jpeg_quality)
bboxes = self._alerts[0]["bboxes"]
bboxes = [tuple(bboxe) for bboxe in bboxes]
_, cam_azimuth = self.cam_creds[cam_id]
ip = cam_id.split("_")[0]
response = self.api_client[ip].create_detection(stream.getvalue(), cam_azimuth, bboxes)
# Force a KeyError if the request failed
response.json()["id"]
# Clear
self._alerts.popleft()
logging.info(f"Camera '{cam_id}' - alert sent")
stream.seek(0) # "Rewind" the stream to the beginning so we can read its content
except (KeyError, ConnectionError) as e:
logging.warning(f"Camera '{cam_id}' - unable to upload cache")
logging.warning(e)
break

def _local_backup(self, img: Image.Image, cam_id: Optional[str], is_alert: bool = True) -> None:
"""Save image on device
Expand Down
2 changes: 2 additions & 0 deletions pyroengine/sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ def __init__(
password: str,
cam_type: str = "ptz",
cam_poses: Optional[List[int]] = None,
cam_azimuths: Optional[List[int]] = None,
protocol: str = "https",
):
self.ip_address = ip_address
self.username = username
self.password = password
self.cam_type = cam_type
self.cam_poses = cam_poses if cam_poses is not None else []
self.cam_azimuths = cam_azimuths if cam_azimuths is not None else []
self.protocol = protocol

if len(self.cam_poses):
Expand Down
Loading

0 comments on commit 8ae55e8

Please sign in to comment.