Skip to content

Commit

Permalink
Fix system acceleration representation in relaxed_rigid_contacts
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Sep 13, 2024
1 parent 618376d commit ae1ff74
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def collidable_point_dynamics(
velocity=W_ṗ_Ci,
model=model,
data=data,
link_forces=link_forces,
)

aux_data = dict()
Expand Down
27 changes: 25 additions & 2 deletions src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,22 @@ def compute_contact_forces(
velocity: jtp.Vector,
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
link_forces: jtp.MatrixLike | None = None,
) -> tuple[jtp.Vector, tuple[Any, ...]]:

link_forces = (
link_forces
if link_forces is not None
else jnp.zeros((model.number_of_links(), 6))
)

references = js.references.JaxSimModelReferences.build(
model=model,
data=data,
velocity_representation=data.velocity_representation,
link_forces=link_forces,
)

def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
x, y, z = jax.tree_map(jnp.squeeze, (x, y, z))

Expand All @@ -201,15 +215,24 @@ def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
# Compute the activation state of the collidable points
δ = jax.vmap(_detect_contact)(*position.T)

with data.switch_velocity_representation(VelRepr.Mixed):
with (
references.switch_velocity_representation(VelRepr.Mixed),
data.switch_velocity_representation(VelRepr.Mixed),
):
M = js.model.free_floating_mass_matrix(model=model, data=data)
Jl_WC = jnp.vstack(
jax.vmap(lambda J, height: J * (height < 0))(
js.contact.jacobian(model=model, data=data)[:, :3, :], δ
)
)
W_H_C = js.contact.transforms(model=model, data=data)
W_ν̇ = jnp.hstack(js.ode.system_acceleration(model=model, data=data))
W_ν̇ = jnp.hstack(
js.ode.system_acceleration(
model=model,
data=data,
link_forces=references.link_forces(model=model, data=data),
)
)
W_ν = data.generalized_velocity()
J̇_WC = jnp.vstack(
jax.vmap(lambda , height: * (height < 0))(
Expand Down

0 comments on commit ae1ff74

Please sign in to comment.