Skip to content
This repository has been archived by the owner on Dec 8, 2024. It is now read-only.

Commit

Permalink
ref: Moved debug tracker to new trackers module
Browse files Browse the repository at this point in the history
  • Loading branch information
MitchellJC committed Sep 10, 2024
1 parent 69d5227 commit 4d1babc
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 48 deletions.
49 changes: 1 addition & 48 deletions client/models/pose_detection/routines.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,21 @@
"""Routines that can be integrated into a main control flow."""

from importlib import resources
from typing import Callable, Mapping

import cv2
import mediapipe as mp
from mediapipe.framework import calculator_pb2
from mediapipe.python._framework_bindings.packet import Packet
from mediapipe.tasks.python.core.base_options import BaseOptions
from mediapipe.tasks.python.vision import RunningMode
from mediapipe.tasks.python.vision.core.vision_task_running_mode import (
VisionTaskRunningMode,
)
from mediapipe.tasks.python.vision.pose_landmarker import (
PoseLandmarker,
PoseLandmarkerOptions,
)

from models.pose_detection.landmarking import AnnotatedImage, display_landmarking
from models.pose_detection.trackers import DebugPostureTracker

POSE_LANDMARKER_FILE = resources.files("models.resources").joinpath(
"pose_landmarker_lite.task"
)


class DebugPostureTracker(PoseLandmarker):
"""Handles routines for a Debugging Posture Tracker.
Attributes:
annotated_image: Mutable container for an image which may be mutated asynchronously.
"""

def __init__(
self,
graph_config: calculator_pb2.CalculatorGraphConfig,
running_mode: VisionTaskRunningMode,
packet_callback: Callable[[Mapping[str, Packet]], None],
) -> None:
super().__init__(graph_config, running_mode, packet_callback)
self.annotated_image = AnnotatedImage()
self._video_capture = cv2.VideoCapture(0)

def track_posture(self) -> None:
"""Get frame from video capture device and process with pose model, then posture
algorithm. Print debugging info and display landmark annotated frame.
"""
success, frame = self._video_capture.read()
if not success:
return

frame_timestamp_ms = self._video_capture.get(cv2.CAP_PROP_POS_MSEC)
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame)
self.detect_async(mp_image, int(frame_timestamp_ms))

cv2.imshow("input", frame)
if self.annotated_image.data is not None:
cv2.imshow("output", self.annotated_image.data)

def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback) -> None:
self._video_capture.release()
cv2.destroyAllWindows()
super().__exit__(unused_exc_type, unused_exc_value, unused_traceback)


def create_debug_posture_tracker() -> DebugPostureTracker:
"""Handles config of livestreamed input and model loading.
Expand Down
53 changes: 53 additions & 0 deletions client/models/pose_detection/trackers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Callable, Mapping

import cv2
import mediapipe as mp
from mediapipe.framework import calculator_pb2
from mediapipe.python._framework_bindings.packet import Packet
from mediapipe.tasks.python.vision.core.vision_task_running_mode import (
VisionTaskRunningMode,
)
from mediapipe.tasks.python.vision.pose_landmarker import (
PoseLandmarker,
)

from models.pose_detection.landmarking import AnnotatedImage


class DebugPostureTracker(PoseLandmarker):
"""Handles routines for a Debugging Posture Tracker.
Attributes:
annotated_image: Mutable container for an image which may be mutated asynchronously.
"""

def __init__(
self,
graph_config: calculator_pb2.CalculatorGraphConfig,
running_mode: VisionTaskRunningMode,
packet_callback: Callable[[Mapping[str, Packet]], None],
) -> None:
super().__init__(graph_config, running_mode, packet_callback)
self.annotated_image = AnnotatedImage()
self._video_capture = cv2.VideoCapture(0)

def track_posture(self) -> None:
"""Get frame from video capture device and process with pose model, then posture
algorithm. Print debugging info and display landmark annotated frame.
"""
success, frame = self._video_capture.read()
if not success:
return

frame_timestamp_ms = self._video_capture.get(cv2.CAP_PROP_POS_MSEC)
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame)
self.detect_async(mp_image, int(frame_timestamp_ms))

cv2.imshow("input", frame)
if self.annotated_image.data is not None:
cv2.imshow("output", self.annotated_image.data)

def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback) -> None:
self._video_capture.release()
cv2.destroyAllWindows()
super().__exit__(unused_exc_type, unused_exc_value, unused_traceback)

0 comments on commit 4d1babc

Please sign in to comment.