From 7d7c400f0fb244c84135cb40e401bd3a48c14277 Mon Sep 17 00:00:00 2001 From: MitchellJC <81349046+MitchellJC@users.noreply.github.com> Date: Fri, 13 Sep 2024 13:34:25 +1000 Subject: [PATCH 1/2] feat: made PostureProcess accept modular FrameCapturer type --- client/models/pose_detection/routines.py | 35 ++++++++++++++---------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/client/models/pose_detection/routines.py b/client/models/pose_detection/routines.py index cc6b283..cd0a9e7 100644 --- a/client/models/pose_detection/routines.py +++ b/client/models/pose_detection/routines.py @@ -6,7 +6,7 @@ import multiprocessing as multp import multiprocessing.connection as connection from importlib import resources -from typing import Callable, Mapping +from typing import Callable, Mapping, Optional, Type from datetime import datetime import cv2 @@ -27,6 +27,7 @@ from models.pose_detection.landmarking import AnnotatedImage, display_landmarking from models.pose_detection.camera import is_camera_aligned from models.pose_detection.classification import posture_classify +from models.pose_detection.frame_capturer import FrameCapturer, OpenCVCapturer POSE_LANDMARKER_FILE = resources.files("models.resources").joinpath( "pose_landmarker_lite.task" @@ -44,12 +45,14 @@ class PostureProcess: API. """ - def __init__(self) -> None: + def __init__(self, frame_capturer: Type[FrameCapturer] = OpenCVCapturer) -> None: """Create a new process which loads the MediaPipe Pose model and runs periodic posture tracking. This initializer blocks until the model is loaded. """ self._parent_con, child_con = multp.Pipe() - self._process = multp.Process(target=_run_posture, args=(child_con,)) + + args = (child_con, frame_capturer) + self._process = multp.Process(target=_run_posture, args=args) self._process.start() # Blocks until something is recieved from child @@ -96,7 +99,7 @@ def __init__( self._start_time = time.time() self._period_start = datetime.now() - self._video_capture = cv2.VideoCapture(0) + self.frame_capturer: Optional[FrameCapturer] = None @property def user_id(self) -> int: @@ -112,12 +115,15 @@ def track_posture(self) -> None: """Get frame from video capture device and process with pose model, then posture algorithm. Saves posture data to database periodically. """ + if self.frame_capturer is None: + raise ValueError( + "Please set a frame capturer before trying to track posture" + ) + if self.user_id == NO_USER: return - success, frame = self._video_capture.read() - if not success: - return + frame, _ = self.frame_capturer.get_frame() mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame) result = self.detect(mp_image) @@ -151,10 +157,6 @@ def _new_period(self) -> None: self._start_time = time.time() self._period_start = datetime.now() - def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback) -> None: - self._video_capture.release() - super().__exit__(unused_exc_type, unused_exc_value, unused_traceback) - class DebugPostureTracker(PoseLandmarker): """Handles routines for a Debugging Posture Tracker. @@ -195,7 +197,7 @@ def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback) -> None: super().__exit__(unused_exc_type, unused_exc_value, unused_traceback) -def create_posture_tracker() -> PostureTracker: +def create_posture_tracker(frame_capturer: FrameCapturer) -> PostureTracker: """Handles config of single image frame input and model loading. Returns: @@ -207,6 +209,7 @@ def create_posture_tracker() -> PostureTracker: ) tracker = PostureTracker.create_from_options(options) + tracker.frame_capturer = frame_capturer return tracker @@ -230,8 +233,12 @@ def create_debug_posture_tracker() -> DebugPostureTracker: return tracker -def _run_posture(con: connection.Connection) -> None: - with create_posture_tracker() as tracker: +def _run_posture( + con: connection.Connection, frame_capturer: Type[FrameCapturer] +) -> None: + # Instantiate frame capturer in subprocess to avoid pickling errors. + frame_capturer_obj = frame_capturer() + with create_posture_tracker(frame_capturer_obj) as tracker: con.send(True) while True: # Handle message from parent From ecb8fd38841556855d439af87b256800cd52d64a Mon Sep 17 00:00:00 2001 From: MitchellJC <81349046+MitchellJC@users.noreply.github.com> Date: Fri, 13 Sep 2024 13:46:15 +1000 Subject: [PATCH 2/2] doc: Added docstrings for new FrameCapturer object in pose_detection routines --- client/models/pose_detection/routines.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/client/models/pose_detection/routines.py b/client/models/pose_detection/routines.py index cd0a9e7..8f9f998 100644 --- a/client/models/pose_detection/routines.py +++ b/client/models/pose_detection/routines.py @@ -48,6 +48,9 @@ class PostureProcess: def __init__(self, frame_capturer: Type[FrameCapturer] = OpenCVCapturer) -> None: """Create a new process which loads the MediaPipe Pose model and runs periodic posture tracking. This initializer blocks until the model is loaded. + + Args: + frame_capturer: Class reference to capturer for child process to construct """ self._parent_con, child_con = multp.Pipe() @@ -82,6 +85,7 @@ class PostureTracker(PoseLandmarker): Attributes: user_id: Id for the user currently being tracked. + frame_capturer: Captures frames to be tracked by model. """ def __init__( @@ -91,6 +95,7 @@ def __init__( packet_callback: Callable[[Mapping[str, Packet]], None], ) -> None: super().__init__(graph_config, running_mode, packet_callback) + self.frame_capturer: Optional[FrameCapturer] = None self._user_id = NO_USER @@ -99,8 +104,6 @@ def __init__( self._start_time = time.time() self._period_start = datetime.now() - self.frame_capturer: Optional[FrameCapturer] = None - @property def user_id(self) -> int: """Currently tracked user. Data will be associated with this user in the database.""" @@ -200,6 +203,9 @@ def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback) -> None: def create_posture_tracker(frame_capturer: FrameCapturer) -> PostureTracker: """Handles config of single image frame input and model loading. + Args: + frame_capturer: Interface for posture tracker to get frames for to feed into posture model. + Returns: Tracker object which acts as context manager. """