Skip to content

Commit

Permalink
Speed up computation for RigidContacts model
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Dec 11, 2024
1 parent e2911f8 commit f2889a7
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 43 deletions.
3 changes: 0 additions & 3 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2361,9 +2361,6 @@ def step(
# Hence, here we need to reset the velocity after each impact to guarantee that
# the linear velocity of the active collidable points is zero.
case jaxsim.rbda.contacts.RigidContacts():
assert isinstance(
data_tf.contacts_params, jaxsim.rbda.contacts.RigidContactsParams
)

# Raise runtime error for not supported case in which Rigid contacts and
# Baumgarte stabilization are enabled and used with ForwardEuler integrator.
Expand Down
62 changes: 22 additions & 40 deletions src/jaxsim/rbda/contacts/rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,52 +182,34 @@ def compute_impact_velocity(
data: The `JaxSimModelData` instance.
"""

def impact_velocity(
inactive_collidable_points: jtp.ArrayLike,
nu_pre: jtp.ArrayLike,
M: jtp.MatrixLike,
J_WC: jtp.MatrixLike,
data: js.data.JaxSimModelData,
):
# Compute system velocity after impact maintaining zero linear velocity of active points
with data.switch_velocity_representation(VelRepr.Mixed):
sl = jnp.s_[:, 0:3, :]
Jl_WC = J_WC[sl]
# Zero out the jacobian rows of inactive points
Jl_WC = jnp.vstack(
jnp.where(
inactive_collidable_points[:, jnp.newaxis, jnp.newaxis],
jnp.zeros_like(Jl_WC),
Jl_WC,
)
)
# Compute system velocity after impact maintaining zero linear velocity of active points.
with data.switch_velocity_representation(VelRepr.Mixed):

A = jnp.vstack(
[
jnp.hstack([M, -Jl_WC.T]),
jnp.hstack(
[Jl_WC, jnp.zeros((Jl_WC.shape[0], Jl_WC.shape[0]))]
),
]
)
b = jnp.hstack([M @ nu_pre, jnp.zeros(Jl_WC.shape[0])])
x = jnp.linalg.lstsq(A, b)[0]
nu_post = x[0 : M.shape[0]]
BW_ν_pre_impact = data.generalized_velocity()

return nu_post
sl = jnp.s_[:, 0:3, :]
Jl_WC = J_WC[sl]

with data.switch_velocity_representation(VelRepr.Mixed):
BW_ν_pre_impact = data.generalized_velocity()
# Zero out the jacobian rows of inactive points.
Jl_WC = jnp.vstack(
jnp.where(
inactive_collidable_points[:, jnp.newaxis, jnp.newaxis],
jnp.zeros_like(Jl_WC),
Jl_WC,
)
)

BW_ν_post_impact = impact_velocity(
data=data,
inactive_collidable_points=inactive_collidable_points,
nu_pre=BW_ν_pre_impact,
M=M,
J_WC=J_WC,
A = jnp.vstack(
[
jnp.hstack([M, -Jl_WC.T]),
jnp.hstack([Jl_WC, jnp.zeros((Jl_WC.shape[0], Jl_WC.shape[0]))]),
]
)
b = jnp.hstack([M @ BW_ν_pre_impact, jnp.zeros(Jl_WC.shape[0])])

BW_ν_post_impact = jnp.linalg.pinv(A, b)[0]

return BW_ν_post_impact
return BW_ν_post_impact

@jax.jit
def compute_contact_forces(
Expand Down

0 comments on commit f2889a7

Please sign in to comment.