Skip to content

Commit

Permalink
Add converters for mass matrix, jacobian and jacobian derivative
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Dec 17, 2024
1 parent f039932 commit a48f7e1
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions src/jaxsim/api/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,69 @@ def other_representation_to_inertial(

case _:
raise ValueError(other_representation)


def convert_mass_matrix(
M: jtp.Matrix,
base_transform: jtp.Matrix,
dofs: jtp.Int,
velocity_representation: VelRepr,
):

return M

match velocity_representation:
case VelRepr.Body:
return M

case VelRepr.Inertial:

B_X_W = Adjoint.from_transform(transform=base_transform, inverse=True)
invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(dofs))

return invT.T @ M @ invT

case VelRepr.Mixed:

BW_H_B = base_transform.at[0:3, 3].set(jnp.zeros(3))
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(dofs))

return invT.T @ M @ invT


def convert_jacobian(
J: jtp.Matrix,
base_transform: jtp.Matrix,
dofs: jtp.Int,
velocity_representation: VelRepr,
):
return J

match velocity_representation:

case VelRepr.Inertial:

W_H_B = base_transform
B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)

return J @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(dofs))

case VelRepr.Body:

return J

case VelRepr.Mixed:

W_R_B = base_transform[:3, :3]
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)

return J @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(dofs))

case _:
raise ValueError(velocity_representation)


def convert_jacobian_derivative(J: jtp.Matrix, velocity_representation: VelRepr):
return J

0 comments on commit a48f7e1

Please sign in to comment.