generated from sensein/python-package-template
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactored into data structures and tasks + generalized visualization
- Loading branch information
Showing
3 changed files
with
281 additions
and
231 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,168 +1,67 @@ | ||
"""Data structures relevant for pose estimation.""" | ||
|
||
from typing import Any, Dict, List | ||
from typing import Dict, List | ||
|
||
import cv2 | ||
import mediapipe as mp | ||
import numpy as np | ||
import torch | ||
from mediapipe import solutions | ||
from mediapipe.framework.formats import landmark_pb2 | ||
from mediapipe.tasks import python | ||
from mediapipe.tasks.python import vision | ||
from mediapipe.tasks.python.vision import PoseLandmarkerResult | ||
from pydantic import BaseModel, ConfigDict, field_validator | ||
|
||
# class Landmark | ||
|
||
class PoseSkeleton(BaseModel): | ||
"""Data structure for estimated poses of multiple individuals in an image. | ||
|
||
class IndividualPose(BaseModel): | ||
"""Data structure for estimated pose of single individual in an image. | ||
Attributes: | ||
image: object representing the original image (torch.Tensor) | ||
normalized_landmarks: list of dictionaries for each person's body landmarks with normalized | ||
image coordinates (x, y, z). | ||
world_landmarks: list of dictionaries for each person's body landmarks with real-world | ||
coordinates (x, y, z). | ||
detection_result: output of MediaPipe pose estimation | ||
pose_index: index of individual detected. | ||
normalized_landmarks: Dictionary of body landmarks with normalized image coordinates and visibility (x, y, z, c). | ||
world_landmarks: Dictionary of body landmarks with real-world coordinates and visibility (x, y, z, c). | ||
""" | ||
|
||
model_config = ConfigDict(arbitrary_types_allowed=True) | ||
image: torch.Tensor | ||
normalized_landmarks: List[Dict[str, List[float]]] # List of dictionaries for each person | ||
world_landmarks: List[Dict[str, List[float]]] # List of dictionaries for each person | ||
detection_result: Any | ||
individual_index: int | ||
normalized_landmarks: Dict[str, List[float]] | ||
world_landmarks: Dict[str, List[float]] | ||
|
||
@field_validator("normalized_landmarks", "world_landmarks", mode="before") | ||
def validate_landmarks(cls, v: List[Dict[str, List[float]]]) -> List[Dict[str, List[float]]]: | ||
"""Validate that landmarks contain at least 3 coordinates.""" | ||
for person_landmarks in v: | ||
for coords in person_landmarks.values(): | ||
if len(coords) < 3: | ||
raise ValueError("Each landmark must have at least 3 coordinates (x, y, z).") | ||
def validate_landmarks(cls, v: Dict[str, List[float]]) -> Dict[str, List[float]]: | ||
"""Validate that landmarks contain exactly 4 coordinates (x,y,z,visibility).""" | ||
for coords in v.values(): | ||
if len(coords) != 4: | ||
raise ValueError("Each landmark must have exactly 4 coordinates (x, y, z, visibility)") | ||
return v | ||
|
||
def to_numpy(self) -> np.ndarray: | ||
"""Converts image to numpy array. | ||
Returns: | ||
numpy array of image that was initialized in class | ||
""" | ||
return self.image.cpu().numpy() | ||
|
||
def get_landmark_coordinates(self, landmark: str, person_index: int = 0, world: bool = False) -> List[float]: | ||
"""Returns the coordinates of a specified landmark for a given individual in the image. | ||
def get_landmark_coordinates(self, landmark: str, world: bool = False) -> List[float]: | ||
"""Returns coordinates for specified landmark. | ||
Args: | ||
person_index (int): Index of the individual in the detection results. Defaults to 0. | ||
landmark (str): Name of the landmark (e.g., "landmark_0", "landmark_1"). | ||
world (bool): If True, retrieves world coordinates. Otherwise, retrieves normalized coordinates. | ||
landmark: Name of the landmark (e.g., "landmark_0") | ||
world: If True, returns world coordinates instead of normalized | ||
Returns: | ||
List[float]: Coordinates of the landmark in the form [x, y, z, visibility]. | ||
Raises: | ||
ValueError: If the landmark does not exist or the person index is out of bounds. | ||
[x, y, z, visibility] coordinates | ||
""" | ||
landmarks = self.world_landmarks if world else self.normalized_landmarks | ||
if person_index >= len(landmarks): | ||
raise ValueError( | ||
f"Person index {person_index} is invalid. Image contains {len(landmarks)} people. " | ||
f"Valid indices are {f'0 to {len(landmarks)-1}' if len(landmarks) > 0 else 'none'}" | ||
) | ||
|
||
if landmark not in landmarks[person_index]: | ||
raise ValueError( | ||
f"Landmark '{landmark}' not found. Available landmarks: {sorted(landmarks[person_index].keys())}" | ||
) | ||
|
||
return landmarks[person_index][landmark] | ||
if landmark not in landmarks: | ||
raise ValueError(f"Landmark '{landmark}' not found. Available landmarks: {sorted(landmarks.keys())}") | ||
return landmarks[landmark] | ||
|
||
def visualize_pose(self) -> None: | ||
"""Visualizes pose landmarks on the image and saves the annotated image. | ||
|
||
Saves the annotated image as "pose_estimation_output.png" in the current directory. | ||
""" | ||
annotated_image = draw_landmarks_on_image(self.to_numpy(), self.detection_result) | ||
# Save the annotated image | ||
annotated_image_bgr = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR) | ||
cv2.imwrite("pose_estimation_output.png", annotated_image_bgr) | ||
print("Image saved as pose_estimation_output.png") | ||
class ImagePose(BaseModel): | ||
"""Data structure for estimated poses of multiple individuals in an image. | ||
Attributes: | ||
image: numpy array representing the original image | ||
individuals: list of IndividualPose objects for each individual with an estimated pose. | ||
""" | ||
|
||
def draw_landmarks_on_image(rgb_image: np.ndarray, detection_result: PoseLandmarkerResult) -> np.ndarray: | ||
"""Draws pose landmarks on the input RGB image. | ||
model_config = ConfigDict(arbitrary_types_allowed=True) | ||
|
||
Args: | ||
rgb_image: The input image in RGB format | ||
detection_result: The detection result containing pose landmarks | ||
image: np.ndarray | ||
individuals: List[IndividualPose] | ||
|
||
Returns: | ||
Annotated image with pose landmarks drawn | ||
""" | ||
annotated_image = np.copy(rgb_image) | ||
for person_landmarks in detection_result.pose_landmarks: | ||
pose_landmarks_proto = landmark_pb2.NormalizedLandmarkList() | ||
pose_landmarks_proto.landmark.extend( | ||
[landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in person_landmarks] | ||
) | ||
solutions.drawing_utils.draw_landmarks( | ||
annotated_image, | ||
pose_landmarks_proto, | ||
solutions.pose.POSE_CONNECTIONS, | ||
solutions.drawing_styles.get_default_pose_landmarks_style(), | ||
) | ||
return annotated_image | ||
|
||
|
||
def estimate_pose_with_mediapipe(image_path: str, num_of_individuals: int = 1) -> PoseSkeleton: | ||
"""Estimates pose landmarks for individuals in the provided image using MediaPipe. | ||
Args: | ||
image_path (str): Path to the input image file. | ||
num_of_individuals (int): Maximum number of individuals to detect. Defaults to 1. | ||
Returns: | ||
PoseSkeleton object | ||
Raises: | ||
FileNotFoundError: If the image file does not exist. | ||
""" | ||
# MediaPipe Pose Estimation config | ||
base_options = python.BaseOptions( | ||
model_asset_path="src/senselab/video/tasks/pose_estimation/models/pose_landmarker.task" | ||
) | ||
options = vision.PoseLandmarkerOptions( | ||
base_options=base_options, output_segmentation_masks=True, num_poses=num_of_individuals | ||
) | ||
detector = vision.PoseLandmarker.create_from_options(options) | ||
|
||
# Load the input image | ||
image = mp.Image.create_from_file(image_path) | ||
detection_result = detector.detect(image) | ||
|
||
normalized_landmarks_list = [] | ||
world_landmarks_list = [] | ||
|
||
for person_landmarks in detection_result.pose_landmarks: | ||
# Store normalized landmarks (3D) for each person | ||
person_normalized_landmarks = {} | ||
for idx, landmark in enumerate(person_landmarks): | ||
person_normalized_landmarks[f"landmark_{idx}"] = [landmark.x, landmark.y, landmark.z, landmark.visibility] | ||
normalized_landmarks_list.append(person_normalized_landmarks) | ||
|
||
for person_landmarks in detection_result.pose_world_landmarks: | ||
# Store world landmarks (3D) for each person | ||
person_world_landmarks = {} | ||
for idx, landmark in enumerate(person_landmarks): | ||
person_world_landmarks[f"landmark_{idx}"] = [landmark.x, landmark.y, landmark.z, landmark.visibility] | ||
world_landmarks_list.append(person_world_landmarks) | ||
|
||
image_tensor = torch.from_numpy(image.numpy_view().copy()) | ||
|
||
# Return PoseSkeleton with all detected individuals' landmarks | ||
return PoseSkeleton( | ||
image=image_tensor, | ||
normalized_landmarks=normalized_landmarks_list, | ||
world_landmarks=world_landmarks_list, | ||
detection_result=detection_result, | ||
) | ||
def get_individual(self, individual_index: int) -> IndividualPose: | ||
"""Returns IndividualPose object for specified individual.""" | ||
if individual_index >= len(self.individuals) or individual_index < 0: | ||
raise ValueError( | ||
f"Individual index {individual_index} is invalid. {len(self.individuals)} poses were estimated. " | ||
f"Valid indices are {f'0 to {len(self.individuals)-1}' if len(self.individuals) > 0 else 'none'}" | ||
) | ||
return self.individuals[individual_index] |
122 changes: 122 additions & 0 deletions
122
src/senselab/video/tasks/pose_estimation/pose_estimation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
"""This module implements the Pose Estimation task and supporting utilities.""" | ||
|
||
import os | ||
|
||
import cv2 | ||
import mediapipe as mp | ||
import numpy as np | ||
from mediapipe import solutions | ||
from mediapipe.framework.formats import landmark_pb2 | ||
from mediapipe.tasks import python | ||
from mediapipe.tasks.python import vision | ||
|
||
from senselab.video.data_structures.pose import ImagePose, IndividualPose | ||
|
||
|
||
def visualize_pose(image: ImagePose) -> np.ndarray: | ||
"""Visualizes pose landmarks on the input image. | ||
Args: | ||
image: ImagePose object containing image and detected poses | ||
Returns: | ||
Annotated image with pose landmarks drawn | ||
Note: | ||
Saves the annotated image as 'pose_estimation_output.png' | ||
""" | ||
# Convert to RGB if needed and create copy | ||
annotated_image = np.copy(image.image) | ||
if len(annotated_image.shape) == 2: # Grayscale | ||
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_GRAY2RGB) | ||
|
||
for individual in image.individuals: | ||
pose_landmarks_proto = landmark_pb2.NormalizedLandmarkList() | ||
landmarks = [ | ||
landmark_pb2.NormalizedLandmark(x=coords[0], y=coords[1], z=coords[2]) | ||
for coords in individual.normalized_landmarks.values() | ||
] | ||
pose_landmarks_proto.landmark.extend(landmarks) | ||
|
||
solutions.drawing_utils.draw_landmarks( | ||
annotated_image, | ||
pose_landmarks_proto, | ||
solutions.pose.POSE_CONNECTIONS, | ||
solutions.drawing_styles.get_default_pose_landmarks_style(), | ||
) | ||
|
||
# Save the annotated image | ||
annotated_image_bgr = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR) | ||
cv2.imwrite("pose_estimation_output.png", annotated_image_bgr) | ||
print("Image saved as pose_estimation_output.png") | ||
|
||
return annotated_image | ||
|
||
|
||
def estimate_pose_with_mediapipe( | ||
image_path: str, | ||
num_of_individuals: int = 1, | ||
model_path: str = "src/senselab/video/tasks/pose_estimation/models/pose_landmarker.task", | ||
) -> ImagePose: | ||
"""Estimates pose landmarks for individuals in the provided image using MediaPipe. | ||
Args: | ||
image_path: Path to the input image file | ||
num_of_individuals: Maximum number of individuals to detect. Defaults to 1 | ||
model_path: Path to the MediaPipe pose landmarker model file | ||
Returns: | ||
ImagePose object containing detected poses | ||
Raises: | ||
FileNotFoundError: If image_path or model_path doesn't exist | ||
RuntimeError: If pose detection fails | ||
""" | ||
# Validate file paths | ||
if not os.path.exists(image_path): | ||
raise FileNotFoundError(f"Image not found at: {image_path}") | ||
if not os.path.exists(model_path): | ||
raise FileNotFoundError(f"Model not found at: {model_path}") | ||
|
||
# Initialize detector | ||
base_options = python.BaseOptions(model_asset_path=model_path) | ||
options = vision.PoseLandmarkerOptions( | ||
base_options=base_options, output_segmentation_masks=True, num_poses=num_of_individuals | ||
) | ||
detector = vision.PoseLandmarker.create_from_options(options) | ||
|
||
# Load and process image | ||
image = mp.Image.create_from_file(image_path) | ||
detection_result = detector.detect(image) | ||
|
||
if not detection_result: | ||
raise RuntimeError("Pose detection failed") | ||
|
||
# Create IndividualPose objects | ||
individuals = [] | ||
for idx, (norm_landmarks, world_landmarks) in enumerate( | ||
zip(detection_result.pose_landmarks, detection_result.pose_world_landmarks) | ||
): | ||
norm_dict = {f"landmark_{i}": [lm.x, lm.y, lm.z, lm.visibility] for i, lm in enumerate(norm_landmarks)} | ||
world_dict = {f"landmark_{i}": [lm.x, lm.y, lm.z, lm.visibility] for i, lm in enumerate(world_landmarks)} | ||
|
||
individual = IndividualPose(individual_index=idx, normalized_landmarks=norm_dict, world_landmarks=world_dict) | ||
individuals.append(individual) | ||
|
||
# Create and return ImagePose | ||
image_array = image.numpy_view().copy() | ||
return ImagePose(image=image_array, individuals=individuals, detection_result=detection_result) | ||
|
||
|
||
if __name__ == "__main__": | ||
import os | ||
|
||
# Example usage | ||
image_path = "src/tests/data_for_testing/pose_data/three_people.jpg" | ||
try: | ||
pose_result = estimate_pose_with_mediapipe(image_path, num_of_individuals=2) | ||
annotated = visualize_pose(pose_result) | ||
print(f"Detected {len(pose_result.individuals)} individuals") | ||
print("Example individual pose data:", pose_result.get_individual(0)) | ||
except Exception as e: | ||
print(f"Error processing image: {e}") |
Oops, something went wrong.