Skip to content

Commit

Permalink
Merge pull request ami-iit#209 from ami-iit/feature/frame_force
Browse files Browse the repository at this point in the history
Add the possibility to set forces to frames
  • Loading branch information
flferretti authored Jul 24, 2024
2 parents 4189a2c + 92cb0bb commit a58766b
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 53 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ ignore = [
"E501", # Line too long
"E731", # Do not assign a `lambda` expression, use a `def`
"E741", # Ambiguous variable name
"F841", # Local variable is assigned to but never used
"I001", # Import block is unsorted or unformatted
"RUF003", # Ambigous unicode character in comment
]
Expand Down
24 changes: 12 additions & 12 deletions src/jaxsim/api/com.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ def centroidal_momentum_jacobian(

match data.velocity_representation:
case VelRepr.Inertial | VelRepr.Mixed:
W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841
case VelRepr.Body:
W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM)
W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841
case _:
raise ValueError(data.velocity_representation)

Expand Down Expand Up @@ -172,9 +172,9 @@ def locked_centroidal_spatial_inertia(

match data.velocity_representation:
case VelRepr.Inertial | VelRepr.Mixed:
W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841
case VelRepr.Body:
W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM)
W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841
case _:
raise ValueError(data.velocity_representation)

Expand Down Expand Up @@ -290,14 +290,14 @@ def other_representation_to_body(

case VelRepr.Inertial:

C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL
C_v_WC = W_v_WW = jnp.zeros(6)
C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841
C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841

L_H_C = L_H_W = jax.vmap(
L_H_C = L_H_W = jax.vmap( # noqa: F841
lambda W_H_L: jaxsim.math.Transform.inverse(W_H_L)
)(W_H_L)

L_v_LC = L_v_LW = jax.vmap(
L_v_LC = L_v_LW = jax.vmap( # noqa: F841
lambda i: -js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body
)
Expand All @@ -314,23 +314,23 @@ def other_representation_to_body(

case VelRepr.Mixed:

C_v̇_WL = LW_v̇_bias_WL = v̇_bias_WL
C_v̇_WL = LW_v̇_bias_WL = v̇_bias_WL # noqa: F841

C_v_WC = LW_v_W_LW = jax.vmap(
C_v_WC = LW_v_W_LW = jax.vmap( # noqa: F841
lambda i: js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Mixed
)
.at[3:6]
.set(jnp.zeros(3))
)(jnp.arange(model.number_of_links()))

L_H_C = L_H_LW = jax.vmap(
L_H_C = L_H_LW = jax.vmap( # noqa: F841
lambda W_H_L: jaxsim.math.Transform.inverse(
W_H_L.at[0:3, 3].set(jnp.zeros(3))
)
)(W_H_L)

L_v_LC = L_v_L_LW = jax.vmap(
L_v_LC = L_v_L_LW = jax.vmap( # noqa: F841
lambda i: -js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body
)
Expand Down
10 changes: 6 additions & 4 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,8 +537,10 @@ def compute_O_J̇_WC_I(

match output_vel_repr:
case VelRepr.Inertial:
O_X_W = W_X_W = Adjoint.from_transform(transform=jnp.eye(4))
O_Ẋ_W = W_Ẋ_W = jnp.zeros((6, 6))
O_X_W = W_X_W = Adjoint.from_transform( # noqa: F841
transform=jnp.eye(4)
)
O_Ẋ_W = W_Ẋ_W = jnp.zeros((6, 6)) # noqa: F841

case VelRepr.Body:
L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
Expand All @@ -548,7 +550,7 @@ def compute_O_J̇_WC_I(
W_nu = data.generalized_velocity()
W_v_WC = W_J_WL_W[parent_link_idx] @ W_nu
W_vx_WC = Cross.vx(W_v_WC)
O_Ẋ_W = C_Ẋ_W = -C_X_W @ W_vx_WC
O_Ẋ_W = C_Ẋ_W = -C_X_W @ W_vx_WC # noqa: F841

case VelRepr.Mixed:
L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
Expand All @@ -560,7 +562,7 @@ def compute_O_J̇_WC_I(
CW_v_WC = CW_J_WC_BW @ data.generalized_velocity()
W_v_W_CW = jnp.zeros(6).at[0:3].set(CW_v_WC[0:3])
W_vx_W_CW = Cross.vx(W_v_W_CW)
O_Ẋ_W = CW_Ẋ_W = -CW_X_W @ W_vx_W_CW
O_Ẋ_W = CW_Ẋ_W = -CW_X_W @ W_vx_W_CW # noqa: F841

case _:
raise ValueError(output_vel_repr)
Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/api/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix:
W_nu = data.generalized_velocity()
W_v_WF = W_J_WL_W @ W_nu
W_vx_WF = Cross.vx(W_v_WF)
O_Ẋ_W = F_Ẋ_W = -F_X_W @ W_vx_WF
O_Ẋ_W = F_Ẋ_W = -F_X_W @ W_vx_WF # noqa: F841

case VelRepr.Mixed:
W_H_F = transform(model=model, data=data, frame_index=frame_index)
Expand All @@ -401,7 +401,7 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix:
FW_v_WF = FW_J_WF_FW @ data.generalized_velocity()
W_v_W_FW = jnp.zeros(6).at[0:3].set(FW_v_WF[0:3])
W_vx_W_FW = Cross.vx(W_v_W_FW)
O_Ẋ_W = FW_Ẋ_W = -FW_X_W @ W_vx_W_FW
O_Ẋ_W = FW_Ẋ_W = -FW_X_W @ W_vx_W_FW # noqa: F841

case _:
raise ValueError(output_vel_repr)
Expand Down
17 changes: 10 additions & 7 deletions src/jaxsim/api/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def jacobian(
case VelRepr.Inertial:
W_H_B = data.base_transform()
B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag(
B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841
B_X_W, jnp.eye(model.dofs())
)

Expand All @@ -299,7 +299,7 @@ def jacobian(
W_R_B = data.base_orientation(dcm=True)
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag(
B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841
B_X_BW, jnp.eye(model.dofs())
)

Expand All @@ -313,7 +313,7 @@ def jacobian(
case VelRepr.Inertial:
W_H_B = data.base_transform()
W_X_B = Adjoint.from_transform(transform=W_H_B)
O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I
O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I # noqa: F841

case VelRepr.Body:
L_X_B = Adjoint.from_transform(transform=B_H_L, inverse=True)
Expand Down Expand Up @@ -505,7 +505,7 @@ def jacobian_derivative(
with data.switch_velocity_representation(VelRepr.Body):
B_v_WB = data.base_velocity()

O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB)
O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841

case VelRepr.Body:

Expand All @@ -519,7 +519,9 @@ def jacobian_derivative(
B_v_WB = data.base_velocity()
L_v_WL = js.link.velocity(model=model, data=data, link_index=link_index)

O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx(B_X_L @ L_v_WL - B_v_WB)
O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx( # noqa: F841
B_X_L @ L_v_WL - B_v_WB
)

case VelRepr.Mixed:

Expand All @@ -544,8 +546,9 @@ def jacobian_derivative(
LW_v_LW_L = LW_v_WL - LW_v_W_LW
LW_v_B_LW = LW_v_WL - LW_X_B @ B_v_WB - LW_v_LW_L

O_Ẋ_B = LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx(B_X_LW @ LW_v_B_LW)

O_Ẋ_B = LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx( # noqa: F841
B_X_LW @ LW_v_B_LW
)
case _:
raise ValueError(output_vel_repr)

Expand Down
53 changes: 31 additions & 22 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,8 +495,9 @@ def generalized_free_floating_jacobian(
W_H_B = data.base_transform()
B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)

B_J_full_WX_I = B_J_full_WX_W = B_J_full_WX_B @ jax.scipy.linalg.block_diag(
B_X_W, jnp.eye(model.dofs())
B_J_full_WX_I = B_J_full_WX_W = ( # noqa: F841
B_J_full_WX_B
@ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
)

case VelRepr.Body:
Expand All @@ -509,7 +510,7 @@ def generalized_free_floating_jacobian(
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)

B_J_full_WX_I = B_J_full_WX_BW = (
B_J_full_WX_I = B_J_full_WX_BW = ( # noqa: F841
B_J_full_WX_B
@ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
)
Expand Down Expand Up @@ -542,11 +543,13 @@ def generalized_free_floating_jacobian(
W_H_B = data.base_transform()
W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B)

O_J_WL_I = W_J_WL_I = jax.vmap(lambda B_J_WL_I: W_X_B @ B_J_WL_I)(B_J_WL_I)
O_J_WL_I = W_J_WL_I = jax.vmap( # noqa: F841
lambda B_J_WL_I: W_X_B @ B_J_WL_I
)(B_J_WL_I)

case VelRepr.Body:

O_J_WL_I = L_J_WL_I = jax.vmap(
O_J_WL_I = L_J_WL_I = jax.vmap( # noqa: F841
lambda B_H_L, B_J_WL_I: jaxsim.math.Adjoint.from_transform(
B_H_L, inverse=True
)
Expand All @@ -565,7 +568,7 @@ def generalized_free_floating_jacobian(
lambda LW_H_L, B_H_L: LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
)(LW_H_L, B_H_L)

O_J_WL_I = LW_J_WL_I = jax.vmap(
O_J_WL_I = LW_J_WL_I = jax.vmap( # noqa: F841
lambda LW_H_B, B_J_WL_I: jaxsim.math.Adjoint.from_transform(LW_H_B)
@ B_J_WL_I
)(LW_H_B, B_J_WL_I)
Expand Down Expand Up @@ -756,8 +759,8 @@ def to_active(
match data.velocity_representation:
case VelRepr.Inertial:
# In this case C=W
W_H_C = W_H_W = jnp.eye(4)
W_v_WC = W_v_WW = jnp.zeros(6)
W_H_C = W_H_W = jnp.eye(4) # noqa: F841
W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841

case VelRepr.Body:
# In this case C=B
Expand All @@ -767,9 +770,9 @@ def to_active(
case VelRepr.Mixed:
# In this case C=B[W]
W_H_B = data.base_transform()
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841
W_ṗ_B = data.base_velocity()[0:3]
W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841

case _:
raise ValueError(data.velocity_representation)
Expand Down Expand Up @@ -1124,8 +1127,8 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):

match data.velocity_representation:
case VelRepr.Inertial:
W_H_C = W_H_W = jnp.eye(4)
W_v_WC = W_v_WW = jnp.zeros(6)
W_H_C = W_H_W = jnp.eye(4) # noqa: F841
W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841

case VelRepr.Body:
W_H_C = W_H_B = data.base_transform()
Expand All @@ -1134,9 +1137,9 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):

case VelRepr.Mixed:
W_H_B = data.base_transform()
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841
W_ṗ_B = data.base_velocity()[0:3]
W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841

case _:
raise ValueError(data.velocity_representation)
Expand Down Expand Up @@ -1571,15 +1574,15 @@ def other_representation_to_inertial(
# a simple C_X_W 6D transform.
match data.velocity_representation:
case VelRepr.Inertial:
W_H_C = W_H_W = jnp.eye(4)
W_v_WC = W_v_WW = jnp.zeros(6)
W_H_C = W_H_W = jnp.eye(4) # noqa: F841
W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
with data.switch_velocity_representation(VelRepr.Inertial):
C_v_WB = W_v_WB = data.base_velocity()

case VelRepr.Body:
W_H_C = W_H_B
with data.switch_velocity_representation(VelRepr.Inertial):
W_v_WC = W_v_WB = data.base_velocity()
W_v_WC = W_v_WB = data.base_velocity() # noqa: F841
with data.switch_velocity_representation(VelRepr.Body):
C_v_WB = B_v_WB = data.base_velocity()

Expand All @@ -1590,9 +1593,9 @@ def other_representation_to_inertial(
W_ṗ_B = data.base_velocity()[0:3]
BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW)
W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW
W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW # noqa: F841
with data.switch_velocity_representation(VelRepr.Mixed):
C_v_WB = BW_v_WB = data.base_velocity()
C_v_WB = BW_v_WB = data.base_velocity() # noqa: F841

case _:
raise ValueError(data.velocity_representation)
Expand Down Expand Up @@ -1700,8 +1703,12 @@ def body_to_other_representation(

match data.velocity_representation:
case VelRepr.Body:
C_H_L = L_H_L = jnp.stack([jnp.eye(4)] * model.number_of_links())
L_v_CL = L_v_LL = jnp.zeros(shape=(model.number_of_links(), 6))
C_H_L = L_H_L = jnp.stack( # noqa: F841
[jnp.eye(4)] * model.number_of_links()
)
L_v_CL = L_v_LL = jnp.zeros( # noqa: F841
shape=(model.number_of_links(), 6)
)

case VelRepr.Inertial:
C_H_L = W_H_L = js.model.forward_kinematics(model=model, data=data)
Expand All @@ -1711,7 +1718,9 @@ def body_to_other_representation(
W_H_L = js.model.forward_kinematics(model=model, data=data)
LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L)
C_H_L = LW_H_L
L_v_CL = L_v_LW_L = jax.vmap(lambda v: v.at[0:3].set(jnp.zeros(3)))(L_v_WL)
L_v_CL = L_v_LW_L = jax.vmap( # noqa: F841
lambda v: v.at[0:3].set(jnp.zeros(3))
)(L_v_WL)

case _:
raise ValueError(data.velocity_representation)
Expand Down
Loading

0 comments on commit a58766b

Please sign in to comment.