Skip to content

Commit

Permalink
Merge pull request ami-iit#220 from ami-iit/fix/apply_frame_forces
Browse files Browse the repository at this point in the history
Fix `JaxSimModelReferences.apply_frame_forces`
  • Loading branch information
flferretti authored Aug 20, 2024
2 parents 1145d1f + 9605aee commit a0efffd
Showing 1 changed file with 10 additions and 17 deletions.
27 changes: 10 additions & 17 deletions src/jaxsim/api/references.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim import exceptions
from jaxsim.math import Adjoint
from jaxsim.utils.tracing import not_tracing

from .common import VelRepr
Expand Down Expand Up @@ -493,9 +492,9 @@ def apply_frame_forces(

# Extract the frame indices.
frame_idxs = js.frame.names_to_idxs(frame_names=frame_names, model=model)
parent_link_idxs = jax.vmap(
lambda frame_idx: js.frame.idx_of_parent_link, in_axes=(None,)
)(model, frame_idx=frame_idxs)
parent_link_idxs = jax.vmap(js.frame.idx_of_parent_link, in_axes=(None,))(
model, frame_index=frame_idxs
)

exceptions.raise_value_error_if(
condition=jnp.logical_not(data.valid(model=model)),
Expand Down Expand Up @@ -527,25 +526,14 @@ def to_inertial(f_F: jtp.MatrixLike, W_H_F: jtp.MatrixLike) -> jtp.Matrix:
case _:
raise ValueError("Invalid velocity representation.")

W_H_L = js.model.forward_kinematics(model=model, data=data)

def convert_to_link_force(
W_f_F: jtp.MatrixLike, W_H_F: jtp.MatrixLike, parent_link_idx: jtp.ArrayLike
) -> jtp.Matrix:
L_Xf_W = Adjoint.from_transform(W_H_L[parent_link_idx]).T

return L_Xf_W @ W_f_F

W_f_L_i = jax.vmap(convert_to_link_force)(W_f_F, W_H_Fi, parent_link_idxs)

# Sum the forces on the parent links.
mask = parent_link_idxs[:, jnp.newaxis] == jnp.arange(model.number_of_links())
W_f_L = mask.T @ W_f_L_i
W_f_L = mask.T @ W_f_F

with self.switch_velocity_representation(
velocity_representation=VelRepr.Inertial
):
return self.apply_link_forces(
references = self.apply_link_forces(
model=model,
data=data,
link_names=js.link.idxs_to_names(
Expand All @@ -554,3 +542,8 @@ def convert_to_link_force(
forces=W_f_L,
additive=additive,
)

with references.switch_velocity_representation(
velocity_representation=self.velocity_representation
):
return references

0 comments on commit a0efffd

Please sign in to comment.