From 51443df77940c6a2178d83906dcf8d2c2436ea38 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 16 Dec 2024 15:11:03 +0100 Subject: [PATCH] [WIP] Save some kindyn computation in `JaxSimModelData` --- src/jaxsim/api/contact.py | 16 +++------------ src/jaxsim/api/data.py | 18 ++++++++++++++++ src/jaxsim/api/link.py | 4 +--- src/jaxsim/api/model.py | 25 ++++++++++++++++++++--- src/jaxsim/integrators/common.py | 8 ++++---- src/jaxsim/rbda/contacts/relaxed_rigid.py | 2 +- tests/test_simulations.py | 12 ----------- 7 files changed, 49 insertions(+), 36 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 11ec6dc76..eedd68299 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -511,9 +511,7 @@ def jacobian( )[indices_of_enabled_collidable_points] # Compute the Jacobians of all links. - W_J_WL = js.model.generalized_free_floating_jacobian( - model=model, data=data, output_vel_repr=VelRepr.Inertial - ) + W_J_WL = data.kyn_dyn.jacobian # Compute the contact Jacobian. # In inertial-fixed output representation, the Jacobian of the parent link is also @@ -663,17 +661,9 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: with data.switch_velocity_representation(VelRepr.Inertial): # Compute the Jacobian of the parent link in inertial representation. - W_J_WL_W = js.model.generalized_free_floating_jacobian( - model=model, - data=data, - output_vel_repr=VelRepr.Inertial, - ) + W_J_WL_W = data.kyn_dyn.jacobian # Compute the Jacobian derivative of the parent link in inertial representation. - W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative( - model=model, - data=data, - output_vel_repr=VelRepr.Inertial, - ) + W_J̇_WL_W = data.kyn_dyn.jacobian_derivative # Get the Jacobian of the enabled collidable points in the mixed representation. with data.switch_velocity_representation(VelRepr.Mixed): diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 81fe203c9..728b1bc4a 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -154,6 +154,22 @@ class KynDynComputation(common.ModelDataWithVelocityRepresentation): mass_matrix: jtp.Matrix +@jax_dataclasses.pytree_dataclass +class KynDynComputation: + + jacobian: jtp.Matrix + + jacobian_derivative: jtp.Matrix + + motion_subspaces: jtp.Matrix + + joint_transforms: jtp.Matrix + + mass_matrix: jtp.Matrix + + forward_kinematics: jtp.Matrix + + @jax_dataclasses.pytree_dataclass class JaxSimModelData(common.ModelDataWithVelocityRepresentation): """ @@ -162,6 +178,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation): state: ODEState + kyn_dyn: KynDynComputation + gravity: jtp.Vector contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False) diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index 2719da5d1..77050a91f 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -422,9 +422,7 @@ def jacobian_derivative( output_vel_repr if output_vel_repr is not None else data.velocity_representation ) - O_J̇_WL_I = js.model.generalized_free_floating_jacobian_derivative( - model=model, data=data, output_vel_repr=output_vel_repr - )[link_index] + O_J̇_WL_I = data.kyn_dyn.jacobian_derivative[link_index] return O_J̇_WL_I diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index bf8d1cdc1..cb489209c 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -271,9 +271,9 @@ def build( integrator_cls = integrator integrator = integrator_cls.build( - dynamics=js.ode.wrap_system_dynamics_for_integration( - system_dynamics=js.ode.system_dynamics - ) + # dynamics=js.ode.wrap_system_dynamics_for_integration( + # system_dynamics=js.ode.system_dynamics + # ) ) case _: @@ -2178,6 +2178,25 @@ def forward( joint_force_references: jtp.VectorLike | None = None, ) -> js.data.JaxSimModelData: + # Kinematics computation. + M = js.model.free_floating_mass_matrix(model=model, data=data) + J = js.model.generalized_free_floating_jacobian(model=model, data=data) + J̇ = js.model.generalized_free_floating_jacobian_derivative(model=model, data=data) + i_X_λ, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + joint_positions=data.joint_positions(), base_transform=data.base_transform() + ) + FK = js.model.forward_kinematics(model=model, data=data) + kyn_dyn = js.data.KynDynComputation( + jacobian=J, + jacobian_derivative=J̇, + joint_transforms=i_X_λ, + motion_subspaces=S, + mass_matrix=M, + forward_kinematics=FK, + ) + + data = data.replace(kyn_dyn=kyn_dyn) + # TODO: some contact models here may want to perform a dynamic filtering of # the enabled collidable points. diff --git a/src/jaxsim/integrators/common.py b/src/jaxsim/integrators/common.py index 912858fd9..482b81306 100644 --- a/src/jaxsim/integrators/common.py +++ b/src/jaxsim/integrators/common.py @@ -342,11 +342,11 @@ def compute_ki() -> tuple[jax.Array, dict[str, Any]]: ti = t0 + c[i] * Δt # Evaluate the dynamics. - ki, aux_dict = f(x=xi, t=ti) - return ki, aux_dict + ki = f(x=xi, t=ti) + return ki # This selector enables FSAL property in the first iteration (i=0). - ki, aux_dict = jax.lax.cond( + ki = jax.lax.cond( pred=jnp.logical_and(i == 0, self.has_fsal), true_fun=lambda: x0, false_fun=compute_ki, @@ -357,7 +357,7 @@ def compute_ki() -> tuple[jax.Array, dict[str, Any]]: K = jax.tree.map(op, K, ki) carry = K - return carry, aux_dict + return carry, None # Compute the state derivatives kᵢ. K, _ = jax.lax.scan( diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index fd84c6941..642a198e2 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -320,7 +320,7 @@ def compute_contact_forces( ) ) - M = js.model.free_floating_mass_matrix(model=model, data=data) + M = data.kyn_dyn.mass_matrix Jl_WC = jnp.vstack( jax.vmap(lambda J, δ: J * (δ > 0))( diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 12026c6a7..2b9aa9ede 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -201,19 +201,8 @@ def run_simulation( return data -@pytest.mark.parametrize( - "integrator", - [ - jaxsim.integrators.fixed_step.ForwardEuler, - jaxsim.integrators.fixed_step.ForwardEulerSO3, - jaxsim.integrators.fixed_step.RungeKutta4, - jaxsim.integrators.fixed_step.RungeKutta4SO3, - jaxsim.integrators.variable_step.BogackiShampineSO3, - ], -) def test_simulation_with_soft_contacts( jaxsim_model_box: js.model.JaxSimModel, - integrator: jaxsim.integrators.Integrator, ): model = jaxsim_model_box @@ -229,7 +218,6 @@ def test_simulation_with_soft_contacts( model.kin_dyn_parameters.contact_parameters.enabled = tuple( enabled_collidable_points_mask.tolist() ) - model.integrator = integrator.build() assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4