Skip to content

Commit

Permalink
Refactored into data structures and tasks + generalized visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
brukew committed Nov 24, 2024
1 parent 83c31b7 commit 269f0a6
Show file tree
Hide file tree
Showing 3 changed files with 281 additions and 231 deletions.
183 changes: 41 additions & 142 deletions src/senselab/video/data_structures/pose.py
Original file line number Diff line number Diff line change
@@ -1,168 +1,67 @@
"""Data structures relevant for pose estimation."""

from typing import Any, Dict, List
from typing import Dict, List

import cv2
import mediapipe as mp
import numpy as np
import torch
from mediapipe import solutions
from mediapipe.framework.formats import landmark_pb2
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from mediapipe.tasks.python.vision import PoseLandmarkerResult
from pydantic import BaseModel, ConfigDict, field_validator

# class Landmark

class PoseSkeleton(BaseModel):
"""Data structure for estimated poses of multiple individuals in an image.

class IndividualPose(BaseModel):
"""Data structure for estimated pose of single individual in an image.
Attributes:
image: object representing the original image (torch.Tensor)
normalized_landmarks: list of dictionaries for each person's body landmarks with normalized
image coordinates (x, y, z).
world_landmarks: list of dictionaries for each person's body landmarks with real-world
coordinates (x, y, z).
detection_result: output of MediaPipe pose estimation
pose_index: index of individual detected.
normalized_landmarks: Dictionary of body landmarks with normalized image coordinates and visibility (x, y, z, c).
world_landmarks: Dictionary of body landmarks with real-world coordinates and visibility (x, y, z, c).
"""

model_config = ConfigDict(arbitrary_types_allowed=True)
image: torch.Tensor
normalized_landmarks: List[Dict[str, List[float]]] # List of dictionaries for each person
world_landmarks: List[Dict[str, List[float]]] # List of dictionaries for each person
detection_result: Any
individual_index: int
normalized_landmarks: Dict[str, List[float]]
world_landmarks: Dict[str, List[float]]

@field_validator("normalized_landmarks", "world_landmarks", mode="before")
def validate_landmarks(cls, v: List[Dict[str, List[float]]]) -> List[Dict[str, List[float]]]:
"""Validate that landmarks contain at least 3 coordinates."""
for person_landmarks in v:
for coords in person_landmarks.values():
if len(coords) < 3:
raise ValueError("Each landmark must have at least 3 coordinates (x, y, z).")
def validate_landmarks(cls, v: Dict[str, List[float]]) -> Dict[str, List[float]]:
"""Validate that landmarks contain exactly 4 coordinates (x,y,z,visibility)."""
for coords in v.values():
if len(coords) != 4:
raise ValueError("Each landmark must have exactly 4 coordinates (x, y, z, visibility)")
return v

def to_numpy(self) -> np.ndarray:
"""Converts image to numpy array.
Returns:
numpy array of image that was initialized in class
"""
return self.image.cpu().numpy()

def get_landmark_coordinates(self, landmark: str, person_index: int = 0, world: bool = False) -> List[float]:
"""Returns the coordinates of a specified landmark for a given individual in the image.
def get_landmark_coordinates(self, landmark: str, world: bool = False) -> List[float]:
"""Returns coordinates for specified landmark.
Args:
person_index (int): Index of the individual in the detection results. Defaults to 0.
landmark (str): Name of the landmark (e.g., "landmark_0", "landmark_1").
world (bool): If True, retrieves world coordinates. Otherwise, retrieves normalized coordinates.
landmark: Name of the landmark (e.g., "landmark_0")
world: If True, returns world coordinates instead of normalized
Returns:
List[float]: Coordinates of the landmark in the form [x, y, z, visibility].
Raises:
ValueError: If the landmark does not exist or the person index is out of bounds.
[x, y, z, visibility] coordinates
"""
landmarks = self.world_landmarks if world else self.normalized_landmarks
if person_index >= len(landmarks):
raise ValueError(
f"Person index {person_index} is invalid. Image contains {len(landmarks)} people. "
f"Valid indices are {f'0 to {len(landmarks)-1}' if len(landmarks) > 0 else 'none'}"
)

