From f2889a783ec5478135357dc2c1697186fd4de4c4 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 11 Dec 2024 14:31:51 +0100 Subject: [PATCH] Speed up computation for `RigidContacts` model --- src/jaxsim/api/model.py | 3 -- src/jaxsim/rbda/contacts/rigid.py | 62 +++++++++++-------------------- 2 files changed, 22 insertions(+), 43 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index e12837a20..27e7cd616 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2361,9 +2361,6 @@ def step( # Hence, here we need to reset the velocity after each impact to guarantee that # the linear velocity of the active collidable points is zero. case jaxsim.rbda.contacts.RigidContacts(): - assert isinstance( - data_tf.contacts_params, jaxsim.rbda.contacts.RigidContactsParams - ) # Raise runtime error for not supported case in which Rigid contacts and # Baumgarte stabilization are enabled and used with ForwardEuler integrator. diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 591451172..bc72bbee9 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -182,52 +182,34 @@ def compute_impact_velocity( data: The `JaxSimModelData` instance. """ - def impact_velocity( - inactive_collidable_points: jtp.ArrayLike, - nu_pre: jtp.ArrayLike, - M: jtp.MatrixLike, - J_WC: jtp.MatrixLike, - data: js.data.JaxSimModelData, - ): - # Compute system velocity after impact maintaining zero linear velocity of active points - with data.switch_velocity_representation(VelRepr.Mixed): - sl = jnp.s_[:, 0:3, :] - Jl_WC = J_WC[sl] - # Zero out the jacobian rows of inactive points - Jl_WC = jnp.vstack( - jnp.where( - inactive_collidable_points[:, jnp.newaxis, jnp.newaxis], - jnp.zeros_like(Jl_WC), - Jl_WC, - ) - ) + # Compute system velocity after impact maintaining zero linear velocity of active points. + with data.switch_velocity_representation(VelRepr.Mixed): - A = jnp.vstack( - [ - jnp.hstack([M, -Jl_WC.T]), - jnp.hstack( - [Jl_WC, jnp.zeros((Jl_WC.shape[0], Jl_WC.shape[0]))] - ), - ] - ) - b = jnp.hstack([M @ nu_pre, jnp.zeros(Jl_WC.shape[0])]) - x = jnp.linalg.lstsq(A, b)[0] - nu_post = x[0 : M.shape[0]] + BW_ν_pre_impact = data.generalized_velocity() - return nu_post + sl = jnp.s_[:, 0:3, :] + Jl_WC = J_WC[sl] - with data.switch_velocity_representation(VelRepr.Mixed): - BW_ν_pre_impact = data.generalized_velocity() + # Zero out the jacobian rows of inactive points. + Jl_WC = jnp.vstack( + jnp.where( + inactive_collidable_points[:, jnp.newaxis, jnp.newaxis], + jnp.zeros_like(Jl_WC), + Jl_WC, + ) + ) - BW_ν_post_impact = impact_velocity( - data=data, - inactive_collidable_points=inactive_collidable_points, - nu_pre=BW_ν_pre_impact, - M=M, - J_WC=J_WC, + A = jnp.vstack( + [ + jnp.hstack([M, -Jl_WC.T]), + jnp.hstack([Jl_WC, jnp.zeros((Jl_WC.shape[0], Jl_WC.shape[0]))]), + ] ) + b = jnp.hstack([M @ BW_ν_pre_impact, jnp.zeros(Jl_WC.shape[0])]) + + BW_ν_post_impact = jnp.linalg.pinv(A, b)[0] - return BW_ν_post_impact + return BW_ν_post_impact @jax.jit def compute_contact_forces(