Skip to content

Commit

Permalink
Adjust output names and return type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Aug 23, 2024
1 parent 0aae8fc commit f307618
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 149 deletions.
37 changes: 19 additions & 18 deletions src/jaxsim/api/com.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,50 +301,50 @@ def other_representation_to_body(

def to_body() -> jtp.Vector:
L_a_bias_WL = v̇_bias_WL

return L_a_bias_WL

def to_inertial() -> jtp.Vector:

C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841
C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
W_v̇_bias_WL = v̇_bias_WL
W_v_WW = jnp.zeros(6)

L_H_C = L_H_W = jax.vmap(lambda W_H_L: Transform.inverse(W_H_L))( # noqa: F841
W_H_L
)
L_H_W = jax.vmap(lambda W_H_L: Transform.inverse(W_H_L))(W_H_L)

L_v_LC = L_v_LW = jax.vmap( # noqa: F841
L_v_LW = jax.vmap(
lambda i: -js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body
)
)(jnp.arange(model.number_of_links()))

L_a_bias_WL = jax.vmap(
lambda i: other_representation_to_body(
C_v̇_WL=C_v̇_WL[i],
C_v_WC=C_v_WC,
L_H_C=L_H_C[i],
L_v_LC=L_v_LC[i],
C_v̇_WL=W_v̇_bias_WL[i],
C_v_WC=W_v_WW,
L_H_C=L_H_W[i],
L_v_LC=L_v_LW[i],
)
)(jnp.arange(model.number_of_links()))

return L_a_bias_WL

def to_mixed() -> jtp.Vector:

C_v̇_WL = LW_v̇_bias_WL = v̇_bias_WL # noqa: F841
LW_v̇_bias_WL = v̇_bias_WL

C_v_WC = LW_v_W_LW = jax.vmap( # noqa: F841
LW_v_W_LW = jax.vmap(
lambda i: js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Mixed
)
.at[3:6]
.set(jnp.zeros(3))
)(jnp.arange(model.number_of_links()))

L_H_C = L_H_LW = jax.vmap( # noqa: F841
L_H_LW = jax.vmap(
lambda W_H_L: Transform.inverse(W_H_L.at[0:3, 3].set(jnp.zeros(3)))
)(W_H_L)

L_v_LC = L_v_L_LW = jax.vmap( # noqa: F841
L_v_L_LW = jax.vmap(
lambda i: -js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body
)
Expand All @@ -354,12 +354,13 @@ def to_mixed() -> jtp.Vector:

L_a_bias_WL = jax.vmap(
lambda i: other_representation_to_body(
C_v̇_WL=C_v̇_WL[i],
C_v_WC=C_v_WC[i],
L_H_C=L_H_C[i],
L_v_LC=L_v_LC[i],
C_v̇_WL=LW_v̇_bias_WL[i],
C_v_WC=LW_v_W_LW[i],
L_H_C=L_H_LW[i],
L_v_LC=L_v_L_LW[i],
)
)(jnp.arange(model.number_of_links()))

return L_a_bias_WL

# We need here to get the body-fixed bias acceleration of the links.
Expand Down
9 changes: 6 additions & 3 deletions src/jaxsim/api/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,21 @@ def inertial_to_other_representation(
if W_H_O.shape != (4, 4):
raise ValueError(W_H_O.shape, (4, 4))

def to_inertial():
def to_inertial() -> jtp.Array:

return W_array

def to_body():
def to_body() -> jtp.Array:
if not is_force:
O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True)
O_array = O_Xv_W @ W_array
else:
O_Xf_W = Adjoint.from_transform(transform=W_H_O).T
O_array = O_Xf_W @ W_array

return O_array

def to_mixed():
def to_mixed() -> jtp.Array:
W_p_O = W_H_O[0:3, 3]
W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)
if not is_force:
Expand All @@ -133,6 +135,7 @@ def to_mixed():
else:
OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).T
OW_array = OW_Xf_W @ W_array

return OW_array

return jax.lax.switch(
Expand Down
20 changes: 11 additions & 9 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,8 @@ def jacobian(
)

def to_inertial() -> jtp.Matrix:
O_J_WC = W_J_WC
return O_J_WC

return W_J_WC

def to_body() -> jtp.Matrix:

Expand All @@ -385,8 +385,9 @@ def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
C_J_WC = C_X_W @ W_J_WC
return C_J_WC

O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC)
return O_J_WC
C_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC)

return C_J_WC

def to_mixed() -> jtp.Matrix:

Expand All @@ -401,8 +402,9 @@ def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
CW_J_WC = CW_X_W @ W_J_WC
return CW_J_WC

O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC)
return O_J_WC
CW_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC)

