diff --git a/src/jaxsim/rbda/collidable_points.py b/src/jaxsim/rbda/collidable_points.py index 4e6800782..791b07cb3 100644 --- a/src/jaxsim/rbda/collidable_points.py +++ b/src/jaxsim/rbda/collidable_points.py @@ -118,7 +118,7 @@ def propagate_kinematics( # ================================================== def process_point_kinematics( - Li_p_C: jtp.Vector, parent_body: jtp.Int + Li_p_C: jtp.Vector, parent_body: jtp.Int, W_X_i: jtp.Matrix, W_v_Wi: jtp.Matrix ) -> tuple[jtp.Vector, jtp.Vector]: # Compute the position of the collidable point. @@ -135,9 +135,11 @@ def process_point_kinematics( return W_p_Ci, CW_vl_WCi # Process all the collidable points in parallel. - W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)( + W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics, in_axes=(0, 0, None, None))( model.kin_dyn_parameters.contact_parameters.point, jnp.array(model.kin_dyn_parameters.contact_parameters.body), + W_X_i, + W_v_Wi, ) return W_p_Ci, CW_vl_WC