if landmark not in landmarks[person_index]:
raise ValueError(
f"Landmark '{landmark}' not found. Available landmarks: {sorted(landmarks[person_index].keys())}"
)

return landmarks[person_index][landmark]
if landmark not in landmarks:
raise ValueError(f"Landmark '{landmark}' not found. Available landmarks: {sorted(landmarks.keys())}")
return landmarks[landmark]

def visualize_pose(self) -> None:
"""Visualizes pose landmarks on the image and saves the annotated image.

Saves the annotated image as "pose_estimation_output.png" in the current directory.
"""
annotated_image = draw_landmarks_on_image(self.to_numpy(), self.detection_result)
# Save the annotated image
annotated_image_bgr = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR)
cv2.imwrite("pose_estimation_output.png", annotated_image_bgr)
print("Image saved as pose_estimation_output.png")
class ImagePose(BaseModel):
"""Data structure for estimated poses of multiple individuals in an image.
Attributes:
image: numpy array representing the original image
individuals: list of IndividualPose objects for each individual with an estimated pose.
"""

def draw_landmarks_on_image(rgb_image: np.ndarray, detection_result: PoseLandmarkerResult) -> np.ndarray:
"""Draws pose landmarks on the input RGB image.
model_config = ConfigDict(arbitrary_types_allowed=True)

Args:
rgb_image: The input image in RGB format
detection_result: The detection result containing pose landmarks
image: np.ndarray
individuals: List[IndividualPose]

Returns:
Annotated image with pose landmarks drawn
"""
annotated_image = np.copy(rgb_image)
for person_landmarks in detection_result.pose_landmarks:
pose_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
pose_landmarks_proto.landmark.extend(
[landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in person_landmarks]
)
solutions.drawing_utils.draw_landmarks(
annotated_image,
pose_landmarks_proto,
solutions.pose.POSE_CONNECTIONS,
solutions.drawing_styles.get_default_pose_landmarks_style(),
)
return annotated_image


def estimate_pose_with_mediapipe(image_path: str, num_of_individuals: int = 1) -> PoseSkeleton:
"""Estimates pose landmarks for individuals in the provided image using MediaPipe.
Args:
image_path (str): Path to the input image file.
num_of_individuals (int): Maximum number of individuals to detect. Defaults to 1.
Returns:
PoseSkeleton object
Raises:
FileNotFoundError: If the image file does not exist.
"""
# MediaPipe Pose Estimation config
base_options = python.BaseOptions(
model_asset_path="src/senselab/video/tasks/pose_estimation/models/pose_landmarker.task"
)
options = vision.PoseLandmarkerOptions(
base_options=base_options, output_segmentation_masks=True, num_poses=num_of_individuals
)
detector = vision.PoseLandmarker.create_from_options(options)

# Load the input image
image = mp.Image.create_from_file(image_path)
detection_result = detector.detect(image)

normalized_landmarks_list = []
world_landmarks_list = []

for person_landmarks in detection_result.pose_landmarks:
# Store normalized landmarks (3D) for each person
person_normalized_landmarks = {}
for idx, landmark in enumerate(person_landmarks):
person_normalized_landmarks[f"landmark_{idx}"] = [landmark.x, landmark.y, landmark.z, landmark.visibility]
normalized_landmarks_list.append(person_normalized_landmarks)

for person_landmarks in detection_result.pose_world_landmarks:
# Store world landmarks (3D) for each person
person_world_landmarks = {}
for idx, landmark in enumerate(person_landmarks):
person_world_landmarks[f"landmark_{idx}"] = [landmark.x, landmark.y, landmark.z, landmark.visibility]
world_landmarks_list.append(person_world_landmarks)

image_tensor = torch.from_numpy(image.numpy_view().copy())