return CW_J_WC

# Adjust the output representation.
O_J_WC = jax.lax.switch(
Expand Down Expand Up @@ -548,13 +550,13 @@ def compute_O_J̇_WC_I(

parent_link_idx = parent_link_idxs[contact_idx]

def to_inertial():
def to_inertial() -> tuple[jtp.Matrix, jtp.Matrix]:
W_X_W = Adjoint.from_transform(transform=jnp.eye(4))
W_Ẋ_W = jnp.zeros((6, 6))

return W_X_W, W_Ẋ_W

def to_body():
def to_body() -> tuple[jtp.Matrix, jtp.Matrix]:
L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
W_H_C = W_H_L[parent_link_idx] @ L_H_C
C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
Expand All @@ -566,7 +568,7 @@ def to_body():

return C_X_W, C_Ẋ_W

def to_mixed():
def to_mixed() -> tuple[jtp.Matrix, jtp.Matrix]:
L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
W_H_C = W_H_L[parent_link_idx] @ L_H_C
W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))
Expand Down
9 changes: 6 additions & 3 deletions src/jaxsim/api/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def to_inertial() -> jtp.Matrix:
W_H_L = js.link.transform(model=model, data=data, link_index=L)
W_X_L = Adjoint.from_transform(transform=W_H_L)
W_J_WL = W_X_L @ L_J_WL

return W_J_WL

def to_body() -> jtp.Matrix:
Expand All @@ -239,6 +240,7 @@ def to_body() -> jtp.Matrix:
F_H_L = Transform.inverse(W_H_F) @ W_H_L
F_X_L = Adjoint.from_transform(transform=F_H_L)
F_J_WL = F_X_L @ L_J_WL

return F_J_WL

def to_mixed() -> jtp.Matrix:
Expand All @@ -249,6 +251,7 @@ def to_mixed() -> jtp.Matrix:
FW_H_L = FW_H_F @ F_H_L
FW_X_L = Adjoint.from_transform(transform=FW_H_L)
FW_J_WL = FW_X_L @ L_J_WL

return FW_J_WL

# Adjust the output representation
Expand Down Expand Up @@ -388,13 +391,13 @@ def from_mixed():
# Compute quantities to adjust the output representation
# =====================================================

def to_inertial():
def to_inertial() -> tuple[jtp.Matrix, jtp.Matrix]:
W_X_W = Adjoint.from_transform(transform=jnp.eye(4))
W_Ẋ_W = jnp.zeros((6, 6))

return W_X_W, W_Ẋ_W

def to_body():
def to_body() -> tuple[jtp.Matrix, jtp.Matrix]:
W_H_F = transform(model=model, data=data, frame_index=frame_index)
F_X_W = Adjoint.from_transform(transform=W_H_F, inverse=True)
with data.switch_velocity_representation(VelRepr.Inertial):
Expand All @@ -405,7 +408,7 @@ def to_body():

return F_X_W, F_Ẋ_W

def to_mixed():
def to_mixed() -> tuple[jtp.Matrix, jtp.Matrix]:
W_H_F = transform(model=model, data=data, frame_index=frame_index)
W_H_FW = W_H_F.at[0:3, 0:3].set(jnp.eye(3))
FW_H_W = Transform.inverse(W_H_FW)
Expand Down
54 changes: 29 additions & 25 deletions src/jaxsim/api/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,21 +287,22 @@ def jacobian(
def to_inertial() -> jtp.Matrix:
W_H_B = data.base_transform()
B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841
B_X_W, jnp.eye(model.dofs())
)
B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))

return B_J_WL_W

def to_body() -> jtp.Matrix:

return B_J_WL_B

def to_mixed() -> jtp.Matrix:
W_R_B = data.base_orientation(dcm=True)
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841
B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag(
B_X_BW, jnp.eye(model.dofs())
)

return B_J_WL_BW

