diff --git a/src/jaxsim/parsers/descriptions/collision.py b/src/jaxsim/parsers/descriptions/collision.py index c5eb8314f..6e232e5c5 100644 --- a/src/jaxsim/parsers/descriptions/collision.py +++ b/src/jaxsim/parsers/descriptions/collision.py @@ -128,6 +128,8 @@ class MeshCollision(CollisionShape): def __eq__(self, other: Any) -> bool: if not isinstance(other, MeshCollision): return False - return len(self.collidable_points) == len( - other.collidable_points - ) and super().__eq__(other) and (self.center == other.center).all() + return ( + len(self.collidable_points) == len(other.collidable_points) + and super().__eq__(other) + and (self.center == other.center).all() + ) diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index 40b9563d1..c53622af6 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -330,7 +330,7 @@ def extract_model_data( mesh_collision = utils.create_mesh_collision( collision=collision, link_description=links_dict[link.name], - method=utils.MeshMappingMethods.UniformSurfaceSampling + method=utils.MeshMappingMethods.UniformSurfaceSampling, ) if mesh_collision is not None: collisions.append(mesh_collision) diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index 2c90e47c7..002c816fc 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -244,20 +244,13 @@ def create_mesh_collision( case MeshMappingMethods.RandomSurfaceSampling: points = mesh.sample(nsamples) case MeshMappingMethods.UniformSurfaceSampling: - points = trimesh.sample.sample_surface_even( - mesh=mesh, - count=nsamples - ) + points = trimesh.sample.sample_surface_even(mesh=mesh, count=nsamples) case _: raise ValueError("Invalid mesh mapping method") points = mesh.vertices - H = ( - collision.pose.transform() if collision.pose is not None else np.eye(4) - ) - center_of_collision_wrt_link = (H @ np.hstack([0, 0, 0, 1.0]))[ - 0:-1 - ] + H = collision.pose.transform() if collision.pose is not None else np.eye(4) + center_of_collision_wrt_link = (H @ np.hstack([0, 0, 0, 1.0]))[0:-1] mesh_points_wrt_link = ( H @ np.hstack([points, np.vstack([1.0] * points.shape[0])]).T )[0:3, :]