Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mesh support #156

Merged
merged 44 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
3936638
Moved from jax.numpy to numpy in PlaneTerrain __eq__ magic to bypass …
lorycontixd Jul 5, 2024
d07e4d1
Set env var when parsing from `robot-descriptions`
flferretti May 13, 2024
5d16532
Initial version of mesh support
lorycontixd May 13, 2024
1663a1f
Format and lint
flferretti May 14, 2024
6e968fc
Use already existing env var to solve mesh URIs
flferretti May 17, 2024
2c5f78f
Skip loading empty meshes
flferretti May 20, 2024
990ea06
Added `networkx` as testing dependency
lorycontixd May 16, 2024
46f7164
Address to reviews:
lorycontixd May 20, 2024
2f25c0e
Added trimesh dependecy for conda-forge
lorycontixd May 21, 2024
b15ea55
Moved mesh parsing logic inside mesh collision function
lorycontixd May 21, 2024
644ce43
Implemented UniformSurfaceSampling for mesh point wrapping
lorycontixd May 21, 2024
b0ddec1
Address reviews
lorycontixd May 21, 2024
bcf5e48
Pre-commit
lorycontixd May 21, 2024
07b8402
Update `__eq__` magic and type hints
flferretti Jun 13, 2024
d7fd1b3
Removed unused lines in conftest
lorycontixd Jun 20, 2024
94c2f81
Fixed typo on logging message
lorycontixd Jun 20, 2024
cf00335
Removed unused import in conftest
lorycontixd Jun 20, 2024
ed51e85
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 5, 2024
5af51da
Implemented initial reviews from https://github.com/ami-iit/jaxsim/pu…
lorycontixd Jul 5, 2024
f9475b0
First draft of new mesh wrapping algorithms
lorycontixd Jul 17, 2024
d434e44
Implemented structure for new mesh wrapping algorithms
lorycontixd Jul 17, 2024
22da2cd
New mesh wrapping algorithms
lorycontixd Jul 17, 2024
8058b7c
New mesh wrapping algorithms with relative tests
lorycontixd Jul 18, 2024
2d07347
Renamed some parameters
lorycontixd Jul 18, 2024
72ce440
Restructured mesh mapping methods to follow inheritance
lorycontixd Jul 18, 2024
6acb19f
Run pre-commit
lorycontixd Jul 18, 2024
f8ecf24
Removed leftover parameters on create_mesh_collision
lorycontixd Jul 18, 2024
4f6cf0a
Removed wrong point selection & added logs
lorycontixd Jul 18, 2024
148fcfe
Added string magics on wrapping methods
lorycontixd Jul 18, 2024
98012e5
Added docstrings to mesh wrapping algorithms
lorycontixd Nov 14, 2024
d2a5e89
Added string magics on wrapping methods
lorycontixd Jul 18, 2024
fec016a
Implemented reviews
lorycontixd Nov 15, 2024
9238af5
Fixed error on array sorting and relative test
lorycontixd Nov 15, 2024
ad30c29
Addressed reviews
lorycontixd Nov 15, 2024
8d0380e
Added experimental feature warning for mesh parsing
lorycontixd Nov 15, 2024
8a29349
Updated variable names
lorycontixd Nov 15, 2024
e76b0b9
Added int casting on mesh_enabled flag
lorycontixd Nov 15, 2024
90451e8
Removed extra search paths in ergocub model building
lorycontixd Nov 15, 2024
495b799
Apply suggestions from code review
lorycontixd Nov 15, 2024
bb21a16
Precommit fix
lorycontixd Nov 15, 2024
1738a87
Removed unused dependency from pyprojecj
lorycontixd Nov 15, 2024
eff915d
Removed unused function
lorycontixd Nov 15, 2024
2d46a70
Fixed minor commenting format
lorycontixd Nov 15, 2024
fe2616c
Removed whitespaces
lorycontixd Nov 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies:
- pptree
- qpax
- rod >= 0.3.3
- trimesh
- typing_extensions # python<3.12
# ====================================
# Optional dependencies from setup.cfg
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ dependencies = [
"qpax",
"rod >= 0.3.3",
"typing_extensions ; python_version < '3.12'",
"trimesh",
]

