diff --git a/environment.yml b/environment.yml index bc78942a2..2603b0c82 100644 --- a/environment.yml +++ b/environment.yml @@ -15,6 +15,7 @@ dependencies: - pptree - qpax - rod >= 0.3.3 + - trimesh - typing_extensions # python<3.12 # ==================================== # Optional dependencies from setup.cfg diff --git a/pyproject.toml b/pyproject.toml index f7ca141ec..1209baabe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ "qpax", "rod >= 0.3.3", "typing_extensions ; python_version < '3.12'", + "trimesh", ] [project.optional-dependencies] @@ -67,7 +68,7 @@ testing = [ "idyntree >= 12.2.1", "pytest >=6.0", "pytest-icdiff", - "robot-descriptions", + "robot-descriptions" ] viz = [ "lxml", diff --git a/src/jaxsim/parsers/descriptions/__init__.py b/src/jaxsim/parsers/descriptions/__init__.py index 9e180c155..ff3bf631d 100644 --- a/src/jaxsim/parsers/descriptions/__init__.py +++ b/src/jaxsim/parsers/descriptions/__init__.py @@ -1,4 +1,10 @@ -from .collision import BoxCollision, CollidablePoint, CollisionShape, SphereCollision +from .collision import ( + BoxCollision, + CollidablePoint, + CollisionShape, + MeshCollision, + SphereCollision, +) from .joint import JointDescription, JointGenericAxis, JointType from .link import LinkDescription from .model import ModelDescription diff --git a/src/jaxsim/parsers/descriptions/collision.py b/src/jaxsim/parsers/descriptions/collision.py index 31ae17b97..7815af488 100644 --- a/src/jaxsim/parsers/descriptions/collision.py +++ b/src/jaxsim/parsers/descriptions/collision.py @@ -154,3 +154,22 @@ def __eq__(self, other: BoxCollision) -> bool: return False return hash(self) == hash(other) + + +@dataclasses.dataclass +class MeshCollision(CollisionShape): + center: jtp.VectorLike + + def __hash__(self) -> int: + return hash( + ( + hash(tuple(self.center.tolist())), + hash(self.collidable_points), + ) + ) + + def __eq__(self, other: MeshCollision) -> bool: + if not isinstance(other, MeshCollision): + return False + + return hash(self) == hash(other) diff --git a/src/jaxsim/parsers/rod/meshes.py b/src/jaxsim/parsers/rod/meshes.py new file mode 100644 index 000000000..9d1ada7b0 --- /dev/null +++ b/src/jaxsim/parsers/rod/meshes.py @@ -0,0 +1,104 @@ +import numpy as np +import trimesh + +VALID_AXIS = {"x": 0, "y": 1, "z": 2} + + +def extract_points_vertices(mesh: trimesh.Trimesh) -> np.ndarray: + """ + Extracts the vertices of a mesh as points. + """ + return mesh.vertices + + +def extract_points_random_surface_sampling(mesh: trimesh.Trimesh, n) -> np.ndarray: + """ + Extracts N random points from the surface of a mesh. + + Args: + mesh: The mesh from which to extract points. + n: The number of points to extract. + + Returns: + The extracted points (N x 3 array). + """ + + return mesh.sample(n) + + +def extract_points_uniform_surface_sampling( + mesh: trimesh.Trimesh, n: int +) -> np.ndarray: + """ + Extracts N uniformly sampled points from the surface of a mesh. + + Args: + mesh: The mesh from which to extract points. + n: The number of points to extract. + + Returns: + The extracted points (N x 3 array). + """ + + return trimesh.sample.sample_surface_even(mesh=mesh, count=n)[0] + + +def extract_points_select_points_over_axis( + mesh: trimesh.Trimesh, axis: str, direction: str, n: int +) -> np.ndarray: + """ + Extracts N points from a mesh along a specified axis. The points are selected based on their position along the axis. + + Args: + mesh: The mesh from which to extract points. + axis: The axis along which to extract points. + direction: The direction along the axis from which to extract points. Valid values are "higher" and "lower". + n: The number of points to extract. + + Returns: + The extracted points (N x 3 array). + """ + + dirs = {"higher": np.s_[-n:], "lower": np.s_[:n]} + arr = mesh.vertices + + # Sort rows lexicographically first, then columnar. + arr.sort(axis=0) + sorted_arr = arr[dirs[direction]] + return sorted_arr + + +def extract_points_aap( + mesh: trimesh.Trimesh, + axis: str, + upper: float | None = None, + lower: float | None = None, +) -> np.ndarray: + """ + Extracts points from a mesh along a specified axis within a specified range. The points are selected based on their position along the axis. + + Args: + mesh: The mesh from which to extract points. + axis: The axis along which to extract points. + upper: The upper bound of the range. + lower: The lower bound of the range. + + Returns: + The extracted points (N x 3 array). + + Raises: + AssertionError: If the lower bound is greater than the upper bound. + """ + + # Check bounds. + upper = upper if upper is not None else np.inf + lower = lower if lower is not None else -np.inf + assert lower < upper, "Invalid bounds for axis-aligned plane" + + # Logic. + points = mesh.vertices[ + (mesh.vertices[:, VALID_AXIS[axis]] >= lower) + & (mesh.vertices[:, VALID_AXIS[axis]] <= upper) + ] + + return points diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index 8d359c6df..fc23420ae 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -334,6 +334,18 @@ def extract_model_data( collisions.append(sphere_collision) + if collision.geometry.mesh is not None and int( + os.environ.get("JAXSIM_COLLISION_MESH_ENABLED", "0") + ): + logging.warning("Mesh collision support is still experimental.") + mesh_collision = utils.create_mesh_collision( + collision=collision, + link_description=links_dict[link.name], + method=utils.meshes.extract_points_vertices, + ) + + collisions.append(mesh_collision) + return SDFData( model_name=sdf_model.name, link_descriptions=links, diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index dd83b1dde..5950a0fa8 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -1,12 +1,21 @@ import os +import pathlib +from collections.abc import Callable +from typing import TypeVar import numpy as np import numpy.typing as npt import rod +import trimesh +from rod.utils.resolve_uris import resolve_local_uri import jaxsim.typing as jtp +from jaxsim import logging from jaxsim.math import Adjoint, Inertia from jaxsim.parsers import descriptions +from jaxsim.parsers.rod import meshes + +MeshMappingMethod = TypeVar("MeshMappingMethod", bound=Callable[..., npt.NDArray]) def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix: @@ -202,3 +211,47 @@ def fibonacci_sphere(samples: int) -> npt.NDArray: return descriptions.SphereCollision( collidable_points=collidable_points, center=center_wrt_link ) + + +def create_mesh_collision( + collision: rod.Collision, + link_description: descriptions.LinkDescription, + method: MeshMappingMethod = None, +) -> descriptions.MeshCollision: + + file = pathlib.Path(resolve_local_uri(uri=collision.geometry.mesh.uri)) + _file_type = file.suffix.replace(".", "") + mesh = trimesh.load_mesh(file, file_type=_file_type) + + if mesh.is_empty: + raise RuntimeError(f"Failed to process '{file}' with trimesh") + + mesh.apply_scale(collision.geometry.mesh.scale) + logging.info( + msg=f"Loading mesh {collision.geometry.mesh.uri} with scale {collision.geometry.mesh.scale}, file type '{_file_type}'" + ) + + if method is None: + method = meshes.VertexExtraction() + logging.debug("Using default Vertex Extraction method for mesh wrapping") + else: + logging.debug(f"Using method {method} for mesh wrapping") + + points = method(mesh=mesh) + logging.debug(f"Extracted {len(points)} points from mesh") + + W_H_L = collision.pose.transform() if collision.pose is not None else np.eye(4) + + # Extract translation from transformation matrix + W_p_L = W_H_L[:3, 3] + mesh_points_wrt_link = points @ W_H_L[:3, :3].T + W_p_L + collidable_points = [ + descriptions.CollidablePoint( + parent_link=link_description, + position=point, + enabled=True, + ) + for point in mesh_points_wrt_link + ] + + return descriptions.MeshCollision(collidable_points=collidable_points, center=W_p_L) diff --git a/tests/test_meshes.py b/tests/test_meshes.py new file mode 100644 index 000000000..58fcb9827 --- /dev/null +++ b/tests/test_meshes.py @@ -0,0 +1,100 @@ +import trimesh + +from jaxsim.parsers.rod import meshes + + +def test_mesh_wrapping_vertex_extraction(): + """ + Test the vertex extraction method on different meshes. + 1. A simple box + 2. A sphere + """ + + # Test 1: A simple box. + # First, create a box with origin at (0,0,0) and extents (3,3,3), + # i.e. points span from -1.5 to 1.5 on the axis. + mesh = trimesh.creation.box( + extents=[3.0, 3.0, 3.0], + ) + points = meshes.extract_points_vertices(mesh=mesh) + assert len(points) == len(mesh.vertices) + + # Test 2: A sphere. + # The sphere is centered at the origin and has a radius of 1.0. + mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0) + points = meshes.extract_points_vertices(mesh=mesh) + assert len(points) == len(mesh.vertices) + + +def test_mesh_wrapping_aap(): + """ + Test the AAP wrapping method on different meshes. + 1. A simple box + 1.1: Remove all points above x=0.0 + 1.2: Remove all points below y=0.0 + 2. A sphere + """ + + # Test 1.1: Remove all points above x=0.0. + # The expected result is that the number of points is halved. + # First, create a box with origin at (0,0,0) and extents (3,3,3), + # i.e. points span from -1.5 to 1.5 on the axis. + mesh = trimesh.creation.box(extents=[3.0, 3.0, 3.0]) + points = meshes.extract_points_aap(mesh=mesh, axis="x", lower=0.0) + assert len(points) == len(mesh.vertices) // 2 + assert all(points[:, 0] > 0.0) + + # Test 1.2: Remove all points below y=0.0. + # The expected result is that the number of points is halved. + points = meshes.extract_points_aap(mesh=mesh, axis="y", upper=0.0) + assert len(points) == len(mesh.vertices) // 2 + assert all(points[:, 1] < 0.0) + + # Test 2: A sphere. + # The sphere is centered at the origin and has a radius of 1.0. + # Points are expected to be halved. + mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0) + + # Remove all points above y=0.0. + points = meshes.extract_points_aap(mesh=mesh, axis="y", lower=0.0) + assert all(points[:, 1] >= 0.0) + assert len(points) < len(mesh.vertices) + + +def test_mesh_wrapping_points_over_axis(): + """ + Test the points over axis method on different meshes. + 1. A simple box + 1.1: Select 10 points from the lower end of the x-axis + 1.2: Select 10 points from the higher end of the y-axis + 2. A sphere + """ + + # Test 1.1: Remove 10 points from the lower end of the x-axis. + # First, create a box with origin at (0,0,0) and extents (3,3,3), + # i.e. points span from -1.5 to 1.5 on the axis. + mesh = trimesh.creation.box(extents=[3.0, 3.0, 3.0]) + points = meshes.extract_points_select_points_over_axis( + mesh=mesh, axis="x", direction="lower", n=4 + ) + assert len(points) == 4 + assert all(points[:, 0] < 0.0) + + # Test 1.2: Select 10 points from the higher end of the y-axis. + points = meshes.extract_points_select_points_over_axis( + mesh=mesh, axis="y", direction="higher", n=4 + ) + assert len(points) == 4 + assert all(points[:, 1] > 0.0) + + # Test 2: A sphere. + # The sphere is centered at the origin and has a radius of 1.0. + mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0) + sphere_n_vertices = len(mesh.vertices) + + # Select 10 points from the higher end of the z-axis. + points = meshes.extract_points_select_points_over_axis( + mesh=mesh, axis="z", direction="higher", n=sphere_n_vertices // 2 + ) + assert len(points) == sphere_n_vertices // 2 + assert all(points[:, 2] >= 0.0)