diff --git a/src/jaxsim/api/com.py b/src/jaxsim/api/com.py index 11b9b2d4a..3b5f9aced 100644 --- a/src/jaxsim/api/com.py +++ b/src/jaxsim/api/com.py @@ -26,7 +26,7 @@ def com_position( m = js.model.total_mass(model=model) - W_H_L = js.model.forward_kinematics(model=model, data=data) + W_H_L = data.kyn_dyn.forward_kinematics W_H_B = data.base_transform() B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B) @@ -269,7 +269,7 @@ def bias_acceleration( """ # Compute the pose of all links with forward kinematics. - W_H_L = js.model.forward_kinematics(model=model, data=data) + W_H_L = data.kyn_dyn.forward_kinematics # Compute the bias acceleration of all links by zeroing the generalized velocity # in the active representation. diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 31d2245d4..ab5e308d5 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -42,12 +42,14 @@ def collidable_point_kinematics( W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel( model=model, - base_position=data.base_position(), + base_position=data.base_position, base_quaternion=data.base_orientation(dcm=False), joint_positions=data.joint_positions(model=model), base_linear_velocity=data.base_velocity()[0:3], base_angular_velocity=data.base_velocity()[3:6], joint_velocities=data.joint_velocities(model=model), + joint_transforms=data.kyn_dyn.joint_transforms, + motion_subspaces=data.kyn_dyn.motion_subspaces, ) return W_p_Ci, W_ṗ_Ci @@ -460,7 +462,7 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt )[indices_of_enabled_collidable_points] # Get the transforms of the parent link of all collidable points. - W_H_L = js.model.forward_kinematics(model=model, data=data)[ + W_H_L = data.kyn_dyn.forward_kinematics[ parent_link_idx_of_enabled_collidable_points ] @@ -518,9 +520,7 @@ def jacobian( )[indices_of_enabled_collidable_points] # Compute the Jacobians of all links. - W_J_WL = js.model.generalized_free_floating_jacobian( - model=model, data=data, output_vel_repr=VelRepr.Inertial - ) + W_J_WL = data.kyn_dyn.jacobian # Compute the contact Jacobian. # In inertial-fixed output representation, the Jacobian of the parent link is also @@ -612,7 +612,7 @@ def jacobian_derivative( ] # Get the transforms of all the parent links. - W_H_Li = js.model.forward_kinematics(model=model, data=data) + W_H_Li = data.kyn_dyn.forward_kinematics # ===================================================== # Compute quantities to adjust the input representation @@ -670,17 +670,9 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: with data.switch_velocity_representation(VelRepr.Inertial): # Compute the Jacobian of the parent link in inertial representation. - W_J_WL_W = js.model.generalized_free_floating_jacobian( - model=model, - data=data, - output_vel_repr=VelRepr.Inertial, - ) + W_J_WL_W = data.kyn_dyn.jacobian # Compute the Jacobian derivative of the parent link in inertial representation. - W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative( - model=model, - data=data, - output_vel_repr=VelRepr.Inertial, - ) + W_J̇_WL_W = data.kyn_dyn.jacobian_derivative # Get the Jacobian of the enabled collidable points in the mixed representation. with data.switch_velocity_representation(VelRepr.Mixed): diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index af3ea7ef2..3e9f662ae 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -26,6 +26,22 @@ from typing_extensions import Self +@jax_dataclasses.pytree_dataclass +class KynDynComputation: + + jacobian: jtp.Matrix + + jacobian_derivative: jtp.Matrix + + motion_subspaces: jtp.Matrix + + joint_transforms: jtp.Matrix + + mass_matrix: jtp.Matrix + + forward_kinematics: jtp.Matrix + + @jax_dataclasses.pytree_dataclass class JaxSimModelData(common.ModelDataWithVelocityRepresentation): """ @@ -34,6 +50,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation): state: ODEState + kyn_dyn: KynDynComputation + gravity: jtp.Vector contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False) @@ -232,11 +250,33 @@ def build( else: contacts_params = model.contact_model._parameters_class() + n = model.dofs() + n_fb = n + 6 * model.floating_base() + + jacobian = jnp.zeros((model.number_of_links(), 6, n_fb)) + jacobian_derivative = jnp.zeros((model.number_of_links(), 6, n_fb)) + motion_subspaces = jnp.zeros((model.number_of_links(), 6, 1)) + joint_transforms = jnp.zeros((model.number_of_links(), 6, 6)) + mass_matrix = jnp.zeros((n_fb, n_fb)) + forward_kinematics = jnp.zeros((model.number_of_links(), 4, 4)) + + kyn_dyn = KynDynComputation( + jacobian=jacobian, + jacobian_derivative=jacobian_derivative, + motion_subspaces=motion_subspaces, + joint_transforms=joint_transforms, + mass_matrix=mass_matrix, + forward_kinematics=forward_kinematics, + ) + + print(jacobian.shape) + return JaxSimModelData( state=ode_state, gravity=gravity, contacts_params=contacts_params, velocity_representation=velocity_representation, + kyn_dyn=kyn_dyn, ) # ================== @@ -349,8 +389,7 @@ def joint_velocities( return self.state.physics_model.joint_velocities[joint_idxs] - @js.common.named_scope - @jax.jit + @property def base_position(self) -> jtp.Vector: """ Get the base position. @@ -359,7 +398,7 @@ def base_position(self) -> jtp.Vector: The base position. """ - return self.state.physics_model.base_position.squeeze() + return self.state.physics_model.base_position @js.common.named_scope @functools.partial(jax.jit, static_argnames=["dcm"]) @@ -400,7 +439,7 @@ def base_transform(self) -> jtp.Matrix: """ W_R_B = self.base_orientation(dcm=True) - W_p_B = jnp.vstack(self.base_position()) + W_p_B = jnp.vstack(self.base_position) return jnp.vstack( [ diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index 23b4d3732..50aafab24 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -187,7 +187,7 @@ def transform( idx=link_index, ) - return js.model.forward_kinematics(model=model, data=data)[link_index] + return data.kyn_dyn.forward_kinematics[link_index] @jax.jit @@ -276,6 +276,8 @@ def jacobian( B_J_full_WX_B, B_H_Li = jaxsim.rbda.jacobian_full_doubly_left( model=model, joint_positions=data.joint_positions(), + joint_transforms=data.kyn_dyn.joint_transforms, + motion_subspaces=data.kyn_dyn.motion_subspaces, ) # Compute the actual doubly-left free-floating jacobian of the link. @@ -422,9 +424,7 @@ def jacobian_derivative( output_vel_repr if output_vel_repr is not None else data.velocity_representation ) - O_J̇_WL_I = js.model.generalized_free_floating_jacobian_derivative( - model=model, data=data, output_vel_repr=output_vel_repr - )[link_index] + O_J̇_WL_I = data.kyn_dyn.jacobian_derivative[link_index] return O_J̇_WL_I diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 27e7cd616..68deb1902 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -271,9 +271,9 @@ def build( integrator_cls = integrator integrator = integrator_cls.build( - dynamics=js.ode.wrap_system_dynamics_for_integration( - system_dynamics=js.ode.system_dynamics - ) + # dynamics=js.ode.wrap_system_dynamics_for_integration( + # system_dynamics=js.ode.system_dynamics + # ) ) case _: @@ -574,9 +574,10 @@ def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp W_H_LL = jaxsim.rbda.forward_kinematics_model( model=model, - base_position=data.base_position(), + base_position=data.base_position, base_quaternion=data.base_orientation(dcm=False), joint_positions=data.joint_positions(model=model), + joint_transforms=data.kyn_dyn.joint_transforms, ) return jnp.atleast_3d(W_H_LL).astype(float) @@ -616,6 +617,8 @@ def generalized_free_floating_jacobian( B_J_full_WX_B, B_H_L = jaxsim.rbda.jacobian_full_doubly_left( model=model, joint_positions=data.joint_positions(), + joint_transforms=data.kyn_dyn.joint_transforms, + motion_subspaces=data.kyn_dyn.motion_subspaces, ) # ====================================================================== @@ -743,6 +746,8 @@ def generalized_free_floating_jacobian_derivative( model=model, joint_positions=data.joint_positions(), joint_velocities=data.joint_velocities(), + # joint_transforms=data.kyn_dyn.joint_transforms, + # motion_subspaces=data.kyn_dyn.motion_subspaces, ) # The derivative of the equation to change the input and output representations @@ -750,6 +755,8 @@ def generalized_free_floating_jacobian_derivative( B_J_full_WL_B, _ = jaxsim.rbda.jacobian_full_doubly_left( model=model, joint_positions=data.joint_positions(), + joint_transforms=data.kyn_dyn.joint_transforms, + motion_subspaces=data.kyn_dyn.motion_subspaces, ) # Compute the actual doubly-left free-floating jacobian derivative of the link @@ -1005,7 +1012,7 @@ def forward_dynamics_aba( # Extract the state in inertial-fixed representation. with data.switch_velocity_representation(VelRepr.Inertial): - W_p_B = data.base_position() + W_p_B = data.base_position W_v_WB = data.base_velocity() W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model, joint_names=joint_names) @@ -1031,6 +1038,8 @@ def forward_dynamics_aba( joint_forces=τ, link_forces=W_f_L, standard_gravity=data.standard_gravity(), + joint_transforms=data.kyn_dyn.joint_transforms, + motion_subspaces=data.kyn_dyn.motion_subspaces, ) # ============= @@ -1201,6 +1210,8 @@ def free_floating_mass_matrix( M_body = jaxsim.rbda.crba( model=model, joint_positions=data.state.physics_model.joint_positions, + joint_transforms=data.kyn_dyn.joint_transforms, + motion_subspaces=data.kyn_dyn.motion_subspaces, ) match data.velocity_representation: @@ -1457,7 +1468,7 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): # Extract the state in inertial-fixed representation. with data.switch_velocity_representation(VelRepr.Inertial): - W_p_B = data.base_position() + W_p_B = data.base_position W_v_WB = data.base_velocity() W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model, joint_names=joint_names) @@ -1484,6 +1495,8 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): joint_accelerations=s̈, link_forces=W_f_L, standard_gravity=data.standard_gravity(), + joint_transforms=data.kyn_dyn.joint_transforms, + motion_subspaces=data.kyn_dyn.motion_subspaces, ) # ============= @@ -1792,7 +1805,7 @@ def average_velocity_jacobian( case VelRepr.Body: GB_J = G_J - W_p_B = data.base_position() + W_p_B = data.base_position W_p_CoM = js.com.com_position(model=model, data=data) B_R_W = data.base_orientation(dcm=True).transpose() @@ -1804,7 +1817,7 @@ def average_velocity_jacobian( case VelRepr.Mixed: GW_J = G_J - W_p_B = data.base_position() + W_p_B = data.base_position W_p_CoM = js.com.com_position(model=model, data=data) BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B) @@ -2006,11 +2019,11 @@ def body_to_other_representation( ) case VelRepr.Inertial: - C_H_L = W_H_L = js.model.forward_kinematics(model=model, data=data) + C_H_L = W_H_L = data.kyn_dyn.forward_kinematics L_v_CL = L_v_WL case VelRepr.Mixed: - W_H_L = js.model.forward_kinematics(model=model, data=data) + W_H_L = data.kyn_dyn.forward_kinematics LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L) C_H_L = LW_H_L L_v_CL = L_v_LW_L = jax.vmap( # noqa: F841 @@ -2182,6 +2195,25 @@ def forward( joint_force_references: jtp.VectorLike | None = None, ) -> js.data.JaxSimModelData: + # Kinematics computation. + M = js.model.free_floating_mass_matrix(model=model, data=data) + J = js.model.generalized_free_floating_jacobian(model=model, data=data) + J̇ = js.model.generalized_free_floating_jacobian_derivative(model=model, data=data) + i_X_λ, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + joint_positions=data.joint_positions(), base_transform=data.base_transform() + ) + FK = js.model.forward_kinematics(model=model, data=data) + kyn_dyn = js.data.KynDynComputation( + jacobian=J, + jacobian_derivative=J̇, + joint_transforms=i_X_λ, + motion_subspaces=S, + mass_matrix=M, + forward_kinematics=FK, + ) + + data = data.replace(kyn_dyn=kyn_dyn) + # TODO: some contact models here may want to perform a dynamic filtering of # the enabled collidable points. diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index 1c0a11078..bdfa4214f 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -242,7 +242,7 @@ def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: )(W_f_L, W_H_L) # The f_L output is either L_f_L or LW_f_L, depending on the representation. - W_H_L = js.model.forward_kinematics(model=model, data=data) + W_H_L = data.kyn_dyn.forward_kinematics f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :]) return f_L @@ -450,7 +450,7 @@ def convert_using_link_frame( )(f_L, W_H_L) # The f_L input is either L_f_L or LW_f_L, depending on the representation. - W_H_L = js.model.forward_kinematics(model=model, data=data) + W_H_L = data.kyn_dyn.forward_kinematics W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :]) return replace(forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L)) diff --git a/src/jaxsim/integrators/common.py b/src/jaxsim/integrators/common.py index 912858fd9..482b81306 100644 --- a/src/jaxsim/integrators/common.py +++ b/src/jaxsim/integrators/common.py @@ -342,11 +342,11 @@ def compute_ki() -> tuple[jax.Array, dict[str, Any]]: ti = t0 + c[i] * Δt # Evaluate the dynamics. - ki, aux_dict = f(x=xi, t=ti) - return ki, aux_dict + ki = f(x=xi, t=ti) + return ki # This selector enables FSAL property in the first iteration (i=0). - ki, aux_dict = jax.lax.cond( + ki = jax.lax.cond( pred=jnp.logical_and(i == 0, self.has_fsal), true_fun=lambda: x0, false_fun=compute_ki, @@ -357,7 +357,7 @@ def compute_ki() -> tuple[jax.Array, dict[str, Any]]: K = jax.tree.map(op, K, ki) carry = K - return carry, aux_dict + return carry, None # Compute the state derivatives kᵢ. K, _ = jax.lax.scan( diff --git a/src/jaxsim/mujoco/utils.py b/src/jaxsim/mujoco/utils.py index 2afff1732..c80bfa79e 100644 --- a/src/jaxsim/mujoco/utils.py +++ b/src/jaxsim/mujoco/utils.py @@ -59,7 +59,7 @@ def mujoco_data_from_jaxsim( if jaxsim_model.floating_base(): # Set the model position. - model_helper.set_base_position(position=np.array(jaxsim_data.base_position())) + model_helper.set_base_position(position=np.array(jaxsim_data.base_position)) # Set the model orientation. model_helper.set_base_orientation( diff --git a/src/jaxsim/rbda/aba.py b/src/jaxsim/rbda/aba.py index b01f46698..d7b596e6e 100644 --- a/src/jaxsim/rbda/aba.py +++ b/src/jaxsim/rbda/aba.py @@ -21,6 +21,8 @@ def aba( joint_forces: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, standard_gravity: jtp.FloatLike = StandardGravity, + joint_transforms, + motion_subspaces, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute forward dynamics using the Articulated Body Algorithm (ABA). @@ -88,9 +90,10 @@ def aba( # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=W_H_B.as_matrix() - ) + # i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + # joint_positions=s, base_transform=W_H_B.as_matrix() + # ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate buffers. v = jnp.zeros(shape=(model.number_of_links(), 6, 1)) diff --git a/src/jaxsim/rbda/collidable_points.py b/src/jaxsim/rbda/collidable_points.py index 543be5328..f85bfc49a 100644 --- a/src/jaxsim/rbda/collidable_points.py +++ b/src/jaxsim/rbda/collidable_points.py @@ -1,6 +1,5 @@ import jax import jax.numpy as jnp -import jaxlie import jaxsim.api as js import jaxsim.typing as jtp @@ -8,6 +7,8 @@ from . import utils +# import jaxlie + def collidable_points_pos_vel( model: js.model.JaxSimModel, @@ -18,6 +19,8 @@ def collidable_points_pos_vel( base_linear_velocity: jtp.Vector, base_angular_velocity: jtp.Vector, joint_velocities: jtp.Vector, + joint_transforms, + motion_subspaces, ) -> tuple[jtp.Matrix, jtp.Matrix]: """ @@ -54,7 +57,7 @@ def collidable_points_pos_vel( if len(indices_of_enabled_collidable_points) == 0: return jnp.array(0).astype(float), jnp.empty(0).astype(float) - W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs( + _, _, _, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs( model=model, base_position=base_position, base_quaternion=base_quaternion, @@ -69,17 +72,18 @@ def collidable_points_pos_vel( λ = model.kin_dyn_parameters.parent_array # Compute the base transform. - W_H_B = jaxlie.SE3.from_rotation_and_translation( - rotation=jaxlie.SO3(wxyz=W_Q_B), - translation=W_p_B, - ) + # W_H_B = jaxlie.SE3.from_rotation_and_translation( + # rotation=jaxlie.SO3(wxyz=W_Q_B), + # translation=W_p_B, + # ) # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=W_H_B.as_matrix() - ) + # i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + # joint_positions=s, base_transform=W_H_B.as_matrix() + # ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate buffer of transforms world -> link and initialize the base pose. W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6)) diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index 56b403fa7..4641fbd46 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -251,7 +251,7 @@ def link_forces_from_contact_forces( # Compute the link transforms. W_H_L = ( - js.model.forward_kinematics(model=model, data=data) + data.kyn_dyn.forward_kinematics if data.velocity_representation is not jaxsim.VelRepr.Inertial else jnp.zeros(shape=(model.number_of_links(), 4, 4)) ) diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index fd84c6941..642a198e2 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -320,7 +320,7 @@ def compute_contact_forces( ) ) - M = js.model.free_floating_mass_matrix(model=model, data=data) + M = data.kyn_dyn.mass_matrix Jl_WC = jnp.vstack( jax.vmap(lambda J, δ: J * (δ > 0))( diff --git a/src/jaxsim/rbda/contacts/visco_elastic.py b/src/jaxsim/rbda/contacts/visco_elastic.py index c433fe23d..11101cb0c 100644 --- a/src/jaxsim/rbda/contacts/visco_elastic.py +++ b/src/jaxsim/rbda/contacts/visco_elastic.py @@ -827,7 +827,7 @@ def integrate_data_with_average_contact_forces( """ s_t0 = data.joint_positions() - W_p_B_t0 = data.base_position() + W_p_B_t0 = data.base_position W_Q_B_t0 = data.base_orientation(dcm=False) ṡ_t0 = data.joint_velocities() @@ -1005,7 +1005,7 @@ def step( # Compute the link transforms. W_H_L = ( - js.model.forward_kinematics(model=model, data=data) + data.kyn_dyn.forward_kinematics if data.velocity_representation is not jaxsim.VelRepr.Inertial else jnp.zeros(shape=(model.number_of_links(), 4, 4)) ) diff --git a/src/jaxsim/rbda/crba.py b/src/jaxsim/rbda/crba.py index adb3506ae..20b77adf3 100644 --- a/src/jaxsim/rbda/crba.py +++ b/src/jaxsim/rbda/crba.py @@ -4,10 +4,16 @@ import jaxsim.api as js import jaxsim.typing as jtp -from . import utils +# from . import utils -def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Matrix: +def crba( + model: js.model.JaxSimModel, + *, + joint_positions: jtp.Vector, + joint_transforms, + motion_subspaces, +) -> jtp.Matrix: """ Compute the free-floating mass matrix using the Composite Rigid-Body Algorithm (CRBA). @@ -19,9 +25,9 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat The free-floating mass matrix of the model in body-fixed representation. """ - _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( - model=model, joint_positions=joint_positions - ) + # _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( + # model=model, joint_positions=joint_positions + # ) # Get the 6D spatial inertia matrices of all links. Mc = js.model.link_spatial_inertia_matrices(model=model) @@ -33,9 +39,10 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=jnp.eye(4) - ) + # i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + # joint_positions=s, base_transform=W_H_B.as_matrix() + # ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate the buffer of transforms link -> base. i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6)) diff --git a/src/jaxsim/rbda/forward_kinematics.py b/src/jaxsim/rbda/forward_kinematics.py index d11e9b45d..b3bed3460 100644 --- a/src/jaxsim/rbda/forward_kinematics.py +++ b/src/jaxsim/rbda/forward_kinematics.py @@ -1,12 +1,14 @@ import jax import jax.numpy as jnp -import jaxlie import jaxsim.api as js import jaxsim.typing as jtp from jaxsim.math import Adjoint -from . import utils +# import jaxlie + + +# from . import utils def forward_kinematics_model( @@ -15,6 +17,7 @@ def forward_kinematics_model( base_position: jtp.VectorLike, base_quaternion: jtp.VectorLike, joint_positions: jtp.VectorLike, + joint_transforms, ) -> jtp.Array: """ Compute the forward kinematics. @@ -29,29 +32,30 @@ def forward_kinematics_model( A 3D array containing the SE(3) transforms of all links belonging to the model. """ - W_p_B, W_Q_B, s, _, _, _, _, _, _, _ = utils.process_inputs( - model=model, - base_position=base_position, - base_quaternion=base_quaternion, - joint_positions=joint_positions, - ) + # W_p_B, W_Q_B, s, _, _, _, _, _, _, _ = utils.process_inputs( + # model=model, + # base_position=base_position, + # base_quaternion=base_quaternion, + # joint_positions=joint_positions, + # ) # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array # Compute the base transform. - W_H_B = jaxlie.SE3.from_rotation_and_translation( - rotation=jaxlie.SO3(wxyz=W_Q_B), - translation=W_p_B, - ) + # W_H_B = jaxlie.SE3.from_rotation_and_translation( + # rotation=jaxlie.SO3(wxyz=W_Q_B), + # translation=W_p_B, + # ) # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, _ = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=W_H_B.as_matrix() - ) + # i_X_λi, _ = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + # joint_positions=s, base_transform=W_H_B.as_matrix() + # ) + i_X_λi = joint_transforms # Allocate the buffer of transforms world -> link and initialize the base pose. W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6)) diff --git a/src/jaxsim/rbda/jacobian.py b/src/jaxsim/rbda/jacobian.py index e8f44d088..e79970019 100644 --- a/src/jaxsim/rbda/jacobian.py +++ b/src/jaxsim/rbda/jacobian.py @@ -14,6 +14,8 @@ def jacobian( *, link_index: jtp.Int, joint_positions: jtp.VectorLike, + joint_transforms, + motion_subspaces, ) -> jtp.Matrix: """ Compute the free-floating Jacobian of a link. @@ -27,9 +29,9 @@ def jacobian( The free-floating left-trivialized Jacobian of the link :math:`{}^L J_{W,L/B}`. """ - _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( - model=model, joint_positions=joint_positions - ) + # _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( + # model=model, joint_positions=joint_positions + # ) # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. @@ -38,9 +40,10 @@ def jacobian( # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=jnp.eye(4) - ) + # i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + # joint_positions=s, base_transform=W_H_B.as_matrix() + # ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate the buffer of transforms link -> base. i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6)) @@ -127,6 +130,8 @@ def jacobian_full_doubly_left( model: js.model.JaxSimModel, *, joint_positions: jtp.VectorLike, + joint_transforms, + motion_subspaces, ) -> tuple[jtp.Matrix, jtp.Array]: r""" Compute the doubly-left full free-floating Jacobian of a model. @@ -144,9 +149,9 @@ def jacobian_full_doubly_left( The doubly-left full free-floating Jacobian of a model. """ - _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( - model=model, joint_positions=joint_positions - ) + # _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( + # model=model, joint_positions=joint_positions + # ) # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. @@ -155,9 +160,10 @@ def jacobian_full_doubly_left( # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=jnp.eye(4) - ) + # i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + # joint_positions=s, base_transform=W_H_B.as_matrix() + # ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate the buffer of transforms base -> link. B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6)) diff --git a/src/jaxsim/rbda/rnea.py b/src/jaxsim/rbda/rnea.py index 025d85a62..3b7ba5975 100644 --- a/src/jaxsim/rbda/rnea.py +++ b/src/jaxsim/rbda/rnea.py @@ -23,6 +23,8 @@ def rnea( joint_accelerations: jtp.Vector | None = None, link_forces: jtp.Matrix | None = None, standard_gravity: jtp.FloatLike = StandardGravity, + joint_transforms, + motion_subspaces, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute inverse dynamics using the Recursive Newton-Euler Algorithm (RNEA). @@ -91,9 +93,10 @@ def rnea( # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=W_H_B.as_matrix() - ) + # i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + # joint_positions=s, base_transform=W_H_B.as_matrix() + # ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate buffers. v = jnp.zeros(shape=(model.number_of_links(), 6, 1)) diff --git a/tests/test_api_contact.py b/tests/test_api_contact.py index 4a0882737..63666431d 100644 --- a/tests/test_api_contact.py +++ b/tests/test_api_contact.py @@ -132,7 +132,7 @@ def test_contact_jacobian_derivative( # Rebuild the JaxSim data. data_with_frames = js.data.JaxSimModelData.build( model=model_with_frames, - base_position=data.base_position(), + base_position=data.base_position, base_quaternion=data.base_orientation(dcm=False), joint_positions=data.joint_positions(), base_linear_velocity=data.base_velocity()[0:3], diff --git a/tests/test_api_frame.py b/tests/test_api_frame.py index d20dbc5ff..3be031cbc 100644 --- a/tests/test_api_frame.py +++ b/tests/test_api_frame.py @@ -249,7 +249,7 @@ def J(q, frame_idxs) -> jax.Array: def compute_q(data: js.data.JaxSimModelData) -> jax.Array: q = jnp.hstack( [ - data.base_position(), + data.base_position, data.base_orientation(), data.joint_positions(), ] diff --git a/tests/test_api_link.py b/tests/test_api_link.py index 7f89e0cc5..150ef574f 100644 --- a/tests/test_api_link.py +++ b/tests/test_api_link.py @@ -341,7 +341,7 @@ def J(q) -> jax.Array: def compute_q(data: js.data.JaxSimModelData) -> jax.Array: q = jnp.hstack( - [data.base_position(), data.base_orientation(), data.joint_positions()] + [data.base_position, data.base_orientation(), data.joint_positions()] ) return q diff --git a/tests/test_api_model.py b/tests/test_api_model.py index 41f19f67f..20912ff62 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -438,7 +438,7 @@ def M(q) -> jax.Array: def compute_q(data: js.data.JaxSimModelData) -> jax.Array: q = jnp.hstack( - [data.base_position(), data.base_orientation(), data.joint_positions()] + [data.base_position, data.base_orientation(), data.joint_positions()] ) return q diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 477f6245d..8c8e9fc9b 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -75,7 +75,7 @@ def test_ad_aba( g = jaxsim.math.StandardGravity # State in VelRepr.Inertial representation. - W_p_B = data.base_position() + W_p_B = data.base_position W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model) W_v_WB = data.base_velocity() @@ -129,7 +129,7 @@ def test_ad_rnea( g = jaxsim.math.StandardGravity # State in VelRepr.Inertial representation. - W_p_B = data.base_position() + W_p_B = data.base_position W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model) W_v_WB = data.base_velocity() @@ -217,7 +217,7 @@ def test_ad_fk( ) # State in VelRepr.Inertial representation. - W_p_B = data.base_position() + W_p_B = data.base_position W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model) @@ -344,7 +344,7 @@ def test_ad_integration( ) # State in VelRepr.Inertial representation. - W_p_B = data.base_position() + W_p_B = data.base_position W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model) W_v_WB = data.base_velocity() diff --git a/tests/test_pytree.py b/tests/test_pytree.py index c27179254..c3bf82576 100644 --- a/tests/test_pytree.py +++ b/tests/test_pytree.py @@ -51,7 +51,7 @@ def my_jit_function(model: js.model.JaxSimModel, data: js.data.JaxSimModelData): # Return random elements from model and data, just to have something returned. return ( jnp.sum(model.kin_dyn_parameters.link_parameters.mass), - data.base_position(), + data.base_position, ) data1 = js.data.JaxSimModelData.build(model=model1) diff --git a/tests/test_simulations.py b/tests/test_simulations.py index ae3a4cfc3..230ff11e4 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -78,7 +78,7 @@ def test_box_with_external_forces( ) # Check that the box didn't move. - assert data.base_position() == pytest.approx(data0.base_position()) + assert data.base_position == pytest.approx(data0.base_position()) assert data.base_orientation() == pytest.approx(data0.base_orientation()) @@ -158,7 +158,7 @@ def test_box_with_zero_gravity( ) # Check that the box moved as expected. - assert data.base_position() == pytest.approx( + assert data.base_position == pytest.approx( data0.base_position() + 0.5 * LW_f[:, :3].squeeze() / js.model.total_mass(model=model) * tf**2, abs=1e-3, @@ -201,19 +201,8 @@ def run_simulation( return data -@pytest.mark.parametrize( - "integrator", - [ - jaxsim.integrators.fixed_step.ForwardEuler, - jaxsim.integrators.fixed_step.ForwardEulerSO3, - jaxsim.integrators.fixed_step.RungeKutta4, - jaxsim.integrators.fixed_step.RungeKutta4SO3, - jaxsim.integrators.variable_step.BogackiShampineSO3, - ], -) def test_simulation_with_soft_contacts( jaxsim_model_box: js.model.JaxSimModel, - integrator: jaxsim.integrators.Integrator, ): model = jaxsim_model_box @@ -229,7 +218,6 @@ def test_simulation_with_soft_contacts( model.kin_dyn_parameters.contact_parameters.enabled = tuple( enabled_collidable_points_mask.tolist() ) - model.integrator = integrator.build() assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4