B_J_WL_I = jax.lax.switch(
Expand All @@ -319,11 +320,13 @@ def to_inertial() -> jtp.Matrix:
W_H_B = data.base_transform()
W_X_B = Adjoint.from_transform(transform=W_H_B)
W_J_WL_I = W_X_B @ B_J_WL_I

return W_J_WL_I

def to_body() -> jtp.Matrix:
L_X_B = Adjoint.from_transform(transform=B_H_L, inverse=True)
L_J_WL_I = L_X_B @ B_J_WL_I

return L_J_WL_I

def to_mixed() -> jtp.Matrix:
Expand All @@ -333,6 +336,7 @@ def to_mixed() -> jtp.Matrix:
LW_H_B = LW_H_L @ Transform.inverse(B_H_L)
LW_X_B = Adjoint.from_transform(transform=LW_H_B)
LW_J_WL_I = LW_X_B @ B_J_WL_I

return LW_J_WL_I

# Adjust the output representation such that `O_v_WL_I = O_J_WL_I @ I_ν`.
Expand Down Expand Up @@ -455,7 +459,7 @@ def jacobian_derivative(
In = jnp.eye(model.dofs())
On = jnp.zeros(shape=(model.dofs(), model.dofs()))

def from_inertial() -> jtp.Matrix:
def from_inertial() -> tuple[jtp.Matrix, jtp.Matrix]:

W_H_B = data.base_transform()
B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
Expand All @@ -468,9 +472,10 @@ def from_inertial() -> jtp.Matrix:
# time derivative.
T = jax.scipy.linalg.block_diag(B_X_W, In)
= jax.scipy.linalg.block_diag(B_Ẋ_W, On)

return T,

def from_body() -> jtp.Matrix:
def from_body() -> tuple[jtp.Matrix, jtp.Matrix]:

B_X_B = Adjoint.from_rotation_and_translation(
translation=jnp.zeros(3), rotation=jnp.eye(3)
Expand All @@ -482,9 +487,10 @@ def from_body() -> jtp.Matrix:
# time derivative.
T = jax.scipy.linalg.block_diag(B_X_B, In)
= jax.scipy.linalg.block_diag(B_Ẋ_B, On)

return T,

def from_mixed() -> jtp.Matrix:
def from_mixed() -> tuple[jtp.Matrix, jtp.Matrix]:

BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3))
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
Expand All @@ -500,6 +506,7 @@ def from_mixed() -> jtp.Matrix:
# time derivative.
T = jax.scipy.linalg.block_diag(B_X_BW, In)
= jax.scipy.linalg.block_diag(B_Ẋ_BW, On)

return T,

T, = jax.lax.switch(
Expand All @@ -515,42 +522,40 @@ def from_mixed() -> jtp.Matrix:
# Compute quantities to adjust the output representation
# ======================================================

def to_inertial() -> jtp.Matrix:
def to_inertial() -> tuple[jtp.Matrix, jtp.Matrix]:

W_H_B = data.base_transform()
O_X_B = W_X_B = Adjoint.from_transform(transform=W_H_B)
W_X_B = Adjoint.from_transform(transform=W_H_B)

with data.switch_velocity_representation(VelRepr.Body):
B_v_WB = data.base_velocity()

O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841
return O_X_B, O_Ẋ_B
W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB)

def to_body() -> jtp.Matrix:
return W_X_B, W_Ẋ_B

O_X_B = L_X_B = Adjoint.from_transform(
transform=B_H_L[link_index, :, :], inverse=True
)
def to_body() -> tuple[jtp.Matrix, jtp.Matrix]:

L_X_B = Adjoint.from_transform(transform=B_H_L[link_index, :, :], inverse=True)

B_X_L = Adjoint.inverse(adjoint=L_X_B)

with data.switch_velocity_representation(VelRepr.Body):
B_v_WB = data.base_velocity()
L_v_WL = js.link.velocity(model=model, data=data, link_index=link_index)

O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx( # noqa: F841
B_X_L @ L_v_WL - B_v_WB
)
return O_X_B, O_Ẋ_B
L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx(B_X_L @ L_v_WL - B_v_WB)

def to_mixed() -> jtp.Matrix:
return L_X_B, L_Ẋ_B

def to_mixed() -> tuple[jtp.Matrix, jtp.Matrix]:

W_H_B = data.base_transform()
W_H_L = W_H_B @ B_H_L[link_index, :, :]
LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3))
LW_H_B = LW_H_L @ Transform.inverse(B_H_L[link_index, :, :])

O_X_B = LW_X_B = Adjoint.from_transform(transform=LW_H_B)
LW_X_B = Adjoint.from_transform(transform=LW_H_B)

B_X_LW = Adjoint.inverse(adjoint=LW_X_B)

Expand All @@ -564,10 +569,9 @@ def to_mixed() -> jtp.Matrix:
LW_v_LW_L = LW_v_WL - LW_v_W_LW
LW_v_B_LW = LW_v_WL - LW_X_B @ B_v_WB - LW_v_LW_L

O_Ẋ_B = LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx( # noqa: F841
B_X_LW @ LW_v_B_LW
)
return O_X_B, O_Ẋ_B
LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx(B_X_LW @ LW_v_B_LW)

return LW_X_B, LW_Ẋ_B

O_X_B, O_Ẋ_B = jax.lax.switch(
index=output_vel_repr,
Expand Down
Loading

0 comments on commit f307618

Please sign in to comment.