From 269f0a620a67ef38d3013a7896ce37a9fb4f7149 Mon Sep 17 00:00:00 2001 From: brukew Date: Sun, 24 Nov 2024 16:10:48 -0500 Subject: [PATCH] Refactored into data structures and tasks + generalized visualization --- src/senselab/video/data_structures/pose.py | 183 ++++------------ .../tasks/pose_estimation/pose_estimation.py | 122 +++++++++++ src/tests/video/data_structures/pose_test.py | 207 ++++++++++-------- 3 files changed, 281 insertions(+), 231 deletions(-) create mode 100644 src/senselab/video/tasks/pose_estimation/pose_estimation.py diff --git a/src/senselab/video/data_structures/pose.py b/src/senselab/video/data_structures/pose.py index fa7fafa0..4c391e5a 100644 --- a/src/senselab/video/data_structures/pose.py +++ b/src/senselab/video/data_structures/pose.py @@ -1,168 +1,67 @@ """Data structures relevant for pose estimation.""" -from typing import Any, Dict, List +from typing import Dict, List -import cv2 -import mediapipe as mp import numpy as np -import torch -from mediapipe import solutions -from mediapipe.framework.formats import landmark_pb2 -from mediapipe.tasks import python -from mediapipe.tasks.python import vision -from mediapipe.tasks.python.vision import PoseLandmarkerResult from pydantic import BaseModel, ConfigDict, field_validator +# class Landmark -class PoseSkeleton(BaseModel): - """Data structure for estimated poses of multiple individuals in an image. + +class IndividualPose(BaseModel): + """Data structure for estimated pose of single individual in an image. Attributes: - image: object representing the original image (torch.Tensor) - normalized_landmarks: list of dictionaries for each person's body landmarks with normalized - image coordinates (x, y, z). - world_landmarks: list of dictionaries for each person's body landmarks with real-world - coordinates (x, y, z). - detection_result: output of MediaPipe pose estimation + pose_index: index of individual detected. + normalized_landmarks: Dictionary of body landmarks with normalized image coordinates and visibility (x, y, z, c). + world_landmarks: Dictionary of body landmarks with real-world coordinates and visibility (x, y, z, c). """ - model_config = ConfigDict(arbitrary_types_allowed=True) - image: torch.Tensor - normalized_landmarks: List[Dict[str, List[float]]] # List of dictionaries for each person - world_landmarks: List[Dict[str, List[float]]] # List of dictionaries for each person - detection_result: Any + individual_index: int + normalized_landmarks: Dict[str, List[float]] + world_landmarks: Dict[str, List[float]] @field_validator("normalized_landmarks", "world_landmarks", mode="before") - def validate_landmarks(cls, v: List[Dict[str, List[float]]]) -> List[Dict[str, List[float]]]: - """Validate that landmarks contain at least 3 coordinates.""" - for person_landmarks in v: - for coords in person_landmarks.values(): - if len(coords) < 3: - raise ValueError("Each landmark must have at least 3 coordinates (x, y, z).") + def validate_landmarks(cls, v: Dict[str, List[float]]) -> Dict[str, List[float]]: + """Validate that landmarks contain exactly 4 coordinates (x,y,z,visibility).""" + for coords in v.values(): + if len(coords) != 4: + raise ValueError("Each landmark must have exactly 4 coordinates (x, y, z, visibility)") return v - def to_numpy(self) -> np.ndarray: - """Converts image to numpy array. - - Returns: - numpy array of image that was initialized in class - """ - return self.image.cpu().numpy() - - def get_landmark_coordinates(self, landmark: str, person_index: int = 0, world: bool = False) -> List[float]: - """Returns the coordinates of a specified landmark for a given individual in the image. + def get_landmark_coordinates(self, landmark: str, world: bool = False) -> List[float]: + """Returns coordinates for specified landmark. Args: - person_index (int): Index of the individual in the detection results. Defaults to 0. - landmark (str): Name of the landmark (e.g., "landmark_0", "landmark_1"). - world (bool): If True, retrieves world coordinates. Otherwise, retrieves normalized coordinates. - + landmark: Name of the landmark (e.g., "landmark_0") + world: If True, returns world coordinates instead of normalized Returns: - List[float]: Coordinates of the landmark in the form [x, y, z, visibility]. - - Raises: - ValueError: If the landmark does not exist or the person index is out of bounds. + [x, y, z, visibility] coordinates """ landmarks = self.world_landmarks if world else self.normalized_landmarks - if person_index >= len(landmarks): - raise ValueError( - f"Person index {person_index} is invalid. Image contains {len(landmarks)} people. " - f"Valid indices are {f'0 to {len(landmarks)-1}' if len(landmarks) > 0 else 'none'}" - ) - - if landmark not in landmarks[person_index]: - raise ValueError( - f"Landmark '{landmark}' not found. Available landmarks: {sorted(landmarks[person_index].keys())}" - ) - - return landmarks[person_index][landmark] + if landmark not in landmarks: + raise ValueError(f"Landmark '{landmark}' not found. Available landmarks: {sorted(landmarks.keys())}") + return landmarks[landmark] - def visualize_pose(self) -> None: - """Visualizes pose landmarks on the image and saves the annotated image. - Saves the annotated image as "pose_estimation_output.png" in the current directory. - """ - annotated_image = draw_landmarks_on_image(self.to_numpy(), self.detection_result) - # Save the annotated image - annotated_image_bgr = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR) - cv2.imwrite("pose_estimation_output.png", annotated_image_bgr) - print("Image saved as pose_estimation_output.png") +class ImagePose(BaseModel): + """Data structure for estimated poses of multiple individuals in an image. + Attributes: + image: numpy array representing the original image + individuals: list of IndividualPose objects for each individual with an estimated pose. + """ -def draw_landmarks_on_image(rgb_image: np.ndarray, detection_result: PoseLandmarkerResult) -> np.ndarray: - """Draws pose landmarks on the input RGB image. + model_config = ConfigDict(arbitrary_types_allowed=True) - Args: - rgb_image: The input image in RGB format - detection_result: The detection result containing pose landmarks + image: np.ndarray + individuals: List[IndividualPose] - Returns: - Annotated image with pose landmarks drawn - """ - annotated_image = np.copy(rgb_image) - for person_landmarks in detection_result.pose_landmarks: - pose_landmarks_proto = landmark_pb2.NormalizedLandmarkList() - pose_landmarks_proto.landmark.extend( - [landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in person_landmarks] - ) - solutions.drawing_utils.draw_landmarks( - annotated_image, - pose_landmarks_proto, - solutions.pose.POSE_CONNECTIONS, - solutions.drawing_styles.get_default_pose_landmarks_style(), - ) - return annotated_image - - -def estimate_pose_with_mediapipe(image_path: str, num_of_individuals: int = 1) -> PoseSkeleton: - """Estimates pose landmarks for individuals in the provided image using MediaPipe. - - Args: - image_path (str): Path to the input image file. - num_of_individuals (int): Maximum number of individuals to detect. Defaults to 1. - - Returns: - PoseSkeleton object - - Raises: - FileNotFoundError: If the image file does not exist. - """ - # MediaPipe Pose Estimation config - base_options = python.BaseOptions( - model_asset_path="src/senselab/video/tasks/pose_estimation/models/pose_landmarker.task" - ) - options = vision.PoseLandmarkerOptions( - base_options=base_options, output_segmentation_masks=True, num_poses=num_of_individuals - ) - detector = vision.PoseLandmarker.create_from_options(options) - - # Load the input image - image = mp.Image.create_from_file(image_path) - detection_result = detector.detect(image) - - normalized_landmarks_list = [] - world_landmarks_list = [] - - for person_landmarks in detection_result.pose_landmarks: - # Store normalized landmarks (3D) for each person - person_normalized_landmarks = {} - for idx, landmark in enumerate(person_landmarks): - person_normalized_landmarks[f"landmark_{idx}"] = [landmark.x, landmark.y, landmark.z, landmark.visibility] - normalized_landmarks_list.append(person_normalized_landmarks) - - for person_landmarks in detection_result.pose_world_landmarks: - # Store world landmarks (3D) for each person - person_world_landmarks = {} - for idx, landmark in enumerate(person_landmarks): - person_world_landmarks[f"landmark_{idx}"] = [landmark.x, landmark.y, landmark.z, landmark.visibility] - world_landmarks_list.append(person_world_landmarks) - - image_tensor = torch.from_numpy(image.numpy_view().copy()) - - # Return PoseSkeleton with all detected individuals' landmarks - return PoseSkeleton( - image=image_tensor, - normalized_landmarks=normalized_landmarks_list, - world_landmarks=world_landmarks_list, - detection_result=detection_result, - ) + def get_individual(self, individual_index: int) -> IndividualPose: + """Returns IndividualPose object for specified individual.""" + if individual_index >= len(self.individuals) or individual_index < 0: + raise ValueError( + f"Individual index {individual_index} is invalid. {len(self.individuals)} poses were estimated. " + f"Valid indices are {f'0 to {len(self.individuals)-1}' if len(self.individuals) > 0 else 'none'}" + ) + return self.individuals[individual_index] diff --git a/src/senselab/video/tasks/pose_estimation/pose_estimation.py b/src/senselab/video/tasks/pose_estimation/pose_estimation.py new file mode 100644 index 00000000..dab000e4 --- /dev/null +++ b/src/senselab/video/tasks/pose_estimation/pose_estimation.py @@ -0,0 +1,122 @@ +"""This module implements the Pose Estimation task and supporting utilities.""" + +import os + +import cv2 +import mediapipe as mp +import numpy as np +from mediapipe import solutions +from mediapipe.framework.formats import landmark_pb2 +from mediapipe.tasks import python +from mediapipe.tasks.python import vision + +from senselab.video.data_structures.pose import ImagePose, IndividualPose + + +def visualize_pose(image: ImagePose) -> np.ndarray: + """Visualizes pose landmarks on the input image. + + Args: + image: ImagePose object containing image and detected poses + + Returns: + Annotated image with pose landmarks drawn + + Note: + Saves the annotated image as 'pose_estimation_output.png' + """ + # Convert to RGB if needed and create copy + annotated_image = np.copy(image.image) + if len(annotated_image.shape) == 2: # Grayscale + annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_GRAY2RGB) + + for individual in image.individuals: + pose_landmarks_proto = landmark_pb2.NormalizedLandmarkList() + landmarks = [ + landmark_pb2.NormalizedLandmark(x=coords[0], y=coords[1], z=coords[2]) + for coords in individual.normalized_landmarks.values() + ] + pose_landmarks_proto.landmark.extend(landmarks) + + solutions.drawing_utils.draw_landmarks( + annotated_image, + pose_landmarks_proto, + solutions.pose.POSE_CONNECTIONS, + solutions.drawing_styles.get_default_pose_landmarks_style(), + ) + + # Save the annotated image + annotated_image_bgr = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR) + cv2.imwrite("pose_estimation_output.png", annotated_image_bgr) + print("Image saved as pose_estimation_output.png") + + return annotated_image + + +def estimate_pose_with_mediapipe( + image_path: str, + num_of_individuals: int = 1, + model_path: str = "src/senselab/video/tasks/pose_estimation/models/pose_landmarker.task", +) -> ImagePose: + """Estimates pose landmarks for individuals in the provided image using MediaPipe. + + Args: + image_path: Path to the input image file + num_of_individuals: Maximum number of individuals to detect. Defaults to 1 + model_path: Path to the MediaPipe pose landmarker model file + + Returns: + ImagePose object containing detected poses + + Raises: + FileNotFoundError: If image_path or model_path doesn't exist + RuntimeError: If pose detection fails + """ + # Validate file paths + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image not found at: {image_path}") + if not os.path.exists(model_path): + raise FileNotFoundError(f"Model not found at: {model_path}") + + # Initialize detector + base_options = python.BaseOptions(model_asset_path=model_path) + options = vision.PoseLandmarkerOptions( + base_options=base_options, output_segmentation_masks=True, num_poses=num_of_individuals + ) + detector = vision.PoseLandmarker.create_from_options(options) + + # Load and process image + image = mp.Image.create_from_file(image_path) + detection_result = detector.detect(image) + + if not detection_result: + raise RuntimeError("Pose detection failed") + + # Create IndividualPose objects + individuals = [] + for idx, (norm_landmarks, world_landmarks) in enumerate( + zip(detection_result.pose_landmarks, detection_result.pose_world_landmarks) + ): + norm_dict = {f"landmark_{i}": [lm.x, lm.y, lm.z, lm.visibility] for i, lm in enumerate(norm_landmarks)} + world_dict = {f"landmark_{i}": [lm.x, lm.y, lm.z, lm.visibility] for i, lm in enumerate(world_landmarks)} + + individual = IndividualPose(individual_index=idx, normalized_landmarks=norm_dict, world_landmarks=world_dict) + individuals.append(individual) + + # Create and return ImagePose + image_array = image.numpy_view().copy() + return ImagePose(image=image_array, individuals=individuals, detection_result=detection_result) + + +if __name__ == "__main__": + import os + + # Example usage + image_path = "src/tests/data_for_testing/pose_data/three_people.jpg" + try: + pose_result = estimate_pose_with_mediapipe(image_path, num_of_individuals=2) + annotated = visualize_pose(pose_result) + print(f"Detected {len(pose_result.individuals)} individuals") + print("Example individual pose data:", pose_result.get_individual(0)) + except Exception as e: + print(f"Error processing image: {e}") diff --git a/src/tests/video/data_structures/pose_test.py b/src/tests/video/data_structures/pose_test.py index c1f648cb..5dd19474 100644 --- a/src/tests/video/data_structures/pose_test.py +++ b/src/tests/video/data_structures/pose_test.py @@ -3,11 +3,11 @@ import os from typing import Generator -import cv2 +import numpy as np import pytest -import torch -from senselab.video.data_structures.pose import PoseSkeleton, estimate_pose_with_mediapipe +from senselab.video.data_structures.pose import ImagePose, IndividualPose +from senselab.video.tasks.pose_estimation.pose_estimation import estimate_pose_with_mediapipe TEST_IMAGES_DIR = "src/tests/data_for_testing/pose_data/" VALID_SINGLE_PERSON_IMAGE = os.path.join(TEST_IMAGES_DIR, "single_person.jpg") @@ -16,96 +16,125 @@ INVALID_IMAGE_PATH = "invalid/path/to/image.jpg" -def test_get_landmark_coordinates() -> None: - """Tests basic landmark retrieval for first person.""" - result = estimate_pose_with_mediapipe(VALID_SINGLE_PERSON_IMAGE) - - # Test valid landmark retrieval - coords = result.get_landmark_coordinates(landmark="landmark_0", person_index=0) - assert len(coords) == 4, "Should return [x, y, z, visibility]" - assert all(isinstance(x, float) for x in coords), "All coordinates should be floats" - - # Test normalized vs world coordinates - norm_coords = result.get_landmark_coordinates(landmark="landmark_0", person_index=0, world=False) - world_coords = result.get_landmark_coordinates(landmark="landmark_0", person_index=0, world=True) - assert norm_coords != world_coords, "World and normalized coordinates should differ" - - -def test_invalid_person_index() -> None: - """Tests error handling for invalid person indices.""" - result = estimate_pose_with_mediapipe(VALID_SINGLE_PERSON_IMAGE) - - with pytest.raises(ValueError) as exc_info: - result.get_landmark_coordinates(person_index=5, landmark="landmark_0") - assert "Person index 5 is invalid" in str(exc_info.value) - assert "Image contains 1" in str(exc_info.value) - - # Test with no people - result_empty = estimate_pose_with_mediapipe(NO_PEOPLE_IMAGE) - with pytest.raises(ValueError) as exc_info: - result_empty.get_landmark_coordinates(person_index=0, landmark="landmark_0") - assert "Image contains 0 people" in str(exc_info.value) - assert "Valid indices are none" in str(exc_info.value) - - -def test_invalid_landmark_name() -> None: - """Tests error handling for invalid landmark names.""" - result = estimate_pose_with_mediapipe(VALID_SINGLE_PERSON_IMAGE) - - with pytest.raises(ValueError) as exc_info: - result.get_landmark_coordinates(person_index=0, landmark="nonexistent_landmark") - error_msg = str(exc_info.value) - assert "Landmark 'nonexistent_landmark' not found" in error_msg - assert "Available landmarks:" in error_msg - - -def test_valid_image_single_person() -> None: - """Tests pose estimation on an image with a single person.""" - result = estimate_pose_with_mediapipe(VALID_SINGLE_PERSON_IMAGE) - assert isinstance(result, PoseSkeleton), "Result should be an instance of PoseSkeleton" - assert len(result.normalized_landmarks) == 1, "There should be one detected person" - assert len(result.world_landmarks) == 1, "There should be one detected person" - assert ( - result.image.shape == torch.from_numpy(cv2.imread(VALID_SINGLE_PERSON_IMAGE)).shape - ), "Input and output image shapes should match" - - -def test_valid_image_multiple_people() -> None: - """Tests pose estimation on an image with multiple people.""" - result = estimate_pose_with_mediapipe(VALID_MULTIPLE_PEOPLE_IMAGE, 3) - assert isinstance(result, PoseSkeleton), "Result should be an instance of PoseSkeleton" - assert len(result.normalized_landmarks) > 1, "There should be multiple detected people" - assert len(result.world_landmarks) > 1, "There should be multiple detected people" - - -def test_no_people_in_image() -> None: - """Tests pose estimation on an image with no people.""" - result = estimate_pose_with_mediapipe(NO_PEOPLE_IMAGE) - assert isinstance(result, PoseSkeleton), "Result should be an instance of PoseSkeleton" - assert len(result.normalized_landmarks) == 0, "No landmarks should be detected" - assert len(result.world_landmarks) == 0, "No landmarks should be detected" - - -def test_invalid_image_path() -> None: - """Tests pose estimation on an invalid image path.""" - with pytest.raises(Exception): - estimate_pose_with_mediapipe(INVALID_IMAGE_PATH) - - -def test_visualization_single_person() -> None: - """Tests visualization and saving of annotated images.""" - result = estimate_pose_with_mediapipe(VALID_SINGLE_PERSON_IMAGE) - result.visualize_pose() - assert os.path.exists("pose_estimation_output.png"), "Annotated image should be saved" +class TestIndividualPose: + """Test suite for IndividualPose class.""" + + @pytest.fixture + def sample_individual(self) -> IndividualPose: + """Create a sample IndividualPose for testing.""" + return IndividualPose( + individual_index=0, + normalized_landmarks={"landmark_0": [0.5, 0.5, 0.0, 1.0], "landmark_1": [0.6, 0.6, 0.1, 0.9]}, + world_landmarks={"landmark_0": [0.5, 1.5, 0.0, 1.0], "landmark_1": [0.6, 1.6, 0.1, 0.9]}, + ) + + def test_landmark_validation(self) -> None: + """Test landmark validation during initialization.""" + # Test valid initialization + valid_pose = IndividualPose( + individual_index=0, + normalized_landmarks={"landmark_0": [0.1, 0.2, 0.3, 0.4]}, + world_landmarks={"landmark_0": [0.1, 0.2, 0.3, 0.4]}, + ) + assert valid_pose is not None + + # Test invalid number of coordinates + with pytest.raises(ValueError) as exc_info: + IndividualPose( + individual_index=0, + normalized_landmarks={"landmark_0": [0.1, 0.2, 0.3]}, + world_landmarks={"landmark_0": [0.1, 0.2, 0.3, 0.4]}, + ) + assert "Each landmark must have exactly 4 coordinates" in str(exc_info.value) + + def test_get_landmark_coordinates(self, sample_individual: IndividualPose) -> None: + """Test landmark coordinate retrieval.""" + # Test normalized coordinates + coords = sample_individual.get_landmark_coordinates("landmark_0", world=False) + assert len(coords) == 4 + assert coords == [0.5, 0.5, 0.0, 1.0] + + # Test world coordinates + world_coords = sample_individual.get_landmark_coordinates("landmark_0", world=True) + assert len(world_coords) == 4 + assert world_coords == [0.5, 1.5, 0.0, 1.0] + + # Test invalid landmark + with pytest.raises(ValueError) as exc_info: + sample_individual.get_landmark_coordinates("nonexistent_landmark") + assert "Landmark 'nonexistent_landmark' not found" in str(exc_info.value) + assert "Available landmarks:" in str(exc_info.value) + + +class TestImagePose: + """Test suite for ImagePose class.""" + + @pytest.fixture + def sample_image_pose(self) -> ImagePose: + """Create a sample ImagePose for testing.""" + individual = IndividualPose( + individual_index=0, + normalized_landmarks={"landmark_0": [0.5, 0.5, 0.0, 1.0]}, + world_landmarks={"landmark_0": [0.5, 1.5, 0.0, 1.0]}, + ) + return ImagePose(image=np.zeros((100, 100, 3)), individuals=[individual]) + + def test_get_individual(self, sample_image_pose: ImagePose) -> None: + """Test individual retrieval.""" + # Test valid index + individual = sample_image_pose.get_individual(0) + assert isinstance(individual, IndividualPose) + assert individual.individual_index == 0 + + # Test invalid index + with pytest.raises(ValueError) as exc_info: + sample_image_pose.get_individual(1) + assert "Individual index 1 is invalid" in str(exc_info.value) + + def test_empty_image_pose(self) -> None: + """Test ImagePose with no individuals.""" + empty_pose = ImagePose(image=np.zeros((100, 100, 3)), individuals=[]) + with pytest.raises(ValueError) as exc_info: + empty_pose.get_individual(0) + assert "Valid indices are none" in str(exc_info.value) + + +class TestIntegration: + """Integration tests with MediaPipe.""" + + def test_mediapipe_single_person(self) -> None: + """Test full pipeline with single person image.""" + result = estimate_pose_with_mediapipe(VALID_SINGLE_PERSON_IMAGE) + assert isinstance(result, ImagePose) + assert len(result.individuals) == 1 + + individual = result.get_individual(0) + assert isinstance(individual, IndividualPose) + assert individual.individual_index == 0 + assert len(individual.normalized_landmarks) > 0 + assert len(individual.world_landmarks) > 0 + + def test_mediapipe_multiple_people(self) -> None: + """Test full pipeline with multiple people image.""" + result = estimate_pose_with_mediapipe(VALID_MULTIPLE_PEOPLE_IMAGE, num_of_individuals=3) + assert isinstance(result, ImagePose) + assert len(result.individuals) == 3 + + def test_mediapipe_no_people(self) -> None: + """Test full pipeline with image containing no people.""" + result = estimate_pose_with_mediapipe(NO_PEOPLE_IMAGE) + assert isinstance(result, ImagePose) + assert len(result.individuals) == 0 + + def test_invalid_image_path(self) -> None: + """Test error handling for invalid image path.""" + with pytest.raises(FileNotFoundError): + estimate_pose_with_mediapipe(INVALID_IMAGE_PATH) @pytest.fixture(autouse=True) def cleanup() -> Generator[None, None, None]: - """Clean up any generated files after tests. - - Yields: - None - """ + """Clean up any generated files after tests.""" yield if os.path.exists("pose_estimation_output.png"): os.remove("pose_estimation_output.png")