diff --git a/src/jaxsim/api/com.py b/src/jaxsim/api/com.py index 6057edaed..9701ced92 100644 --- a/src/jaxsim/api/com.py +++ b/src/jaxsim/api/com.py @@ -301,18 +301,17 @@ def other_representation_to_body( def to_body() -> jtp.Vector: L_a_bias_WL = v̇_bias_WL + return L_a_bias_WL def to_inertial() -> jtp.Vector: - C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841 - C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 + W_v̇_bias_WL = v̇_bias_WL + W_v_WW = jnp.zeros(6) - L_H_C = L_H_W = jax.vmap(lambda W_H_L: Transform.inverse(W_H_L))( # noqa: F841 - W_H_L - ) + L_H_W = jax.vmap(lambda W_H_L: Transform.inverse(W_H_L))(W_H_L) - L_v_LC = L_v_LW = jax.vmap( # noqa: F841 + L_v_LW = jax.vmap( lambda i: -js.link.velocity( model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body ) @@ -320,19 +319,20 @@ def to_inertial() -> jtp.Vector: L_a_bias_WL = jax.vmap( lambda i: other_representation_to_body( - C_v̇_WL=C_v̇_WL[i], - C_v_WC=C_v_WC, - L_H_C=L_H_C[i], - L_v_LC=L_v_LC[i], + C_v̇_WL=W_v̇_bias_WL[i], + C_v_WC=W_v_WW, + L_H_C=L_H_W[i], + L_v_LC=L_v_LW[i], ) )(jnp.arange(model.number_of_links())) + return L_a_bias_WL def to_mixed() -> jtp.Vector: - C_v̇_WL = LW_v̇_bias_WL = v̇_bias_WL # noqa: F841 + LW_v̇_bias_WL = v̇_bias_WL - C_v_WC = LW_v_W_LW = jax.vmap( # noqa: F841 + LW_v_W_LW = jax.vmap( lambda i: js.link.velocity( model=model, data=data, link_index=i, output_vel_repr=VelRepr.Mixed ) @@ -340,11 +340,11 @@ def to_mixed() -> jtp.Vector: .set(jnp.zeros(3)) )(jnp.arange(model.number_of_links())) - L_H_C = L_H_LW = jax.vmap( # noqa: F841 + L_H_LW = jax.vmap( lambda W_H_L: Transform.inverse(W_H_L.at[0:3, 3].set(jnp.zeros(3))) )(W_H_L) - L_v_LC = L_v_L_LW = jax.vmap( # noqa: F841 + L_v_L_LW = jax.vmap( lambda i: -js.link.velocity( model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body ) @@ -354,12 +354,13 @@ def to_mixed() -> jtp.Vector: L_a_bias_WL = jax.vmap( lambda i: other_representation_to_body( - C_v̇_WL=C_v̇_WL[i], - C_v_WC=C_v_WC[i], - L_H_C=L_H_C[i], - L_v_LC=L_v_LC[i], + C_v̇_WL=LW_v̇_bias_WL[i], + C_v_WC=LW_v_W_LW[i], + L_H_C=L_H_LW[i], + L_v_LC=L_v_L_LW[i], ) )(jnp.arange(model.number_of_links())) + return L_a_bias_WL # We need here to get the body-fixed bias acceleration of the links. diff --git a/src/jaxsim/api/common.py b/src/jaxsim/api/common.py index d9dba423e..33bec4567 100644 --- a/src/jaxsim/api/common.py +++ b/src/jaxsim/api/common.py @@ -112,19 +112,21 @@ def inertial_to_other_representation( if W_H_O.shape != (4, 4): raise ValueError(W_H_O.shape, (4, 4)) - def to_inertial(): + def to_inertial() -> jtp.Array: + return W_array - def to_body(): + def to_body() -> jtp.Array: if not is_force: O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True) O_array = O_Xv_W @ W_array else: O_Xf_W = Adjoint.from_transform(transform=W_H_O).T O_array = O_Xf_W @ W_array + return O_array - def to_mixed(): + def to_mixed() -> jtp.Array: W_p_O = W_H_O[0:3, 3] W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O) if not is_force: @@ -133,6 +135,7 @@ def to_mixed(): else: OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).T OW_array = OW_Xf_W @ W_array + return OW_array return jax.lax.switch( diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 6c0785c7c..96d66282f 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -373,8 +373,8 @@ def jacobian( ) def to_inertial() -> jtp.Matrix: - O_J_WC = W_J_WC - return O_J_WC + + return W_J_WC def to_body() -> jtp.Matrix: @@ -385,8 +385,9 @@ def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: C_J_WC = C_X_W @ W_J_WC return C_J_WC - O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC) - return O_J_WC + C_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC) + + return C_J_WC def to_mixed() -> jtp.Matrix: @@ -401,8 +402,9 @@ def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: CW_J_WC = CW_X_W @ W_J_WC return CW_J_WC - O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC) - return O_J_WC + CW_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC) + + return CW_J_WC # Adjust the output representation. O_J_WC = jax.lax.switch( @@ -548,13 +550,13 @@ def compute_O_J̇_WC_I( parent_link_idx = parent_link_idxs[contact_idx] - def to_inertial(): + def to_inertial() -> tuple[jtp.Matrix, jtp.Matrix]: W_X_W = Adjoint.from_transform(transform=jnp.eye(4)) W_Ẋ_W = jnp.zeros((6, 6)) return W_X_W, W_Ẋ_W - def to_body(): + def to_body() -> tuple[jtp.Matrix, jtp.Matrix]: L_H_C = Transform.from_rotation_and_translation(translation=L_p_C) W_H_C = W_H_L[parent_link_idx] @ L_H_C C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) @@ -566,7 +568,7 @@ def to_body(): return C_X_W, C_Ẋ_W - def to_mixed(): + def to_mixed() -> tuple[jtp.Matrix, jtp.Matrix]: L_H_C = Transform.from_rotation_and_translation(translation=L_p_C) W_H_C = W_H_L[parent_link_idx] @ L_H_C W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) diff --git a/src/jaxsim/api/frame.py b/src/jaxsim/api/frame.py index 0dc2d88a1..0d8ae9dcc 100644 --- a/src/jaxsim/api/frame.py +++ b/src/jaxsim/api/frame.py @@ -231,6 +231,7 @@ def to_inertial() -> jtp.Matrix: W_H_L = js.link.transform(model=model, data=data, link_index=L) W_X_L = Adjoint.from_transform(transform=W_H_L) W_J_WL = W_X_L @ L_J_WL + return W_J_WL def to_body() -> jtp.Matrix: @@ -239,6 +240,7 @@ def to_body() -> jtp.Matrix: F_H_L = Transform.inverse(W_H_F) @ W_H_L F_X_L = Adjoint.from_transform(transform=F_H_L) F_J_WL = F_X_L @ L_J_WL + return F_J_WL def to_mixed() -> jtp.Matrix: @@ -249,6 +251,7 @@ def to_mixed() -> jtp.Matrix: FW_H_L = FW_H_F @ F_H_L FW_X_L = Adjoint.from_transform(transform=FW_H_L) FW_J_WL = FW_X_L @ L_J_WL + return FW_J_WL # Adjust the output representation @@ -388,13 +391,13 @@ def from_mixed(): # Compute quantities to adjust the output representation # ===================================================== - def to_inertial(): + def to_inertial() -> tuple[jtp.Matrix, jtp.Matrix]: W_X_W = Adjoint.from_transform(transform=jnp.eye(4)) W_Ẋ_W = jnp.zeros((6, 6)) return W_X_W, W_Ẋ_W - def to_body(): + def to_body() -> tuple[jtp.Matrix, jtp.Matrix]: W_H_F = transform(model=model, data=data, frame_index=frame_index) F_X_W = Adjoint.from_transform(transform=W_H_F, inverse=True) with data.switch_velocity_representation(VelRepr.Inertial): @@ -405,7 +408,7 @@ def to_body(): return F_X_W, F_Ẋ_W - def to_mixed(): + def to_mixed() -> tuple[jtp.Matrix, jtp.Matrix]: W_H_F = transform(model=model, data=data, frame_index=frame_index) W_H_FW = W_H_F.at[0:3, 0:3].set(jnp.eye(3)) FW_H_W = Transform.inverse(W_H_FW) diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index 5906792a9..9519f5c14 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -287,21 +287,22 @@ def jacobian( def to_inertial() -> jtp.Matrix: W_H_B = data.base_transform() B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) - B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841 - B_X_W, jnp.eye(model.dofs()) - ) + B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) + return B_J_WL_W def to_body() -> jtp.Matrix: + return B_J_WL_B def to_mixed() -> jtp.Matrix: W_R_B = data.base_orientation(dcm=True) BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) - B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841 + B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag( B_X_BW, jnp.eye(model.dofs()) ) + return B_J_WL_BW B_J_WL_I = jax.lax.switch( @@ -319,11 +320,13 @@ def to_inertial() -> jtp.Matrix: W_H_B = data.base_transform() W_X_B = Adjoint.from_transform(transform=W_H_B) W_J_WL_I = W_X_B @ B_J_WL_I + return W_J_WL_I def to_body() -> jtp.Matrix: L_X_B = Adjoint.from_transform(transform=B_H_L, inverse=True) L_J_WL_I = L_X_B @ B_J_WL_I + return L_J_WL_I def to_mixed() -> jtp.Matrix: @@ -333,6 +336,7 @@ def to_mixed() -> jtp.Matrix: LW_H_B = LW_H_L @ Transform.inverse(B_H_L) LW_X_B = Adjoint.from_transform(transform=LW_H_B) LW_J_WL_I = LW_X_B @ B_J_WL_I + return LW_J_WL_I # Adjust the output representation such that `O_v_WL_I = O_J_WL_I @ I_ν`. @@ -455,7 +459,7 @@ def jacobian_derivative( In = jnp.eye(model.dofs()) On = jnp.zeros(shape=(model.dofs(), model.dofs())) - def from_inertial() -> jtp.Matrix: + def from_inertial() -> tuple[jtp.Matrix, jtp.Matrix]: W_H_B = data.base_transform() B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) @@ -468,9 +472,10 @@ def from_inertial() -> jtp.Matrix: # time derivative. T = jax.scipy.linalg.block_diag(B_X_W, In) Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_W, On) + return T, Ṫ - def from_body() -> jtp.Matrix: + def from_body() -> tuple[jtp.Matrix, jtp.Matrix]: B_X_B = Adjoint.from_rotation_and_translation( translation=jnp.zeros(3), rotation=jnp.eye(3) @@ -482,9 +487,10 @@ def from_body() -> jtp.Matrix: # time derivative. T = jax.scipy.linalg.block_diag(B_X_B, In) Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_B, On) + return T, Ṫ - def from_mixed() -> jtp.Matrix: + def from_mixed() -> tuple[jtp.Matrix, jtp.Matrix]: BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) @@ -500,6 +506,7 @@ def from_mixed() -> jtp.Matrix: # time derivative. T = jax.scipy.linalg.block_diag(B_X_BW, In) Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_BW, On) + return T, Ṫ T, Ṫ = jax.lax.switch( @@ -515,22 +522,21 @@ def from_mixed() -> jtp.Matrix: # Compute quantities to adjust the output representation # ====================================================== - def to_inertial() -> jtp.Matrix: + def to_inertial() -> tuple[jtp.Matrix, jtp.Matrix]: W_H_B = data.base_transform() - O_X_B = W_X_B = Adjoint.from_transform(transform=W_H_B) + W_X_B = Adjoint.from_transform(transform=W_H_B) with data.switch_velocity_representation(VelRepr.Body): B_v_WB = data.base_velocity() - O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841 - return O_X_B, O_Ẋ_B + W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) - def to_body() -> jtp.Matrix: + return W_X_B, W_Ẋ_B - O_X_B = L_X_B = Adjoint.from_transform( - transform=B_H_L[link_index, :, :], inverse=True - ) + def to_body() -> tuple[jtp.Matrix, jtp.Matrix]: + + L_X_B = Adjoint.from_transform(transform=B_H_L[link_index, :, :], inverse=True) B_X_L = Adjoint.inverse(adjoint=L_X_B) @@ -538,19 +544,18 @@ def to_body() -> jtp.Matrix: B_v_WB = data.base_velocity() L_v_WL = js.link.velocity(model=model, data=data, link_index=link_index) - O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx( # noqa: F841 - B_X_L @ L_v_WL - B_v_WB - ) - return O_X_B, O_Ẋ_B + L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx(B_X_L @ L_v_WL - B_v_WB) - def to_mixed() -> jtp.Matrix: + return L_X_B, L_Ẋ_B + + def to_mixed() -> tuple[jtp.Matrix, jtp.Matrix]: W_H_B = data.base_transform() W_H_L = W_H_B @ B_H_L[link_index, :, :] LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3)) LW_H_B = LW_H_L @ Transform.inverse(B_H_L[link_index, :, :]) - O_X_B = LW_X_B = Adjoint.from_transform(transform=LW_H_B) + LW_X_B = Adjoint.from_transform(transform=LW_H_B) B_X_LW = Adjoint.inverse(adjoint=LW_X_B) @@ -564,10 +569,9 @@ def to_mixed() -> jtp.Matrix: LW_v_LW_L = LW_v_WL - LW_v_W_LW LW_v_B_LW = LW_v_WL - LW_X_B @ B_v_WB - LW_v_LW_L - O_Ẋ_B = LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx( # noqa: F841 - B_X_LW @ LW_v_B_LW - ) - return O_X_B, O_Ẋ_B + LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx(B_X_LW @ LW_v_B_LW) + + return LW_X_B, LW_Ẋ_B O_X_B, O_Ẋ_B = jax.lax.switch( index=output_vel_repr, diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 8011d3438..b8f9f5701 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -488,19 +488,21 @@ def generalized_free_floating_jacobian( # Update the input velocity representation such that v_WL = J_WL_I @ I_ν # ====================================================================== - def from_inertial(): + def from_inertial() -> jtp.Matrix: W_H_B = data.base_transform() B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) B_J_full_WX_W = B_J_full_WX_B @ jax.scipy.linalg.block_diag( B_X_W, jnp.eye(model.dofs()) ) + return B_J_full_WX_W - def from_body(): + def from_body() -> jtp.Matrix: + return B_J_full_WX_B - def from_mixed(): + def from_mixed() -> jtp.Matrix: W_R_B = data.base_orientation(dcm=True) BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) @@ -508,6 +510,7 @@ def from_mixed(): B_J_full_WX_BW = B_J_full_WX_B @ jax.scipy.linalg.block_diag( B_X_BW, jnp.eye(model.dofs()) ) + return B_J_full_WX_BW # Update the input velocity representation such that `J_WL_I @ I_ν`. @@ -520,43 +523,17 @@ def from_mixed(): ), ) - def to_inertial(): - W_H_B = data.base_transform() - W_X_B = Adjoint.from_transform(transform=W_H_B) - W_J_full_WX_I = W_X_B @ B_J_full_WX_I - return W_J_full_WX_I - - def to_body(): - return B_J_full_WX_I - - def to_mixed(): - W_R_B = data.base_orientation(dcm=True) - BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) - BW_X_B = Adjoint.from_transform(transform=BW_H_B) - BW_J_full_WX_I = BW_X_B @ B_J_full_WX_I - return BW_J_full_WX_I - # ==================================================================== # Create stacked Jacobian for each link by filtering the full Jacobian # ==================================================================== - # Update the output velocity representation such that `O_v_WL_I = O_J_WL_I @ I_ν`. - O_J_full_WX_I = jax.lax.switch( - index=output_vel_repr, - branches=( - to_body, # VelRepr.Body - to_mixed, # VelRepr.Mixed - to_inertial, # VelRepr.Inertial - ), - ) - κ_bool = model.kin_dyn_parameters.support_body_array_bool # Keep only the columns of the full Jacobian corresponding to the support # body array of each link. B_J_WL_I = jax.vmap( lambda κ: jnp.where( - jnp.hstack([jnp.ones(5), κ]), O_J_full_WX_I, jnp.zeros_like(O_J_full_WX_I) + jnp.hstack([jnp.ones(5), κ]), B_J_full_WX_I, jnp.zeros_like(B_J_full_WX_I) ) )(κ_bool) @@ -568,19 +545,19 @@ def to_inertial() -> jtp.Matrix: W_H_B = data.base_transform() W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B) - O_J_WL_I = W_J_WL_I = jax.vmap(lambda B_J_WL_I: W_X_B @ B_J_WL_I)( # noqa: F841 - B_J_WL_I - ) - return O_J_WL_I + W_J_WL_I = jax.vmap(lambda B_J_WL_I: W_X_B @ B_J_WL_I)(B_J_WL_I) + + return W_J_WL_I def to_body() -> jtp.Matrix: - O_J_WL_I = L_J_WL_I = jax.vmap( # noqa: F841 + L_J_WL_I = jax.vmap( lambda B_H_L, B_J_WL_I: jaxsim.math.Adjoint.from_transform( B_H_L, inverse=True ) @ B_J_WL_I )(B_H_L, B_J_WL_I) - return O_J_WL_I + + return L_J_WL_I def to_mixed() -> jtp.Matrix: W_H_B = data.base_transform() @@ -591,11 +568,12 @@ def to_mixed() -> jtp.Matrix: lambda LW_H_L, B_H_L: LW_H_L @ jaxsim.math.Transform.inverse(B_H_L) )(LW_H_L, B_H_L) - O_J_WL_I = LW_J_WL_I = jax.vmap( # noqa: F841 + LW_J_WL_I = jax.vmap( lambda LW_H_B, B_J_WL_I: jaxsim.math.Adjoint.from_transform(LW_H_B) @ B_J_WL_I )(LW_H_B, B_J_WL_I) - return O_J_WL_I + + return LW_J_WL_I O_J_WL_I = jax.lax.switch( index=output_vel_repr, @@ -786,25 +764,27 @@ def to_active( C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) return C_X_W @ W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB - def to_inertial(): + def to_inertial() -> tuple[jtp.Vector, jtp.Matrix]: # In this case C=W - W_H_C = W_H_W = jnp.eye(4) # noqa: F841 - W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 - return W_v_WC, W_H_C + W_H_W = jnp.eye(4) + W_v_WW = jnp.zeros(6) + + return W_v_WW, W_H_W - def to_body(): + def to_body() -> tuple[jtp.Vector, jtp.Matrix]: # In this case C=B - W_H_C = W_H_B = data.base_transform() # noqa: F841 - W_v_WC = W_v_WB - return W_v_WC, W_H_C + W_H_B = data.base_transform() - def to_mixed(): + return W_v_WB, W_H_B + + def to_mixed() -> tuple[jtp.Vector, jtp.Matrix]: # In this case C=B[W] W_H_B = data.base_transform() - W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841 + W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) W_ṗ_B = data.base_velocity()[0:3] - W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841 - return W_v_WC, W_H_C + W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) + + return W_v_W_BW, W_H_BW W_v_WC, W_H_C = jax.lax.switch( index=data.velocity_representation, @@ -1046,10 +1026,10 @@ def compute_link_contribution(M, v, J, J̇) -> jtp.Array: # Adjust the representation of the Coriolis matrix. # Refer to https://github.com/traversaro/traversaro-phd-thesis, Section 3.6. - def to_body(): + def to_body() -> jtp.Matrix: return C_B - def to_inertial(): + def to_inertial() -> jtp.Matrix: n = model.dofs() W_H_B = data.base_transform() B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True) @@ -1068,7 +1048,7 @@ def to_inertial(): return C - def to_mixed(): + def to_mixed() -> jtp.Matrix: n = model.dofs() BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True) @@ -1169,22 +1149,25 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): return W_X_C @ (C_v̇_WB + Cross.vx(C_v_WC) @ C_v_WB) def convert_inertial() -> jtp.Vector: - W_H_C = W_H_W = jnp.eye(4) # noqa: F841 - W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 - return W_H_C, W_v_WC + W_H_W = jnp.eye(4) + W_v_WW = jnp.zeros(6) + + return W_H_W, W_v_WW def convert_body() -> jtp.Vector: - W_H_C = W_H_B = data.base_transform() # noqa: F841 + W_H_B = data.base_transform() with data.switch_velocity_representation(VelRepr.Inertial): - W_v_WC = W_v_WB = data.base_velocity() # noqa: F841 - return W_H_C, W_v_WC + W_v_WB = data.base_velocity() + + return W_H_B, W_v_WB def convert_mixed() -> jtp.Vector: W_H_B = data.base_transform() - W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841 + W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) W_ṗ_B = data.base_velocity()[0:3] - W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841 - return W_H_C, W_v_WC + W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) + + return W_H_BW, W_v_W_BW W_H_C, W_v_WC = jax.lax.switch( index=data.velocity_representation, @@ -1447,15 +1430,18 @@ def total_momentum_jacobian( B_Jh_B = free_floating_mass_matrix(model=model, data=data)[0:6] def to_body() -> jtp.Matrix: + return B_Jh_B def to_inertial() -> jtp.Matrix: B_X_W = Adjoint.from_transform(transform=data.base_transform(), inverse=True) + return B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) def to_mixed() -> jtp.Matrix: BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) + return B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) B_Jh = jax.lax.switch( @@ -1638,29 +1624,30 @@ def other_representation_to_inertial( # W_a_WB, and intrinsic accelerations can be expressed in different frames through # a simple C_X_W 6D transform. def to_inertial() -> jtp.Matrix: - W_H_C = W_H_W = jnp.eye(4) # noqa: F841 - W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 + W_H_W = jnp.eye(4) + W_v_WW = jnp.zeros(6) with data.switch_velocity_representation(VelRepr.Inertial): - C_v_WB = W_v_WB = data.base_velocity() # noqa: F841 - return W_H_C, W_v_WC, C_v_WB + W_v_WB = data.base_velocity() + + return W_H_W, W_v_WW, W_v_WB def to_body() -> jtp.Matrix: - W_H_C = W_H_B with data.switch_velocity_representation(VelRepr.Inertial): - W_v_WC = W_v_WB = data.base_velocity() # noqa: F841 + W_v_WB = data.base_velocity() with data.switch_velocity_representation(VelRepr.Body): - C_v_WB = B_v_WB = data.base_velocity() # noqa: F841 - return W_H_C, W_v_WC, C_v_WB + B_v_WB = data.base_velocity() + + return W_H_B, W_v_WB, B_v_WB def to_mixed() -> jtp.Matrix: W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) - W_H_C = W_H_BW with data.switch_velocity_representation(VelRepr.Mixed): W_ṗ_B = data.base_velocity()[0:3] - W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841 + W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) with data.switch_velocity_representation(VelRepr.Mixed): - C_v_WB = BW_v_WB = data.base_velocity() # noqa: F841 - return W_H_C, W_v_WC, C_v_WB + BW_v_WB = data.base_velocity() + + return W_H_BW, W_v_W_BW, BW_v_WB W_H_C, W_v_WC, C_v_WB = jax.lax.switch( index=data.velocity_representation, @@ -1772,26 +1759,23 @@ def body_to_other_representation( C_X_L = jaxsim.math.Adjoint.from_transform(transform=C_H_L) return C_X_L @ (L_v̇_WL + jaxsim.math.Cross.vx(L_v_CL) @ L_v_WL) - def to_body() -> jtp.Matrix: - C_H_L = L_H_L = jnp.stack([jnp.eye(4)] * model.number_of_links()) # noqa: F841 - L_v_CL = L_v_LL = jnp.zeros(shape=(model.number_of_links(), 6)) # noqa: F841 - return C_H_L, L_v_CL + def to_body() -> tuple[jtp.Matrix, jtp.Vector]: + L_H_L = jnp.stack([jnp.eye(4)] * model.number_of_links()) + L_v_LL = jnp.zeros(shape=(model.number_of_links(), 6)) - def to_inertial() -> jtp.Matrix: - C_H_L = W_H_L = js.model.forward_kinematics( # noqa: F841 - model=model, data=data - ) - L_v_CL = L_v_WL - return C_H_L, L_v_CL + return L_H_L, L_v_LL - def to_mixed() -> jtp.Matrix: + def to_inertial() -> tuple[jtp.Matrix, jtp.Vector]: + W_H_L = js.model.forward_kinematics(model=model, data=data) + + return W_H_L, L_v_WL + + def to_mixed() -> tuple[jtp.Matrix, jtp.Vector]: W_H_L = js.model.forward_kinematics(model=model, data=data) LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L) - C_H_L = LW_H_L - L_v_CL = L_v_LW_L = jax.vmap( # noqa: F841 - lambda v: v.at[0:3].set(jnp.zeros(3)) - )(L_v_WL) - return C_H_L, L_v_CL + L_v_LW_L = jax.vmap(lambda v: v.at[0:3].set(jnp.zeros(3)))(L_v_WL) + + return LW_H_L, L_v_LW_L C_H_L, L_v_CL = jax.lax.switch( index=data.velocity_representation, diff --git a/src/jaxsim/typing.py b/src/jaxsim/typing.py index 4a6ff0474..1e1ed5da3 100644 --- a/src/jaxsim/typing.py +++ b/src/jaxsim/typing.py @@ -38,4 +38,4 @@ BoolLike = bool | Bool | jax.typing.ArrayLike FloatLike = float | Float | jax.typing.ArrayLike -VelRepr = int +VelRepr = Int