diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index 5950a0fa8..ca69e12e0 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -117,22 +117,19 @@ def create_box_collision( center = np.array([x / 2, y / 2, z / 2]) - box_corners = ( - np.vstack( - [ - np.array([0, 0, 0]), - np.array([x, 0, 0]), - np.array([x, y, 0]), - np.array([0, y, 0]), - np.array([0, 0, z]), - np.array([x, 0, z]), - np.array([x, y, z]), - np.array([0, y, z]), - ] - ) - - center + # Define the bottom corners. + bottom_corners = np.array([[0, 0, 0], [x, 0, 0], [x, y, 0], [0, y, 0]]) + + # Conditionally add the top corners based on the environment variable. + top_corners = ( + np.array([[0, 0, z], [x, 0, z], [x, y, z], [0, y, z]]) + if not os.environ.get("JAXSIM_COLLISION_USE_BOTTOM_ONLY", "0") + else [] ) + # Combine and shift by the center + box_corners = np.vstack([bottom_corners, *top_corners]) - center + H = collision.pose.transform() if collision.pose is not None else np.eye(4) center_wrt_link = (H @ np.hstack([center, 1.0]))[0:-1] @@ -170,23 +167,31 @@ def create_sphere_collision( # From https://stackoverflow.com/a/26127012 def fibonacci_sphere(samples: int) -> npt.NDArray: - points = [] - phi = np.pi * (3.0 - np.sqrt(5.0)) # golden angle in radians - - for i in range(samples): - y = 1 - (i / float(samples - 1)) * 2 # y goes from 1 to -1 - radius = np.sqrt(1 - y * y) # radius at y - - theta = phi * i # golden angle increment - - x = np.cos(theta) * radius - z = np.sin(theta) * radius + # Get the golden ratio in radians. + phi = np.pi * (3.0 - np.sqrt(5.0)) + + # Generate the points. + points = [ + np.array( + [ + np.cos(phi * i) + * np.sqrt(1 - (y := 1 - 2 * i / (samples - 1)) ** 2), + y, + np.sin(phi * i) * np.sqrt(1 - y**2), + ] + ) + for i in range(samples) + ] - points.append(np.array([x, y, z])) + # Filter to keep only the bottom half if required. + if os.environ.get("JAXSIM_COLLISION_USE_BOTTOM_ONLY", "0"): + # Keep only the points with y <= 0. + points = [point for point in points if point[2] <= 0] return np.vstack(points) r = collision.geometry.sphere.radius + sphere_points = r * fibonacci_sphere( samples=int(os.getenv(key="JAXSIM_COLLISION_SPHERE_POINTS", default="50")) )