Skip to content

Commit

Permalink
[WIP] Save some kindyn computation in JaxSimModelData
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Dec 16, 2024
1 parent 30472ed commit 66173de
Show file tree
Hide file tree
Showing 13 changed files with 149 additions and 63 deletions.
4 changes: 3 additions & 1 deletion src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@ def collidable_point_kinematics(

W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
model=model,
base_position=data.base_position(),
base_position=data.base_position,
base_quaternion=data.base_orientation(dcm=False),
joint_positions=data.joint_positions(model=model),
base_linear_velocity=data.base_velocity()[0:3],
base_angular_velocity=data.base_velocity()[3:6],
joint_velocities=data.joint_velocities(model=model),
joint_transforms=data.joint_transforms,
motion_subspaces=data.motion_subspaces,
)

return W_p_Ci, W_ṗ_Ci
Expand Down
34 changes: 30 additions & 4 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):

state: ODEState

jacobian: jtp.Matrix

jacobian_derivative: jtp.Matrix

motion_subspaces: jtp.Matrix

joint_transforms: jtp.Matrix

mass_matrix: jtp.Matrix

gravity: jtp.Vector

contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False)
Expand Down Expand Up @@ -232,11 +242,28 @@ def build(
else:
contacts_params = model.contact_model._parameters_class()

n = model.dofs()
n_fb = n + 6 * model.floating_base()

jacobian = jnp.zeros((model.number_of_links(), 6, n_fb))
jacobian_derivative = jnp.zeros((model.number_of_links(), 6, n_fb))
motion_subspaces = jnp.zeros((model.number_of_links(), 6, 1))
joint_transforms = jnp.zeros((model.number_of_links(), 6, 6))
mass_matrix = jnp.zeros((n_fb, n_fb))

print(jacobian.shape)

return JaxSimModelData(
state=ode_state,
gravity=gravity,
contacts_params=contacts_params,
velocity_representation=velocity_representation,
#
jacobian=jacobian,
jacobian_derivative=jacobian_derivative,
motion_subspaces=motion_subspaces,
joint_transforms=joint_transforms,
mass_matrix=mass_matrix,
)

# ==================
Expand Down Expand Up @@ -349,8 +376,7 @@ def joint_velocities(

return self.state.physics_model.joint_velocities[joint_idxs]

@js.common.named_scope
@jax.jit
@property
def base_position(self) -> jtp.Vector:
"""
Get the base position.
Expand All @@ -359,7 +385,7 @@ def base_position(self) -> jtp.Vector:
The base position.
"""

return self.state.physics_model.base_position.squeeze()
return self.state.physics_model.base_position

@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["dcm"])
Expand Down Expand Up @@ -400,7 +426,7 @@ def base_transform(self) -> jtp.Matrix:
"""

W_R_B = self.base_orientation(dcm=True)
W_p_B = jnp.vstack(self.base_position())
W_p_B = jnp.vstack(self.base_position)

return jnp.vstack(
[
Expand Down
2 changes: 2 additions & 0 deletions src/jaxsim/api/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ def jacobian(
B_J_full_WX_B, B_H_Li = jaxsim.rbda.jacobian_full_doubly_left(
model=model,
joint_positions=data.joint_positions(),
joint_transforms=data.joint_transforms,
motion_subspaces=data.motion_subspaces,
)

# Compute the actual doubly-left free-floating jacobian of the link.
Expand Down
39 changes: 34 additions & 5 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,9 +574,10 @@ def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp

W_H_LL = jaxsim.rbda.forward_kinematics_model(
model=model,
base_position=data.base_position(),
base_position=data.base_position,
base_quaternion=data.base_orientation(dcm=False),
joint_positions=data.joint_positions(model=model),
joint_transforms=data.joint_transforms,
)

return jnp.atleast_3d(W_H_LL).astype(float)
Expand Down Expand Up @@ -616,6 +617,8 @@ def generalized_free_floating_jacobian(
B_J_full_WX_B, B_H_L = jaxsim.rbda.jacobian_full_doubly_left(
model=model,
joint_positions=data.joint_positions(),
joint_transforms=data.joint_transforms,
motion_subspaces=data.motion_subspaces,
)

# ======================================================================
Expand Down Expand Up @@ -743,13 +746,17 @@ def generalized_free_floating_jacobian_derivative(
model=model,
joint_positions=data.joint_positions(),
joint_velocities=data.joint_velocities(),
# joint_transforms=data.joint_transforms,
# motion_subspaces=data.motion_subspaces,
)

# The derivative of the equation to change the input and output representations
# of the Jacobian derivative needs the computation of the plain link Jacobian.
B_J_full_WL_B, _ = jaxsim.rbda.jacobian_full_doubly_left(
model=model,
joint_positions=data.joint_positions(),
joint_transforms=data.joint_transforms,
motion_subspaces=data.motion_subspaces,
)

# Compute the actual doubly-left free-floating jacobian derivative of the link
Expand Down Expand Up @@ -1005,7 +1012,7 @@ def forward_dynamics_aba(

# Extract the state in inertial-fixed representation.
with data.switch_velocity_representation(VelRepr.Inertial):
W_p_B = data.base_position()
W_p_B = data.base_position
W_v_WB = data.base_velocity()
W_Q_B = data.base_orientation(dcm=False)
s = data.joint_positions(model=model, joint_names=joint_names)
Expand All @@ -1031,6 +1038,8 @@ def forward_dynamics_aba(
joint_forces=τ,
link_forces=W_f_L,
standard_gravity=data.standard_gravity(),
joint_transforms=data.joint_transforms,
motion_subspaces=data.motion_subspaces,
)

# =============
Expand Down Expand Up @@ -1201,6 +1210,8 @@ def free_floating_mass_matrix(
M_body = jaxsim.rbda.crba(
model=model,
joint_positions=data.state.physics_model.joint_positions,
joint_transforms=data.joint_transforms,
motion_subspaces=data.motion_subspaces,
)

match data.velocity_representation:
Expand Down Expand Up @@ -1457,7 +1468,7 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):

# Extract the state in inertial-fixed representation.
with data.switch_velocity_representation(VelRepr.Inertial):
W_p_B = data.base_position()
W_p_B = data.base_position
W_v_WB = data.base_velocity()
W_Q_B = data.base_orientation(dcm=False)
s = data.joint_positions(model=model, joint_names=joint_names)
Expand All @@ -1484,6 +1495,8 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):
joint_accelerations=,
link_forces=W_f_L,
standard_gravity=data.standard_gravity(),
joint_transforms=data.joint_transforms,
motion_subspaces=data.motion_subspaces,
)

# =============
Expand Down Expand Up @@ -1792,7 +1805,7 @@ def average_velocity_jacobian(
case VelRepr.Body:

GB_J = G_J
W_p_B = data.base_position()
W_p_B = data.base_position
W_p_CoM = js.com.com_position(model=model, data=data)
B_R_W = data.base_orientation(dcm=True).transpose()

Expand All @@ -1804,7 +1817,7 @@ def average_velocity_jacobian(
case VelRepr.Mixed:

GW_J = G_J
W_p_B = data.base_position()
W_p_B = data.base_position
W_p_CoM = js.com.com_position(model=model, data=data)

BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B)
Expand Down Expand Up @@ -2182,6 +2195,22 @@ def forward(
joint_force_references: jtp.VectorLike | None = None,
) -> js.data.JaxSimModelData:

# Kinematics computation.
M = js.model.free_floating_mass_matrix(model=model, data=data)
J = js.model.generalized_free_floating_jacobian(model=model, data=data)
= js.model.generalized_free_floating_jacobian_derivative(model=model, data=data)
i_X_λ, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
joint_positions=data.joint_positions(), base_transform=data.base_transform()
)

data = data.replace(
jacobian=J,
jacobian_derivative=,
joint_transforms=i_X_λ,
motion_subspaces=S,
mass_matrix=M,
)

# TODO: some contact models here may want to perform a dynamic filtering of
# the enabled collidable points.

Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/mujoco/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def mujoco_data_from_jaxsim(
if jaxsim_model.floating_base():

# Set the model position.
model_helper.set_base_position(position=np.array(jaxsim_data.base_position()))
model_helper.set_base_position(position=np.array(jaxsim_data.base_position))

# Set the model orientation.
model_helper.set_base_orientation(
Expand Down
9 changes: 6 additions & 3 deletions src/jaxsim/rbda/aba.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def aba(
joint_forces: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
standard_gravity: jtp.FloatLike = StandardGravity,
joint_transforms,
motion_subspaces,
) -> tuple[jtp.Vector, jtp.Vector]:
"""
Compute forward dynamics using the Articulated Body Algorithm (ABA).
Expand Down Expand Up @@ -88,9 +90,10 @@ def aba(
# Compute the parent-to-child adjoints and the motion subspaces of the joints.
# These transforms define the relative kinematics of the entire model, including
# the base transform for both floating-base and fixed-base models.
i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
joint_positions=s, base_transform=W_H_B.as_matrix()
)
# i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
# joint_positions=s, base_transform=W_H_B.as_matrix()
# )
i_X_λi, S = joint_transforms, motion_subspaces

# Allocate buffers.
v = jnp.zeros(shape=(model.number_of_links(), 6, 1))
Expand Down
22 changes: 13 additions & 9 deletions src/jaxsim/rbda/collidable_points.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import jax
import jax.numpy as jnp
import jaxlie

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import Adjoint, Skew

from . import utils

# import jaxlie


def collidable_points_pos_vel(
model: js.model.JaxSimModel,
Expand All @@ -18,6 +19,8 @@ def collidable_points_pos_vel(
base_linear_velocity: jtp.Vector,
base_angular_velocity: jtp.Vector,
joint_velocities: jtp.Vector,
joint_transforms,
motion_subspaces,
) -> tuple[jtp.Matrix, jtp.Matrix]:
"""
Expand Down Expand Up @@ -54,7 +57,7 @@ def collidable_points_pos_vel(
if len(indices_of_enabled_collidable_points) == 0:
return jnp.array(0).astype(float), jnp.empty(0).astype(float)

W_p_B, W_Q_B, s, W_v_WB, , _, _, _, _, _ = utils.process_inputs(
_, _, _, W_v_WB, , _, _, _, _, _ = utils.process_inputs(
model=model,
base_position=base_position,
base_quaternion=base_quaternion,
Expand All @@ -69,17 +72,18 @@ def collidable_points_pos_vel(
λ = model.kin_dyn_parameters.parent_array

# Compute the base transform.
W_H_B = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3(wxyz=W_Q_B),
translation=W_p_B,
)
# W_H_B = jaxlie.SE3.from_rotation_and_translation(
# rotation=jaxlie.SO3(wxyz=W_Q_B),
# translation=W_p_B,
# )

# Compute the parent-to-child adjoints and the motion subspaces of the joints.
# These transforms define the relative kinematics of the entire model, including
# the base transform for both floating-base and fixed-base models.
i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
joint_positions=s, base_transform=W_H_B.as_matrix()
)
# i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
# joint_positions=s, base_transform=W_H_B.as_matrix()
# )
i_X_λi, S = joint_transforms, motion_subspaces

# Allocate buffer of transforms world -> link and initialize the base pose.
W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def compute_contact_forces(
)
)

M = js.model.free_floating_mass_matrix(model=model, data=data)
M = data.mass_matrix

Jl_WC = jnp.vstack(
jax.vmap(lambda J, δ: J * (δ > 0))(
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/rbda/contacts/visco_elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ def integrate_data_with_average_contact_forces(
"""

s_t0 = data.joint_positions()
W_p_B_t0 = data.base_position()
W_p_B_t0 = data.base_position
W_Q_B_t0 = data.base_orientation(dcm=False)

ṡ_t0 = data.joint_velocities()
Expand Down
23 changes: 15 additions & 8 deletions src/jaxsim/rbda/crba.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@
import jaxsim.api as js
import jaxsim.typing as jtp

from . import utils
# from . import utils


def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Matrix:
def crba(
model: js.model.JaxSimModel,
*,
joint_positions: jtp.Vector,
joint_transforms,
motion_subspaces,
) -> jtp.Matrix:
"""
Compute the free-floating mass matrix using the Composite Rigid-Body Algorithm (CRBA).
Expand All @@ -19,9 +25,9 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
The free-floating mass matrix of the model in body-fixed representation.
"""

_, _, s, _, _, _, _, _, _, _ = utils.process_inputs(
model=model, joint_positions=joint_positions
)
# _, _, s, _, _, _, _, _, _, _ = utils.process_inputs(
# model=model, joint_positions=joint_positions
# )

# Get the 6D spatial inertia matrices of all links.
Mc = js.model.link_spatial_inertia_matrices(model=model)
Expand All @@ -33,9 +39,10 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
# Compute the parent-to-child adjoints and the motion subspaces of the joints.
# These transforms define the relative kinematics of the entire model, including
# the base transform for both floating-base and fixed-base models.
i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
joint_positions=s, base_transform=jnp.eye(4)
)
# i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
# joint_positions=s, base_transform=W_H_B.as_matrix()
# )
i_X_λi, S = joint_transforms, motion_subspaces

# Allocate the buffer of transforms link -> base.
i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
Expand Down
Loading

0 comments on commit 66173de

Please sign in to comment.