From 66173de3cc64ff7432015465039034f0d954c27c Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 16 Dec 2024 15:11:03 +0100 Subject: [PATCH] [WIP] Save some kindyn computation in `JaxSimModelData` --- src/jaxsim/api/contact.py | 4 ++- src/jaxsim/api/data.py | 34 +++++++++++++++++--- src/jaxsim/api/link.py | 2 ++ src/jaxsim/api/model.py | 39 ++++++++++++++++++++--- src/jaxsim/mujoco/utils.py | 2 +- src/jaxsim/rbda/aba.py | 9 ++++-- src/jaxsim/rbda/collidable_points.py | 22 +++++++------ src/jaxsim/rbda/contacts/relaxed_rigid.py | 2 +- src/jaxsim/rbda/contacts/visco_elastic.py | 2 +- src/jaxsim/rbda/crba.py | 23 ++++++++----- src/jaxsim/rbda/forward_kinematics.py | 34 +++++++++++--------- src/jaxsim/rbda/jacobian.py | 30 ++++++++++------- src/jaxsim/rbda/rnea.py | 9 ++++-- 13 files changed, 149 insertions(+), 63 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 31d2245d4..bdb08a911 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -42,12 +42,14 @@ def collidable_point_kinematics( W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel( model=model, - base_position=data.base_position(), + base_position=data.base_position, base_quaternion=data.base_orientation(dcm=False), joint_positions=data.joint_positions(model=model), base_linear_velocity=data.base_velocity()[0:3], base_angular_velocity=data.base_velocity()[3:6], joint_velocities=data.joint_velocities(model=model), + joint_transforms=data.joint_transforms, + motion_subspaces=data.motion_subspaces, ) return W_p_Ci, W_ṗ_Ci diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index af3ea7ef2..ea504fb07 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -34,6 +34,16 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation): state: ODEState + jacobian: jtp.Matrix + + jacobian_derivative: jtp.Matrix + + motion_subspaces: jtp.Matrix + + joint_transforms: jtp.Matrix + + mass_matrix: jtp.Matrix + gravity: jtp.Vector contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False) @@ -232,11 +242,28 @@ def build( else: contacts_params = model.contact_model._parameters_class() + n = model.dofs() + n_fb = n + 6 * model.floating_base() + + jacobian = jnp.zeros((model.number_of_links(), 6, n_fb)) + jacobian_derivative = jnp.zeros((model.number_of_links(), 6, n_fb)) + motion_subspaces = jnp.zeros((model.number_of_links(), 6, 1)) + joint_transforms = jnp.zeros((model.number_of_links(), 6, 6)) + mass_matrix = jnp.zeros((n_fb, n_fb)) + + print(jacobian.shape) + return JaxSimModelData( state=ode_state, gravity=gravity, contacts_params=contacts_params, velocity_representation=velocity_representation, + # + jacobian=jacobian, + jacobian_derivative=jacobian_derivative, + motion_subspaces=motion_subspaces, + joint_transforms=joint_transforms, + mass_matrix=mass_matrix, ) # ================== @@ -349,8 +376,7 @@ def joint_velocities( return self.state.physics_model.joint_velocities[joint_idxs] - @js.common.named_scope - @jax.jit + @property def base_position(self) -> jtp.Vector: """ Get the base position. @@ -359,7 +385,7 @@ def base_position(self) -> jtp.Vector: The base position. """ - return self.state.physics_model.base_position.squeeze() + return self.state.physics_model.base_position @js.common.named_scope @functools.partial(jax.jit, static_argnames=["dcm"]) @@ -400,7 +426,7 @@ def base_transform(self) -> jtp.Matrix: """ W_R_B = self.base_orientation(dcm=True) - W_p_B = jnp.vstack(self.base_position()) + W_p_B = jnp.vstack(self.base_position) return jnp.vstack( [ diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index 23b4d3732..118ac7150 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -276,6 +276,8 @@ def jacobian( B_J_full_WX_B, B_H_Li = jaxsim.rbda.jacobian_full_doubly_left( model=model, joint_positions=data.joint_positions(), + joint_transforms=data.joint_transforms, + motion_subspaces=data.motion_subspaces, ) # Compute the actual doubly-left free-floating jacobian of the link. diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 27e7cd616..3dceb39b0 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -574,9 +574,10 @@ def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp W_H_LL = jaxsim.rbda.forward_kinematics_model( model=model, - base_position=data.base_position(), + base_position=data.base_position, base_quaternion=data.base_orientation(dcm=False), joint_positions=data.joint_positions(model=model), + joint_transforms=data.joint_transforms, ) return jnp.atleast_3d(W_H_LL).astype(float) @@ -616,6 +617,8 @@ def generalized_free_floating_jacobian( B_J_full_WX_B, B_H_L = jaxsim.rbda.jacobian_full_doubly_left( model=model, joint_positions=data.joint_positions(), + joint_transforms=data.joint_transforms, + motion_subspaces=data.motion_subspaces, ) # ====================================================================== @@ -743,6 +746,8 @@ def generalized_free_floating_jacobian_derivative( model=model, joint_positions=data.joint_positions(), joint_velocities=data.joint_velocities(), + # joint_transforms=data.joint_transforms, + # motion_subspaces=data.motion_subspaces, ) # The derivative of the equation to change the input and output representations @@ -750,6 +755,8 @@ def generalized_free_floating_jacobian_derivative( B_J_full_WL_B, _ = jaxsim.rbda.jacobian_full_doubly_left( model=model, joint_positions=data.joint_positions(), + joint_transforms=data.joint_transforms, + motion_subspaces=data.motion_subspaces, ) # Compute the actual doubly-left free-floating jacobian derivative of the link @@ -1005,7 +1012,7 @@ def forward_dynamics_aba( # Extract the state in inertial-fixed representation. with data.switch_velocity_representation(VelRepr.Inertial): - W_p_B = data.base_position() + W_p_B = data.base_position W_v_WB = data.base_velocity() W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model, joint_names=joint_names) @@ -1031,6 +1038,8 @@ def forward_dynamics_aba( joint_forces=τ, link_forces=W_f_L, standard_gravity=data.standard_gravity(), + joint_transforms=data.joint_transforms, + motion_subspaces=data.motion_subspaces, ) # ============= @@ -1201,6 +1210,8 @@ def free_floating_mass_matrix( M_body = jaxsim.rbda.crba( model=model, joint_positions=data.state.physics_model.joint_positions, + joint_transforms=data.joint_transforms, + motion_subspaces=data.motion_subspaces, ) match data.velocity_representation: @@ -1457,7 +1468,7 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): # Extract the state in inertial-fixed representation. with data.switch_velocity_representation(VelRepr.Inertial): - W_p_B = data.base_position() + W_p_B = data.base_position W_v_WB = data.base_velocity() W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model, joint_names=joint_names) @@ -1484,6 +1495,8 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): joint_accelerations=s̈, link_forces=W_f_L, standard_gravity=data.standard_gravity(), + joint_transforms=data.joint_transforms, + motion_subspaces=data.motion_subspaces, ) # ============= @@ -1792,7 +1805,7 @@ def average_velocity_jacobian( case VelRepr.Body: GB_J = G_J - W_p_B = data.base_position() + W_p_B = data.base_position W_p_CoM = js.com.com_position(model=model, data=data) B_R_W = data.base_orientation(dcm=True).transpose() @@ -1804,7 +1817,7 @@ def average_velocity_jacobian( case VelRepr.Mixed: GW_J = G_J - W_p_B = data.base_position() + W_p_B = data.base_position W_p_CoM = js.com.com_position(model=model, data=data) BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B) @@ -2182,6 +2195,22 @@ def forward( joint_force_references: jtp.VectorLike | None = None, ) -> js.data.JaxSimModelData: + # Kinematics computation. + M = js.model.free_floating_mass_matrix(model=model, data=data) + J = js.model.generalized_free_floating_jacobian(model=model, data=data) + J̇ = js.model.generalized_free_floating_jacobian_derivative(model=model, data=data) + i_X_λ, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + joint_positions=data.joint_positions(), base_transform=data.base_transform() + ) + + data = data.replace( + jacobian=J, + jacobian_derivative=J̇, + joint_transforms=i_X_λ, + motion_subspaces=S, + mass_matrix=M, + ) + # TODO: some contact models here may want to perform a dynamic filtering of # the enabled collidable points. diff --git a/src/jaxsim/mujoco/utils.py b/src/jaxsim/mujoco/utils.py index 2afff1732..c80bfa79e 100644 --- a/src/jaxsim/mujoco/utils.py +++ b/src/jaxsim/mujoco/utils.py @@ -59,7 +59,7 @@ def mujoco_data_from_jaxsim( if jaxsim_model.floating_base(): # Set the model position. - model_helper.set_base_position(position=np.array(jaxsim_data.base_position())) + model_helper.set_base_position(position=np.array(jaxsim_data.base_position)) # Set the model orientation. model_helper.set_base_orientation( diff --git a/src/jaxsim/rbda/aba.py b/src/jaxsim/rbda/aba.py index b01f46698..d7b596e6e 100644 --- a/src/jaxsim/rbda/aba.py +++ b/src/jaxsim/rbda/aba.py @@ -21,6 +21,8 @@ def aba( joint_forces: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, standard_gravity: jtp.FloatLike = StandardGravity, + joint_transforms, + motion_subspaces, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute forward dynamics using the Articulated Body Algorithm (ABA). @@ -88,9 +90,10 @@ def aba( # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=W_H_B.as_matrix() - ) + # i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + # joint_positions=s, base_transform=W_H_B.as_matrix() + # ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate buffers. v = jnp.zeros(shape=(model.number_of_links(), 6, 1)) diff --git a/src/jaxsim/rbda/collidable_points.py b/src/jaxsim/rbda/collidable_points.py index 543be5328..f85bfc49a 100644 --- a/src/jaxsim/rbda/collidable_points.py +++ b/src/jaxsim/rbda/collidable_points.py @@ -1,6 +1,5 @@ import jax import jax.numpy as jnp -import jaxlie import jaxsim.api as js import jaxsim.typing as jtp @@ -8,6 +7,8 @@ from . import utils +# import jaxlie + def collidable_points_pos_vel( model: js.model.JaxSimModel, @@ -18,6 +19,8 @@ def collidable_points_pos_vel( base_linear_velocity: jtp.Vector, base_angular_velocity: jtp.Vector, joint_velocities: jtp.Vector, + joint_transforms, + motion_subspaces, ) -> tuple[jtp.Matrix, jtp.Matrix]: """ @@ -54,7 +57,7 @@ def collidable_points_pos_vel( if len(indices_of_enabled_collidable_points) == 0: return jnp.array(0).astype(float), jnp.empty(0).astype(float) - W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs( + _, _, _, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs( model=model, base_position=base_position, base_quaternion=base_quaternion, @@ -69,17 +72,18 @@ def collidable_points_pos_vel( λ = model.kin_dyn_parameters.parent_array # Compute the base transform. - W_H_B = jaxlie.SE3.from_rotation_and_translation( - rotation=jaxlie.SO3(wxyz=W_Q_B), - translation=W_p_B, - ) + # W_H_B = jaxlie.SE3.from_rotation_and_translation( + # rotation=jaxlie.SO3(wxyz=W_Q_B), + # translation=W_p_B, + # ) # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=W_H_B.as_matrix() - ) + # i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + # joint_positions=s, base_transform=W_H_B.as_matrix() + # ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate buffer of transforms world -> link and initialize the base pose. W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6)) diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index fd84c6941..85de1404f 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -320,7 +320,7 @@ def compute_contact_forces( ) ) - M = js.model.free_floating_mass_matrix(model=model, data=data) + M = data.mass_matrix Jl_WC = jnp.vstack( jax.vmap(lambda J, δ: J * (δ > 0))( diff --git a/src/jaxsim/rbda/contacts/visco_elastic.py b/src/jaxsim/rbda/contacts/visco_elastic.py index c433fe23d..cd82661ae 100644 --- a/src/jaxsim/rbda/contacts/visco_elastic.py +++ b/src/jaxsim/rbda/contacts/visco_elastic.py @@ -827,7 +827,7 @@ def integrate_data_with_average_contact_forces( """ s_t0 = data.joint_positions() - W_p_B_t0 = data.base_position() + W_p_B_t0 = data.base_position W_Q_B_t0 = data.base_orientation(dcm=False) ṡ_t0 = data.joint_velocities() diff --git a/src/jaxsim/rbda/crba.py b/src/jaxsim/rbda/crba.py index adb3506ae..20b77adf3 100644 --- a/src/jaxsim/rbda/crba.py +++ b/src/jaxsim/rbda/crba.py @@ -4,10 +4,16 @@ import jaxsim.api as js import jaxsim.typing as jtp -from . import utils +# from . import utils -def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Matrix: +def crba( + model: js.model.JaxSimModel, + *, + joint_positions: jtp.Vector, + joint_transforms, + motion_subspaces, +) -> jtp.Matrix: """ Compute the free-floating mass matrix using the Composite Rigid-Body Algorithm (CRBA). @@ -19,9 +25,9 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat The free-floating mass matrix of the model in body-fixed representation. """ - _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( - model=model, joint_positions=joint_positions - ) + # _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( + # model=model, joint_positions=joint_positions + # ) # Get the 6D spatial inertia matrices of all links. Mc = js.model.link_spatial_inertia_matrices(model=model) @@ -33,9 +39,10 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=jnp.eye(4) - ) + # i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + # joint_positions=s, base_transform=W_H_B.as_matrix() + # ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate the buffer of transforms link -> base. i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6)) diff --git a/src/jaxsim/rbda/forward_kinematics.py b/src/jaxsim/rbda/forward_kinematics.py index d11e9b45d..b3bed3460 100644 --- a/src/jaxsim/rbda/forward_kinematics.py +++ b/src/jaxsim/rbda/forward_kinematics.py @@ -1,12 +1,14 @@ import jax import jax.numpy as jnp -import jaxlie import jaxsim.api as js import jaxsim.typing as jtp from jaxsim.math import Adjoint -from . import utils +# import jaxlie + + +# from . import utils def forward_kinematics_model( @@ -15,6 +17,7 @@ def forward_kinematics_model( base_position: jtp.VectorLike, base_quaternion: jtp.VectorLike, joint_positions: jtp.VectorLike, + joint_transforms, ) -> jtp.Array: """ Compute the forward kinematics. @@ -29,29 +32,30 @@ def forward_kinematics_model( A 3D array containing the SE(3) transforms of all links belonging to the model. """ - W_p_B, W_Q_B, s, _, _, _, _, _, _, _ = utils.process_inputs( - model=model, - base_position=base_position, - base_quaternion=base_quaternion, - joint_positions=joint_positions, - ) + # W_p_B, W_Q_B, s, _, _, _, _, _, _, _ = utils.process_inputs( + # model=model, + # base_position=base_position, + # base_quaternion=base_quaternion, + # joint_positions=joint_positions, + # ) # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array # Compute the base transform. - W_H_B = jaxlie.SE3.from_rotation_and_translation( - rotation=jaxlie.SO3(wxyz=W_Q_B), - translation=W_p_B, - ) + # W_H_B = jaxlie.SE3.from_rotation_and_translation( + # rotation=jaxlie.SO3(wxyz=W_Q_B), + # translation=W_p_B, + # ) # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, _ = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=W_H_B.as_matrix() - ) + # i_X_λi, _ = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + # joint_positions=s, base_transform=W_H_B.as_matrix() + # ) + i_X_λi = joint_transforms # Allocate the buffer of transforms world -> link and initialize the base pose. W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6)) diff --git a/src/jaxsim/rbda/jacobian.py b/src/jaxsim/rbda/jacobian.py index e8f44d088..e79970019 100644 --- a/src/jaxsim/rbda/jacobian.py +++ b/src/jaxsim/rbda/jacobian.py @@ -14,6 +14,8 @@ def jacobian( *, link_index: jtp.Int, joint_positions: jtp.VectorLike, + joint_transforms, + motion_subspaces, ) -> jtp.Matrix: """ Compute the free-floating Jacobian of a link. @@ -27,9 +29,9 @@ def jacobian( The free-floating left-trivialized Jacobian of the link :math:`{}^L J_{W,L/B}`. """ - _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( - model=model, joint_positions=joint_positions - ) + # _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( + # model=model, joint_positions=joint_positions + # ) # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. @@ -38,9 +40,10 @@ def jacobian( # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=jnp.eye(4) - ) + # i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + # joint_positions=s, base_transform=W_H_B.as_matrix() + # ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate the buffer of transforms link -> base. i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6)) @@ -127,6 +130,8 @@ def jacobian_full_doubly_left( model: js.model.JaxSimModel, *, joint_positions: jtp.VectorLike, + joint_transforms, + motion_subspaces, ) -> tuple[jtp.Matrix, jtp.Array]: r""" Compute the doubly-left full free-floating Jacobian of a model. @@ -144,9 +149,9 @@ def jacobian_full_doubly_left( The doubly-left full free-floating Jacobian of a model. """ - _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( - model=model, joint_positions=joint_positions - ) + # _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( + # model=model, joint_positions=joint_positions + # ) # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. @@ -155,9 +160,10 @@ def jacobian_full_doubly_left( # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=jnp.eye(4) - ) + # i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + # joint_positions=s, base_transform=W_H_B.as_matrix() + # ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate the buffer of transforms base -> link. B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6)) diff --git a/src/jaxsim/rbda/rnea.py b/src/jaxsim/rbda/rnea.py index 025d85a62..3b7ba5975 100644 --- a/src/jaxsim/rbda/rnea.py +++ b/src/jaxsim/rbda/rnea.py @@ -23,6 +23,8 @@ def rnea( joint_accelerations: jtp.Vector | None = None, link_forces: jtp.Matrix | None = None, standard_gravity: jtp.FloatLike = StandardGravity, + joint_transforms, + motion_subspaces, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute inverse dynamics using the Recursive Newton-Euler Algorithm (RNEA). @@ -91,9 +93,10 @@ def rnea( # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=W_H_B.as_matrix() - ) + # i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + # joint_positions=s, base_transform=W_H_B.as_matrix() + # ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate buffers. v = jnp.zeros(shape=(model.number_of_links(), 6, 1))