# Return PoseSkeleton with all detected individuals' landmarks
return PoseSkeleton(
image=image_tensor,
normalized_landmarks=normalized_landmarks_list,
world_landmarks=world_landmarks_list,
detection_result=detection_result,
)
def get_individual(self, individual_index: int) -> IndividualPose:
"""Returns IndividualPose object for specified individual."""
if individual_index >= len(self.individuals) or individual_index < 0:
raise ValueError(
f"Individual index {individual_index} is invalid. {len(self.individuals)} poses were estimated. "
f"Valid indices are {f'0 to {len(self.individuals)-1}' if len(self.individuals) > 0 else 'none'}"
)
return self.individuals[individual_index]
122 changes: 122 additions & 0 deletions src/senselab/video/tasks/pose_estimation/pose_estimation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""This module implements the Pose Estimation task and supporting utilities."""

import os

import cv2
import mediapipe as mp
import numpy as np
from mediapipe import solutions
from mediapipe.framework.formats import landmark_pb2
from mediapipe.tasks import python
from mediapipe.tasks.python import vision

from senselab.video.data_structures.pose import ImagePose, IndividualPose


def visualize_pose(image: ImagePose) -> np.ndarray:
"""Visualizes pose landmarks on the input image.
Args:
image: ImagePose object containing image and detected poses
Returns:
Annotated image with pose landmarks drawn
Note:
Saves the annotated image as 'pose_estimation_output.png'
"""
# Convert to RGB if needed and create copy
annotated_image = np.copy(image.image)
if len(annotated_image.shape) == 2: # Grayscale
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_GRAY2RGB)

for individual in image.individuals:
pose_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
landmarks = [
landmark_pb2.NormalizedLandmark(x=coords[0], y=coords[1], z=coords[2])
for coords in individual.normalized_landmarks.values()
]
pose_landmarks_proto.landmark.extend(landmarks)

solutions.drawing_utils.draw_landmarks(
annotated_image,
pose_landmarks_proto,
solutions.pose.POSE_CONNECTIONS,
solutions.drawing_styles.get_default_pose_landmarks_style(),
)

# Save the annotated image
annotated_image_bgr = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR)
cv2.imwrite("pose_estimation_output.png", annotated_image_bgr)
print("Image saved as pose_estimation_output.png")

return annotated_image


def estimate_pose_with_mediapipe(
image_path: str,
num_of_individuals: int = 1,
model_path: str = "src/senselab/video/tasks/pose_estimation/models/pose_landmarker.task",
) -> ImagePose:
"""Estimates pose landmarks for individuals in the provided image using MediaPipe.
Args:
image_path: Path to the input image file
num_of_individuals: Maximum number of individuals to detect. Defaults to 1
model_path: Path to the MediaPipe pose landmarker model file
Returns:
ImagePose object containing detected poses
Raises:
FileNotFoundError: If image_path or model_path doesn't exist
RuntimeError: If pose detection fails
"""
# Validate file paths
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image not found at: {image_path}")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model not found at: {model_path}")

# Initialize detector
base_options = python.BaseOptions(model_asset_path=model_path)
options = vision.PoseLandmarkerOptions(
base_options=base_options, output_segmentation_masks=True, num_poses=num_of_individuals
)
detector = vision.PoseLandmarker.create_from_options(options)

# Load and process image
image = mp.Image.create_from_file(image_path)
detection_result = detector.detect(image)

if not detection_result:
raise RuntimeError("Pose detection failed")

# Create IndividualPose objects
individuals = []
for idx, (norm_landmarks, world_landmarks) in enumerate(
zip(detection_result.pose_landmarks, detection_result.pose_world_landmarks)
):
norm_dict = {f"landmark_{i}": [lm.x, lm.y, lm.z, lm.visibility] for i, lm in enumerate(norm_landmarks)}
world_dict = {f"landmark_{i}": [lm.x, lm.y, lm.z, lm.visibility] for i, lm in enumerate(world_landmarks)}

individual = IndividualPose(individual_index=idx, normalized_landmarks=norm_dict, world_landmarks=world_dict)
individuals.append(individual)

# Create and return ImagePose
image_array = image.numpy_view().copy()
return ImagePose(image=image_array, individuals=individuals, detection_result=detection_result)


if __name__ == "__main__":
import os

# Example usage
image_path = "src/tests/data_for_testing/pose_data/three_people.jpg"
try:
pose_result = estimate_pose_with_mediapipe(image_path, num_of_individuals=2)
annotated = visualize_pose(pose_result)
print(f"Detected {len(pose_result.individuals)} individuals")
print("Example individual pose data:", pose_result.get_individual(0))
except Exception as e:
print(f"Error processing image: {e}")
Loading

0 comments on commit 269f0a6

Please sign in to comment.