Skip to content
This repository has been archived by the owner on Dec 8, 2024. It is now read-only.

Commit

Permalink
Merge pull request #53 from LimaoC/mitch-modular-capture
Browse files Browse the repository at this point in the history
Added flexibility with FrameCapturer for PostureProcess
  • Loading branch information
MitchellJC authored Sep 13, 2024
2 parents 0034308 + ecb8fd3 commit 9de1b4d
Showing 1 changed file with 28 additions and 15 deletions.
43 changes: 28 additions & 15 deletions client/models/pose_detection/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import multiprocessing as multp
import multiprocessing.connection as connection
from importlib import resources
from typing import Callable, Mapping
from typing import Callable, Mapping, Optional, Type
from datetime import datetime

import cv2
Expand All @@ -27,6 +27,7 @@
from models.pose_detection.landmarking import AnnotatedImage, display_landmarking
from models.pose_detection.camera import is_camera_aligned
from models.pose_detection.classification import posture_classify
from models.pose_detection.frame_capturer import FrameCapturer, OpenCVCapturer

POSE_LANDMARKER_FILE = resources.files("models.resources").joinpath(
"pose_landmarker_lite.task"
Expand All @@ -44,12 +45,17 @@ class PostureProcess:
API.
"""

def __init__(self) -> None:
def __init__(self, frame_capturer: Type[FrameCapturer] = OpenCVCapturer) -> None:
"""Create a new process which loads the MediaPipe Pose model and runs periodic posture
tracking. This initializer blocks until the model is loaded.
Args:
frame_capturer: Class reference to capturer for child process to construct
"""
self._parent_con, child_con = multp.Pipe()
self._process = multp.Process(target=_run_posture, args=(child_con,))

args = (child_con, frame_capturer)
self._process = multp.Process(target=_run_posture, args=args)
self._process.start()

# Blocks until something is recieved from child
Expand Down Expand Up @@ -79,6 +85,7 @@ class PostureTracker(PoseLandmarker):
Attributes:
user_id: Id for the user currently being tracked.
frame_capturer: Captures frames to be tracked by model.
"""

def __init__(
Expand All @@ -88,6 +95,7 @@ def __init__(
packet_callback: Callable[[Mapping[str, Packet]], None],
) -> None:
super().__init__(graph_config, running_mode, packet_callback)
self.frame_capturer: Optional[FrameCapturer] = None

self._user_id = NO_USER

Expand All @@ -96,8 +104,6 @@ def __init__(
self._start_time = time.time()
self._period_start = datetime.now()

self._video_capture = cv2.VideoCapture(0)

@property
def user_id(self) -> int:
"""Currently tracked user. Data will be associated with this user in the database."""
Expand All @@ -112,12 +118,15 @@ def track_posture(self) -> None:
"""Get frame from video capture device and process with pose model, then posture
algorithm. Saves posture data to database periodically.
"""
if self.frame_capturer is None:
raise ValueError(
"Please set a frame capturer before trying to track posture"
)

if self.user_id == NO_USER:
return

success, frame = self._video_capture.read()
if not success:
return
frame, _ = self.frame_capturer.get_frame()

mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame)
result = self.detect(mp_image)
Expand Down Expand Up @@ -151,10 +160,6 @@ def _new_period(self) -> None:
self._start_time = time.time()
self._period_start = datetime.now()

def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback) -> None:
self._video_capture.release()
super().__exit__(unused_exc_type, unused_exc_value, unused_traceback)


class DebugPostureTracker(PoseLandmarker):
"""Handles routines for a Debugging Posture Tracker.
Expand Down Expand Up @@ -195,9 +200,12 @@ def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback) -> None:
super().__exit__(unused_exc_type, unused_exc_value, unused_traceback)


def create_posture_tracker() -> PostureTracker:
def create_posture_tracker(frame_capturer: FrameCapturer) -> PostureTracker:
"""Handles config of single image frame input and model loading.
Args:
frame_capturer: Interface for posture tracker to get frames for to feed into posture model.
Returns:
Tracker object which acts as context manager.
"""
Expand All @@ -207,6 +215,7 @@ def create_posture_tracker() -> PostureTracker:
)

tracker = PostureTracker.create_from_options(options)
tracker.frame_capturer = frame_capturer
return tracker


Expand All @@ -230,8 +239,12 @@ def create_debug_posture_tracker() -> DebugPostureTracker:
return tracker


def _run_posture(con: connection.Connection) -> None:
with create_posture_tracker() as tracker:
def _run_posture(
con: connection.Connection, frame_capturer: Type[FrameCapturer]
) -> None:
# Instantiate frame capturer in subprocess to avoid pickling errors.
frame_capturer_obj = frame_capturer()
with create_posture_tracker(frame_capturer_obj) as tracker:
con.send(True)
while True:
# Handle message from parent
Expand Down

0 comments on commit 9de1b4d

Please sign in to comment.