From 9605aee35dd0a5ae7c986f2b08506aba18ed84b9 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 29 Jul 2024 18:09:52 +0200 Subject: [PATCH] Fix `JaxSimModelReferences.apply_frame_forces` --- src/jaxsim/api/references.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index 901b741b6..c34aa2240 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -9,7 +9,6 @@ import jaxsim.api as js import jaxsim.typing as jtp from jaxsim import exceptions -from jaxsim.math import Adjoint from jaxsim.utils.tracing import not_tracing from .common import VelRepr @@ -493,9 +492,9 @@ def apply_frame_forces( # Extract the frame indices. frame_idxs = js.frame.names_to_idxs(frame_names=frame_names, model=model) - parent_link_idxs = jax.vmap( - lambda frame_idx: js.frame.idx_of_parent_link, in_axes=(None,) - )(model, frame_idx=frame_idxs) + parent_link_idxs = jax.vmap(js.frame.idx_of_parent_link, in_axes=(None,))( + model, frame_index=frame_idxs + ) exceptions.raise_value_error_if( condition=jnp.logical_not(data.valid(model=model)), @@ -527,25 +526,14 @@ def to_inertial(f_F: jtp.MatrixLike, W_H_F: jtp.MatrixLike) -> jtp.Matrix: case _: raise ValueError("Invalid velocity representation.") - W_H_L = js.model.forward_kinematics(model=model, data=data) - - def convert_to_link_force( - W_f_F: jtp.MatrixLike, W_H_F: jtp.MatrixLike, parent_link_idx: jtp.ArrayLike - ) -> jtp.Matrix: - L_Xf_W = Adjoint.from_transform(W_H_L[parent_link_idx]).T - - return L_Xf_W @ W_f_F - - W_f_L_i = jax.vmap(convert_to_link_force)(W_f_F, W_H_Fi, parent_link_idxs) - # Sum the forces on the parent links. mask = parent_link_idxs[:, jnp.newaxis] == jnp.arange(model.number_of_links()) - W_f_L = mask.T @ W_f_L_i + W_f_L = mask.T @ W_f_F with self.switch_velocity_representation( velocity_representation=VelRepr.Inertial ): - return self.apply_link_forces( + references = self.apply_link_forces( model=model, data=data, link_names=js.link.idxs_to_names( @@ -554,3 +542,8 @@ def convert_to_link_force( forces=W_f_L, additive=additive, ) + + with references.switch_velocity_representation( + velocity_representation=self.velocity_representation + ): + return references