From fb3cb28c6d18de0693f6b91275b44507a1f7495a Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 13 Sep 2024 11:01:38 +0200 Subject: [PATCH] Fix system acceleration representation in `relaxed_rigid_contacts` --- src/jaxsim/api/contact.py | 1 + src/jaxsim/rbda/contacts/relaxed_rigid.py | 27 +++++++++++++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 5592541c9..d6a51af45 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -211,6 +211,7 @@ def collidable_point_dynamics( velocity=W_ṗ_Ci, model=model, data=data, + link_forces=link_forces, ) aux_data = dict() diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index cc3cb0bcc..93933d630 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -188,8 +188,22 @@ def compute_contact_forces( velocity: jtp.Vector, model: js.model.JaxSimModel, data: js.data.JaxSimModelData, + link_forces: jtp.MatrixLike | None = None, ) -> tuple[jtp.Vector, tuple[Any, ...]]: + link_forces = ( + link_forces + if link_forces is not None + else jnp.zeros((model.number_of_links(), 6)) + ) + + references = js.references.JaxSimModelReferences.build( + model=model, + data=data, + velocity_representation=data.velocity_representation, + link_forces=link_forces, + ) + def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array: x, y, z = jax.tree_map(jnp.squeeze, (x, y, z)) @@ -201,7 +215,10 @@ def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array: # Compute the activation state of the collidable points δ = jax.vmap(_detect_contact)(*position.T) - with data.switch_velocity_representation(VelRepr.Mixed): + with ( + references.switch_velocity_representation(VelRepr.Mixed), + data.switch_velocity_representation(VelRepr.Mixed), + ): M = js.model.free_floating_mass_matrix(model=model, data=data) Jl_WC = jnp.vstack( jax.vmap(lambda J, height: J * (height < 0))( @@ -209,7 +226,13 @@ def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array: ) ) W_H_C = js.contact.transforms(model=model, data=data) - W_ν̇ = jnp.hstack(js.ode.system_acceleration(model=model, data=data)) + W_ν̇ = jnp.hstack( + js.ode.system_acceleration( + model=model, + data=data, + link_forces=references.link_forces(model=model, data=data), + ) + ) W_ν = data.generalized_velocity() J̇_WC = jnp.vstack( jax.vmap(lambda J̇, height: J̇ * (height < 0))(