diff --git a/client/models/pose_detection/routines.py b/client/models/pose_detection/routines.py index cc6b283..8f9f998 100644 --- a/client/models/pose_detection/routines.py +++ b/client/models/pose_detection/routines.py @@ -6,7 +6,7 @@ import multiprocessing as multp import multiprocessing.connection as connection from importlib import resources -from typing import Callable, Mapping +from typing import Callable, Mapping, Optional, Type from datetime import datetime import cv2 @@ -27,6 +27,7 @@ from models.pose_detection.landmarking import AnnotatedImage, display_landmarking from models.pose_detection.camera import is_camera_aligned from models.pose_detection.classification import posture_classify +from models.pose_detection.frame_capturer import FrameCapturer, OpenCVCapturer POSE_LANDMARKER_FILE = resources.files("models.resources").joinpath( "pose_landmarker_lite.task" @@ -44,12 +45,17 @@ class PostureProcess: API. """ - def __init__(self) -> None: + def __init__(self, frame_capturer: Type[FrameCapturer] = OpenCVCapturer) -> None: """Create a new process which loads the MediaPipe Pose model and runs periodic posture tracking. This initializer blocks until the model is loaded. + + Args: + frame_capturer: Class reference to capturer for child process to construct """ self._parent_con, child_con = multp.Pipe() - self._process = multp.Process(target=_run_posture, args=(child_con,)) + + args = (child_con, frame_capturer) + self._process = multp.Process(target=_run_posture, args=args) self._process.start() # Blocks until something is recieved from child @@ -79,6 +85,7 @@ class PostureTracker(PoseLandmarker): Attributes: user_id: Id for the user currently being tracked. + frame_capturer: Captures frames to be tracked by model. """ def __init__( @@ -88,6 +95,7 @@ def __init__( packet_callback: Callable[[Mapping[str, Packet]], None], ) -> None: super().__init__(graph_config, running_mode, packet_callback) + self.frame_capturer: Optional[FrameCapturer] = None self._user_id = NO_USER @@ -96,8 +104,6 @@ def __init__( self._start_time = time.time() self._period_start = datetime.now() - self._video_capture = cv2.VideoCapture(0) - @property def user_id(self) -> int: """Currently tracked user. Data will be associated with this user in the database.""" @@ -112,12 +118,15 @@ def track_posture(self) -> None: """Get frame from video capture device and process with pose model, then posture algorithm. Saves posture data to database periodically. """ + if self.frame_capturer is None: + raise ValueError( + "Please set a frame capturer before trying to track posture" + ) + if self.user_id == NO_USER: return - success, frame = self._video_capture.read() - if not success: - return + frame, _ = self.frame_capturer.get_frame() mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame) result = self.detect(mp_image) @@ -151,10 +160,6 @@ def _new_period(self) -> None: self._start_time = time.time() self._period_start = datetime.now() - def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback) -> None: - self._video_capture.release() - super().__exit__(unused_exc_type, unused_exc_value, unused_traceback) - class DebugPostureTracker(PoseLandmarker): """Handles routines for a Debugging Posture Tracker. @@ -195,9 +200,12 @@ def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback) -> None: super().__exit__(unused_exc_type, unused_exc_value, unused_traceback) -def create_posture_tracker() -> PostureTracker: +def create_posture_tracker(frame_capturer: FrameCapturer) -> PostureTracker: """Handles config of single image frame input and model loading. + Args: + frame_capturer: Interface for posture tracker to get frames for to feed into posture model. + Returns: Tracker object which acts as context manager. """ @@ -207,6 +215,7 @@ def create_posture_tracker() -> PostureTracker: ) tracker = PostureTracker.create_from_options(options) + tracker.frame_capturer = frame_capturer return tracker @@ -230,8 +239,12 @@ def create_debug_posture_tracker() -> DebugPostureTracker: return tracker -def _run_posture(con: connection.Connection) -> None: - with create_posture_tracker() as tracker: +def _run_posture( + con: connection.Connection, frame_capturer: Type[FrameCapturer] +) -> None: + # Instantiate frame capturer in subprocess to avoid pickling errors. + frame_capturer_obj = frame_capturer() + with create_posture_tracker(frame_capturer_obj) as tracker: con.send(True) while True: # Handle message from parent