Skip to content

Commit

Permalink
Merge pull request #325 from ami-iit/speedup_rigid_contacts
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti authored Jan 8, 2025
2 parents 5e2192c + 753bb78 commit 1d96dac
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 52 deletions.
12 changes: 5 additions & 7 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2295,9 +2295,6 @@ def step(
# Hence, here we need to reset the velocity after each impact to guarantee that
# the linear velocity of the active collidable points is zero.
case jaxsim.rbda.contacts.RigidContacts():
assert isinstance(
data_tf.contacts_params, jaxsim.rbda.contacts.RigidContactsParams
)

# Raise runtime error for not supported case in which Rigid contacts and
# Baumgarte stabilization are enabled and used with ForwardEuler integrator.
Expand Down Expand Up @@ -2331,21 +2328,22 @@ def step(
indices_of_enabled_collidable_points
]
M = js.model.free_floating_mass_matrix(model, data_tf)
BW_ν_pre_impact = data_tf.generalized_velocity()

# Compute the impact velocity.
# It may be discontinuous in case new contacts are made.
BW_nu_post_impact = (
BW_ν_post_impact = (
jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity(
data=data_tf,
generalized_velocity=BW_ν_pre_impact,
inactive_collidable_points=(δ <= 0),
M=M,
J_WC=J_WC,
)
)

# Reset the generalized velocity.
data_tf = data_tf.reset_base_velocity(BW_nu_post_impact[0:6])
data_tf = data_tf.reset_joint_velocities(BW_nu_post_impact[6:])
data_tf = data_tf.reset_base_velocity(BW_ν_post_impact[0:6])
data_tf = data_tf.reset_joint_velocities(BW_ν_post_impact[6:])

# Restore the input velocity representation.
data_tf = data_tf.replace(
Expand Down
72 changes: 27 additions & 45 deletions src/jaxsim/rbda/contacts/rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def compute_impact_velocity(
inactive_collidable_points: jtp.ArrayLike,
M: jtp.MatrixLike,
J_WC: jtp.MatrixLike,
data: js.data.JaxSimModelData,
generalized_velocity: jtp.VectorLike,
) -> jtp.Vector:
"""
Return the new velocity of the system after a potential impact.
Expand All @@ -182,55 +182,37 @@ def compute_impact_velocity(
inactive_collidable_points: The activation state of the collidable points.
M: The mass matrix of the system (in mixed representation).
J_WC: The Jacobian matrix of the collidable points (in mixed representation).
data: The `JaxSimModelData` instance.
"""
generalized_velocity: The generalized velocity of the system.
def impact_velocity(
inactive_collidable_points: jtp.ArrayLike,
nu_pre: jtp.ArrayLike,
M: jtp.MatrixLike,
J_WC: jtp.MatrixLike,
data: js.data.JaxSimModelData,
):
# Compute system velocity after impact maintaining zero linear velocity of active points
with data.switch_velocity_representation(VelRepr.Mixed):
sl = jnp.s_[:, 0:3, :]
Jl_WC = J_WC[sl]
# Zero out the jacobian rows of inactive points
Jl_WC = jnp.vstack(
jnp.where(
inactive_collidable_points[:, jnp.newaxis, jnp.newaxis],
jnp.zeros_like(Jl_WC),
Jl_WC,
)
)
Note:
The mass matrix `M`, the Jacobian `J_WC`, and the generalized velocity `generalized_velocity`
must be expressed in the same velocity representation.
"""

A = jnp.vstack(
[
jnp.hstack([M, -Jl_WC.T]),
jnp.hstack(
[Jl_WC, jnp.zeros((Jl_WC.shape[0], Jl_WC.shape[0]))]
),
]
)
b = jnp.hstack([M @ nu_pre, jnp.zeros(Jl_WC.shape[0])])
x = jnp.linalg.lstsq(A, b)[0]
nu_post = x[0 : M.shape[0]]
# Compute system velocity after impact maintaining zero linear velocity of active points.
sl = jnp.s_[:, 0:3, :]
Jl_WC = J_WC[sl]

# Zero out the jacobian rows of inactive points.
Jl_WC = jnp.vstack(
jnp.where(
inactive_collidable_points[:, jnp.newaxis, jnp.newaxis],
jnp.zeros_like(Jl_WC),
Jl_WC,
)
)

return nu_post
A = jnp.vstack(
[
jnp.hstack([M, -Jl_WC.T]),
jnp.hstack([Jl_WC, jnp.zeros((Jl_WC.shape[0], Jl_WC.shape[0]))]),
]
)
b = jnp.hstack([M @ generalized_velocity, jnp.zeros(Jl_WC.shape[0])])

with data.switch_velocity_representation(VelRepr.Mixed):
BW_ν_pre_impact = data.generalized_velocity()

BW_ν_post_impact = impact_velocity(
data=data,
inactive_collidable_points=inactive_collidable_points,
nu_pre=BW_ν_pre_impact,
M=M,
J_WC=J_WC,
)
BW_ν_post_impact = jnp.linalg.lstsq(A, b)[0]

return BW_ν_post_impact
return BW_ν_post_impact[0 : M.shape[0]]

@jax.jit
def compute_contact_forces(
Expand Down

0 comments on commit 1d96dac

Please sign in to comment.