[project.optional-dependencies]
Expand All @@ -67,7 +68,7 @@ testing = [
"idyntree >= 12.2.1",
"pytest >=6.0",
"pytest-icdiff",
"robot-descriptions",
"robot-descriptions"
]
viz = [
"lxml",
Expand Down
8 changes: 7 additions & 1 deletion src/jaxsim/parsers/descriptions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from .collision import BoxCollision, CollidablePoint, CollisionShape, SphereCollision
from .collision import (
BoxCollision,
CollidablePoint,
CollisionShape,
MeshCollision,
SphereCollision,
)
from .joint import JointDescription, JointGenericAxis, JointType
from .link import LinkDescription
from .model import ModelDescription
19 changes: 19 additions & 0 deletions src/jaxsim/parsers/descriptions/collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,22 @@ def __eq__(self, other: BoxCollision) -> bool:
return False

return hash(self) == hash(other)


@dataclasses.dataclass
class MeshCollision(CollisionShape):
center: jtp.VectorLike

def __hash__(self) -> int:
return hash(
(
hash(tuple(self.center.tolist())),
hash(self.collidable_points),
)
)

def __eq__(self, other: MeshCollision) -> bool:
if not isinstance(other, MeshCollision):
return False

return hash(self) == hash(other)
104 changes: 104 additions & 0 deletions src/jaxsim/parsers/rod/meshes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import numpy as np
import trimesh

VALID_AXIS = {"x": 0, "y": 1, "z": 2}


def extract_points_vertices(mesh: trimesh.Trimesh) -> np.ndarray:
"""
Extracts the vertices of a mesh as points.
"""
return mesh.vertices


def extract_points_random_surface_sampling(mesh: trimesh.Trimesh, n) -> np.ndarray:
"""
Extracts N random points from the surface of a mesh.

Args:
mesh: The mesh from which to extract points.
n: The number of points to extract.

Returns:
The extracted points (N x 3 array).
"""

return mesh.sample(n)


def extract_points_uniform_surface_sampling(
mesh: trimesh.Trimesh, n: int
) -> np.ndarray:
"""
Extracts N uniformly sampled points from the surface of a mesh.

Args:
mesh: The mesh from which to extract points.
n: The number of points to extract.

Returns:
The extracted points (N x 3 array).
"""

return trimesh.sample.sample_surface_even(mesh=mesh, count=n)[0]


def extract_points_select_points_over_axis(
mesh: trimesh.Trimesh, axis: str, direction: str, n: int
) -> np.ndarray:
"""
Extracts N points from a mesh along a specified axis. The points are selected based on their position along the axis.

Args:
mesh: The mesh from which to extract points.
axis: The axis along which to extract points.
direction: The direction along the axis from which to extract points. Valid values are "higher" and "lower".
n: The number of points to extract.

Returns:
The extracted points (N x 3 array).
"""

dirs = {"higher": np.s_[-n:], "lower": np.s_[:n]}
arr = mesh.vertices

# Sort rows lexicographically first, then columnar.
arr.sort(axis=0)
sorted_arr = arr[dirs[direction]]
return sorted_arr


def extract_points_aap(
mesh: trimesh.Trimesh,
axis: str,
upper: float | None = None,
lower: float | None = None,
) -> np.ndarray:
"""
Extracts points from a mesh along a specified axis within a specified range. The points are selected based on their position along the axis.

Args:
mesh: The mesh from which to extract points.
axis: The axis along which to extract points.
upper: The upper bound of the range.
lower: The lower bound of the range.

Returns:
The extracted points (N x 3 array).

Raises:
AssertionError: If the lower bound is greater than the upper bound.
"""

# Check bounds.
upper = upper if upper is not None else np.inf
lower = lower if lower is not None else -np.inf
assert lower < upper, "Invalid bounds for axis-aligned plane"

# Logic.
points = mesh.vertices[
(mesh.vertices[:, VALID_AXIS[axis]] >= lower)
& (mesh.vertices[:, VALID_AXIS[axis]] <= upper)
]

return points
12 changes: 12 additions & 0 deletions src/jaxsim/parsers/rod/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,18 @@ def extract_model_data(

collisions.append(sphere_collision)

if collision.geometry.mesh is not None and int(
os.environ.get("JAXSIM_COLLISION_MESH_ENABLED", "0")
):
logging.warning("Mesh collision support is still experimental.")
mesh_collision = utils.create_mesh_collision(
collision=collision,
link_description=links_dict[link.name],
method=utils.meshes.extract_points_vertices,
)

collisions.append(mesh_collision)

return SDFData(
model_name=sdf_model.name,
link_descriptions=links,
Expand Down
53 changes: 53 additions & 0 deletions src/jaxsim/parsers/rod/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import os
import pathlib
from collections.abc import Callable
from typing import TypeVar

import numpy as np
import numpy.typing as npt
import rod
import trimesh
from rod.utils.resolve_uris import resolve_local_uri

import jaxsim.typing as jtp
from jaxsim import logging
from jaxsim.math import Adjoint, Inertia
from jaxsim.parsers import descriptions
from jaxsim.parsers.rod import meshes

MeshMappingMethod = TypeVar("MeshMappingMethod", bound=Callable[..., npt.NDArray])


def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
Expand Down Expand Up @@ -202,3 +211,47 @@ def fibonacci_sphere(samples: int) -> npt.NDArray:
return descriptions.SphereCollision(
collidable_points=collidable_points, center=center_wrt_link
)


def create_mesh_collision(
collision: rod.Collision,
link_description: descriptions.LinkDescription,
method: MeshMappingMethod = None,
) -> descriptions.MeshCollision:

file = pathlib.Path(resolve_local_uri(uri=collision.geometry.mesh.uri))
_file_type = file.suffix.replace(".", "")
mesh = trimesh.load_mesh(file, file_type=_file_type)

if mesh.is_empty:
raise RuntimeError(f"Failed to process '{file}' with trimesh")

mesh.apply_scale(collision.geometry.mesh.scale)
logging.info(
msg=f"Loading mesh {collision.geometry.mesh.uri} with scale {collision.geometry.mesh.scale}, file type '{_file_type}'"
)

if method is None:
method = meshes.VertexExtraction()
logging.debug("Using default Vertex Extraction method for mesh wrapping")
else:
logging.debug(f"Using method {method} for mesh wrapping")

points = method(mesh=mesh)
logging.debug(f"Extracted {len(points)} points from mesh")

W_H_L = collision.pose.transform() if collision.pose is not None else np.eye(4)

# Extract translation from transformation matrix
W_p_L = W_H_L[:3, 3]
mesh_points_wrt_link = points @ W_H_L[:3, :3].T + W_p_L
collidable_points = [
descriptions.CollidablePoint(
parent_link=link_description,
position=point,
enabled=True,
)
for point in mesh_points_wrt_link
]

return descriptions.MeshCollision(collidable_points=collidable_points, center=W_p_L)
100 changes: 100 additions & 0 deletions tests/test_meshes.py
flferretti marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import trimesh

from jaxsim.parsers.rod import meshes


def test_mesh_wrapping_vertex_extraction():
"""
Test the vertex extraction method on different meshes.
1. A simple box
2. A sphere
"""

# Test 1: A simple box.
# First, create a box with origin at (0,0,0) and extents (3,3,3),
# i.e. points span from -1.5 to 1.5 on the axis.
mesh = trimesh.creation.box(
extents=[3.0, 3.0, 3.0],
)
points = meshes.extract_points_vertices(mesh=mesh)
assert len(points) == len(mesh.vertices)

# Test 2: A sphere.
# The sphere is centered at the origin and has a radius of 1.0.
mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0)
points = meshes.extract_points_vertices(mesh=mesh)
assert len(points) == len(mesh.vertices)


def test_mesh_wrapping_aap():
"""
Test the AAP wrapping method on different meshes.
1. A simple box
1.1: Remove all points above x=0.0
1.2: Remove all points below y=0.0
2. A sphere
"""

# Test 1.1: Remove all points above x=0.0.
# The expected result is that the number of points is halved.
# First, create a box with origin at (0,0,0) and extents (3,3,3),
# i.e. points span from -1.5 to 1.5 on the axis.
mesh = trimesh.creation.box(extents=[3.0, 3.0, 3.0])
points = meshes.extract_points_aap(mesh=mesh, axis="x", lower=0.0)
assert len(points) == len(mesh.vertices) // 2
assert all(points[:, 0] > 0.0)

# Test 1.2: Remove all points below y=0.0.
# The expected result is that the number of points is halved.
points = meshes.extract_points_aap(mesh=mesh, axis="y", upper=0.0)
assert len(points) == len(mesh.vertices) // 2
assert all(points[:, 1] < 0.0)

# Test 2: A sphere.
# The sphere is centered at the origin and has a radius of 1.0.
# Points are expected to be halved.
mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0)

# Remove all points above y=0.0.
points = meshes.extract_points_aap(mesh=mesh, axis="y", lower=0.0)
assert all(points[:, 1] >= 0.0)
assert len(points) < len(mesh.vertices)


def test_mesh_wrapping_points_over_axis():
"""
Test the points over axis method on different meshes.
1. A simple box
1.1: Select 10 points from the lower end of the x-axis
1.2: Select 10 points from the higher end of the y-axis
2. A sphere
"""

# Test 1.1: Remove 10 points from the lower end of the x-axis.
# First, create a box with origin at (0,0,0) and extents (3,3,3),
# i.e. points span from -1.5 to 1.5 on the axis.
mesh = trimesh.creation.box(extents=[3.0, 3.0, 3.0])
points = meshes.extract_points_select_points_over_axis(
mesh=mesh, axis="x", direction="lower", n=4
)
assert len(points) == 4
assert all(points[:, 0] < 0.0)

# Test 1.2: Select 10 points from the higher end of the y-axis.
points = meshes.extract_points_select_points_over_axis(
mesh=mesh, axis="y", direction="higher", n=4
)
assert len(points) == 4
assert all(points[:, 1] > 0.0)

# Test 2: A sphere.
# The sphere is centered at the origin and has a radius of 1.0.
mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0)
sphere_n_vertices = len(mesh.vertices)

# Select 10 points from the higher end of the z-axis.
points = meshes.extract_points_select_points_over_axis(
mesh=mesh, axis="z", direction="higher", n=sphere_n_vertices // 2
)
assert len(points) == sphere_n_vertices // 2
assert all(points[:, 2] >= 0.0)