diff --git a/src/jaxsim/api/com.py b/src/jaxsim/api/com.py index f2122ced1..401706739 100644 --- a/src/jaxsim/api/com.py +++ b/src/jaxsim/api/com.py @@ -26,7 +26,7 @@ def com_position( m = js.model.total_mass(model=model) - W_H_L = js.model.forward_kinematics(model=model, data=data) + W_H_L = data.kyn_dyn.forward_kinematics W_H_B = data.base_transform() B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B) @@ -269,7 +269,7 @@ def bias_acceleration( """ # Compute the pose of all links with forward kinematics. - W_H_L = js.model.forward_kinematics(model=model, data=data) + W_H_L = data.kyn_dyn.forward_kinematics # Compute the bias acceleration of all links by zeroing the generalized velocity # in the active representation. diff --git a/src/jaxsim/api/common.py b/src/jaxsim/api/common.py index 7d723120e..163f421ef 100644 --- a/src/jaxsim/api/common.py +++ b/src/jaxsim/api/common.py @@ -232,3 +232,52 @@ def other_representation_to_inertial( case _: raise ValueError(other_representation) + + +def convert_mass_matrix( + M: jtp.Matrix, + base_transform: jtp.Matrix, + dofs: jtp.Int, + velocity_representation: VelRepr, +): + + # The mass matrix is always save in body-fixed representation. + + match velocity_representation: + case VelRepr.Body: + return M + + case VelRepr.Inertial: + + B_X_W = Adjoint.from_transform(transform=base_transform, inverse=True) + invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(dofs)) + + return invT.T @ M @ invT + + case VelRepr.Mixed: + + BW_H_B = base_transform.at[0:3, 3].set(jnp.zeros(3)) + B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) + invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(dofs)) + + return invT.T @ M @ invT + + +def convert_jacobian( + J: jtp.Matrix, + base_transform: jtp.Matrix, + dofs: jtp.Int, + velocity_representation: VelRepr, +): + # TODO (flferretti): save actual Jacobian instead of full doubly left and perform conversion. + return J + + +def convert_jacobian_derivative( + Jd: jtp.Matrix, + base_transform: jtp.Matrix, + dofs: jtp.Int, + velocity_representation: VelRepr, +): + # TODO (flferretti): save actual Jacobian derivative instead of full doubly left and perform conversion. + return Jd diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 294413f7e..67699ff5d 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -42,12 +42,14 @@ def collidable_point_kinematics( W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel( model=model, - base_position=data.base_position(), + base_position=data.base_position, base_quaternion=data.base_orientation(dcm=False), joint_positions=data.joint_positions(model=model), base_linear_velocity=data.base_velocity()[0:3], base_angular_velocity=data.base_velocity()[3:6], joint_velocities=data.joint_velocities(model=model), + joint_transforms=data.kyn_dyn.joint_transforms, + motion_subspaces=data.kyn_dyn.motion_subspaces, ) return W_p_Ci, W_ṗ_Ci @@ -460,7 +462,7 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt )[indices_of_enabled_collidable_points] # Get the transforms of the parent link of all collidable points. - W_H_L = js.model.forward_kinematics(model=model, data=data)[ + W_H_L = data.kyn_dyn.forward_kinematics[ parent_link_idx_of_enabled_collidable_points ] @@ -612,7 +614,7 @@ def jacobian_derivative( ] # Get the transforms of all the parent links. - W_H_Li = js.model.forward_kinematics(model=model, data=data) + W_H_Li = data.kyn_dyn.forward_kinematics # ===================================================== # Compute quantities to adjust the input representation diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index b880547b9..38e4db70b 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -13,11 +13,15 @@ import jaxsim.math import jaxsim.rbda import jaxsim.typing as jtp -from jaxsim.utils import Mutability from jaxsim.utils.tracing import not_tracing from . import common -from .common import VelRepr +from .common import ( + VelRepr, + convert_jacobian, + convert_jacobian_derivative, + convert_mass_matrix, +) from .ode_data import ODEState try: @@ -26,6 +30,103 @@ from typing_extensions import Self +class KynDynProxy: + """ + Proxy class for KynDynComputation that ensures attribute-specific + velocity representation consistency. + """ + + _data: JaxSimModelData + _kyn_dyn: KynDynComputation + + def __init__(self, data, kyn_dyn): + self._data = data + self._kyn_dyn = kyn_dyn + + def __convert_attribute(self, value, name): + + if name in ["motion_subspaces", "joint_transforms", "forward_kinematics"]: + return value + + match name: + + case "jacobian_full_doubly_left": + if ( + self._data.velocity_representation + != self._kyn_dyn.velocity_representation + ): + value = convert_jacobian( + value, + dofs=len(self._data.state.physics_model.joint_positions), + base_transform=self._kyn_dyn.velocity_representation, + velocity_representation=self._data.velocity_representation, + ) + + case "jacobian_derivative_full_doubly_left": + if ( + self._data.velocity_representation + != self._kyn_dyn.velocity_representation + ): + value = convert_jacobian_derivative( + value, + dofs=len(self._data.state.physics_model.joint_positions), + base_transform=self._kyn_dyn.velocity_representation, + velocity_representation=self._data.velocity_representation, + ) + + case "mass_matrix": + if ( + self._data.velocity_representation + != self._kyn_dyn.velocity_representation + ): + value = convert_mass_matrix( + value, + dofs=len(self._data.state.physics_model.joint_positions), + base_transform=self._kyn_dyn.velocity_representation, + velocity_representation=self._data.velocity_representation, + ) + + case _: + raise AttributeError( + f"'{type(self._kyn_dyn).__name__}' object has no attribute '{name}'" + ) + + return value + + def __getattr__(self, name: str): + + if name in ["_data", "_kyn_dyn"]: + return super().__getattribute__(name) + + value = getattr(self._kyn_dyn, name) + + return self.__convert_attribute(value=value, name=name) + + def __setattr__(self, name, value): + + if name in ["_data", "_kyn_dyn"]: + return super().__setattr__(name, value) + + value = self.__convert_attribute(value=value, name=name) + + self._kyn_dyn.replace(**{name: value}) + + +@jax_dataclasses.pytree_dataclass +class KynDynComputation(common.ModelDataWithVelocityRepresentation): + motion_subspaces: jtp.Matrix + + joint_transforms: jtp.Matrix + + forward_kinematics: jtp.Matrix + + jacobian_full_doubly_left: jtp.Matrix + + jacobian_derivative_full_doubly_left: jtp.Matrix + + mass_matrix: jtp.Matrix + + @jax_dataclasses.pytree_dataclass class JaxSimModelData(common.ModelDataWithVelocityRepresentation): """ @@ -34,10 +135,26 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation): state: ODEState + _kyn_dyn: KynDynComputation + gravity: jtp.Vector contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False) + @property + def kyn_dyn(self): + + # Return proxy object that handles attribute-specific conversions. + return KynDynProxy(data=self, kyn_dyn=self._kyn_dyn) + + @kyn_dyn.setter + def kyn_dyn(self, new_kyn_dyn: KynDynComputation): + + if not isinstance(new_kyn_dyn, KynDynComputation): + raise ValueError("kyn_dyn must be an instance of KynDynComputation") + + self._kyn_dyn = new_kyn_dyn + def __hash__(self) -> int: from jaxsim.utils.wrappers import HashedNumpyArray @@ -232,11 +349,51 @@ def build( else: contacts_params = model.contact_model._parameters_class() + base_orientation = jaxsim.math.Quaternion.to_dcm(base_quaternion) + base_transform = jnp.vstack( + [ + jnp.block([base_orientation, base_position.reshape(3, -1)]), + jnp.array([0, 0, 0, 1]), + ] + ) + + i_X_λ, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + joint_positions=joint_positions, base_transform=base_transform + ) + M = jaxsim.rbda.crba( + model=model, + joint_positions=joint_positions, + joint_transforms=i_X_λ, + motion_subspaces=S, + ) + J, W_H_L = jaxsim.rbda.jacobian_full_doubly_left( + model=model, + joint_positions=joint_positions, + joint_transforms=i_X_λ, + motion_subspaces=S, + ) + J̇, _ = jaxsim.rbda.jacobian_derivative_full_doubly_left( + model=model, + joint_positions=joint_positions, + joint_velocities=joint_velocities, + ) + + kyn_dyn = KynDynComputation( + velocity_representation=velocity_representation, + jacobian_full_doubly_left=J, + jacobian_derivative_full_doubly_left=J̇, + motion_subspaces=S, + joint_transforms=i_X_λ, + mass_matrix=M, + forward_kinematics=W_H_L, + ) + return JaxSimModelData( state=ode_state, gravity=gravity, contacts_params=contacts_params, velocity_representation=velocity_representation, + _kyn_dyn=kyn_dyn, ) # ================== @@ -287,12 +444,6 @@ def joint_positions( return self.state.physics_model.joint_positions - if not_tracing(self.state.physics_model.joint_positions) and not self.valid( - model=model - ): - msg = "The data object is not compatible with the provided model" - raise ValueError(msg) - joint_idxs = ( js.joint.names_to_idxs(joint_names=joint_names, model=model) if joint_names is not None @@ -335,12 +486,6 @@ def joint_velocities( return self.state.physics_model.joint_velocities - if not_tracing(self.state.physics_model.joint_velocities) and not self.valid( - model=model - ): - msg = "The data object is not compatible with the provided model" - raise ValueError(msg) - joint_idxs = ( js.joint.names_to_idxs(joint_names=joint_names, model=model) if joint_names is not None @@ -349,8 +494,7 @@ def joint_velocities( return self.state.physics_model.joint_velocities[joint_idxs] - @js.common.named_scope - @jax.jit + @property def base_position(self) -> jtp.Vector: """ Get the base position. @@ -359,7 +503,7 @@ def base_position(self) -> jtp.Vector: The base position. """ - return self.state.physics_model.base_position.squeeze() + return self.state.physics_model.base_position @js.common.named_scope @functools.partial(jax.jit, static_argnames=["dcm"]) @@ -401,7 +545,7 @@ def base_transform(self) -> jtp.Matrix: """ W_R_B = self.base_orientation(dcm=True) - W_p_B = jnp.vstack(self.base_position()) + W_p_B = jnp.vstack(self.base_position) return jnp.vstack( [ @@ -429,16 +573,12 @@ def base_velocity(self) -> jtp.Vector: W_H_B = self.base_transform() - return ( - JaxSimModelData.inertial_to_other_representation( - array=W_v_WB, - other_representation=self.velocity_representation, - transform=W_H_B, - is_force=False, - ) - .squeeze() - .astype(float) - ) + return JaxSimModelData.inertial_to_other_representation( + array=W_v_WB, + other_representation=self.velocity_representation, + transform=W_H_B, + is_force=False, + ).astype(float) @js.common.named_scope @jax.jit @@ -834,74 +974,60 @@ def random_model_data( ω_max = jnp.array(base_vel_ang_bounds[1], dtype=float) ṡ_min, ṡ_max = joint_vel_bounds - random_data = JaxSimModelData.zero( - model=model, - **( - dict(velocity_representation=velocity_representation) - if velocity_representation is not None - else {} - ), - ) - - with random_data.mutable_context( - mutability=Mutability.MUTABLE, restore_after_exception=False - ): + base_position = jax.random.uniform(key=k1, shape=(3,), minval=p_min, maxval=p_max) - physics_model_state = random_data.state.physics_model - - physics_model_state.base_position = jax.random.uniform( - key=k1, shape=(3,), minval=p_min, maxval=p_max - ) + base_quaternion = jaxsim.math.Quaternion.to_wxyz( + xyzw=jax.scipy.spatial.transform.Rotation.from_euler( + seq=base_rpy_seq, + angles=jax.random.uniform( + key=k2, shape=(3,), minval=rpy_min, maxval=rpy_max + ), + ).as_quat() + ) - physics_model_state.base_quaternion = jaxsim.math.Quaternion.to_wxyz( - xyzw=jax.scipy.spatial.transform.Rotation.from_euler( - seq=base_rpy_seq, - angles=jax.random.uniform( - key=k2, shape=(3,), minval=rpy_min, maxval=rpy_max - ), - ).as_quat() + ( + joint_positions, + joint_velocities, + base_linear_velocity, + base_angular_velocity, + standard_gravity, + contacts_params, + ) = (None,) * 6 + + if model.number_of_joints() > 0: + + s_min, s_max = ( + jnp.array(joint_pos_bounds, dtype=float) + if joint_pos_bounds is not None + else (None, None) ) - if model.number_of_joints() > 0: - - s_min, s_max = ( - jnp.array(joint_pos_bounds, dtype=float) - if joint_pos_bounds is not None - else (None, None) - ) - - physics_model_state.joint_positions = ( - js.joint.random_joint_positions(model=model, key=k3) - if (s_min is None or s_max is None) - else jax.random.uniform( - key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max - ) + joint_positions = ( + js.joint.random_joint_positions(model=model, key=k3) + if (s_min is None or s_max is None) + else jax.random.uniform( + key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max ) + ) - physics_model_state.joint_velocities = jax.random.uniform( - key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max - ) + joint_velocities = jax.random.uniform( + key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max + ) - if model.floating_base(): - physics_model_state.base_linear_velocity = jax.random.uniform( - key=k5, shape=(3,), minval=v_min, maxval=v_max - ) + if model.floating_base(): + base_linear_velocity = jax.random.uniform( + key=k5, shape=(3,), minval=v_min, maxval=v_max + ) - physics_model_state.base_angular_velocity = jax.random.uniform( - key=k6, shape=(3,), minval=ω_min, maxval=ω_max - ) + base_angular_velocity = jax.random.uniform( + key=k6, shape=(3,), minval=ω_min, maxval=ω_max + ) - random_data.gravity = ( - jnp.zeros(3, dtype=random_data.gravity.dtype) - .at[2] - .set( - -jax.random.uniform( - key=k7, - shape=(), - minval=standard_gravity_bounds[0], - maxval=standard_gravity_bounds[1], - ) - ) + standard_gravity = jax.random.uniform( + key=k7, + shape=(), + minval=standard_gravity_bounds[0], + maxval=standard_gravity_bounds[1], ) if contacts_params is None: @@ -912,17 +1038,21 @@ def random_model_data( | jaxsim.rbda.contacts.ViscoElasticContacts, ): - random_data = random_data.replace( - contacts_params=js.contact.estimate_good_contact_parameters( - model=model, standard_gravity=random_data.gravity - ), - validate=False, + contacts_params = js.contact.estimate_good_contact_parameters( + model=model, standard_gravity=standard_gravity ) - else: - random_data = random_data.replace( - contacts_params=model.contact_model._parameters_class(), - validate=False, - ) + contacts_params = (model.contact_model._parameters_class(),) - return random_data + return JaxSimModelData.build( + model=model, + velocity_representation=velocity_representation, + base_position=base_position, + base_quaternion=base_quaternion, + joint_positions=joint_positions, + joint_velocities=joint_velocities, + base_linear_velocity=base_linear_velocity, + base_angular_velocity=base_angular_velocity, + standard_gravity=standard_gravity, + contacts_params=contacts_params, + ) diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index 23b4d3732..50aafab24 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -187,7 +187,7 @@ def transform( idx=link_index, ) - return js.model.forward_kinematics(model=model, data=data)[link_index] + return data.kyn_dyn.forward_kinematics[link_index] @jax.jit @@ -276,6 +276,8 @@ def jacobian( B_J_full_WX_B, B_H_Li = jaxsim.rbda.jacobian_full_doubly_left( model=model, joint_positions=data.joint_positions(), + joint_transforms=data.kyn_dyn.joint_transforms, + motion_subspaces=data.kyn_dyn.motion_subspaces, ) # Compute the actual doubly-left free-floating jacobian of the link. @@ -422,9 +424,7 @@ def jacobian_derivative( output_vel_repr if output_vel_repr is not None else data.velocity_representation ) - O_J̇_WL_I = js.model.generalized_free_floating_jacobian_derivative( - model=model, data=data, output_vel_repr=output_vel_repr - )[link_index] + O_J̇_WL_I = data.kyn_dyn.jacobian_derivative[link_index] return O_J̇_WL_I diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 35c327791..82dc9dc87 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -548,9 +548,10 @@ def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp W_H_LL = jaxsim.rbda.forward_kinematics_model( model=model, - base_position=data.base_position(), + base_position=data.base_position, base_quaternion=data.base_orientation(dcm=False), joint_positions=data.joint_positions(model=model), + joint_transforms=data.kyn_dyn.joint_transforms, ) return jnp.atleast_3d(W_H_LL).astype(float) @@ -590,6 +591,8 @@ def generalized_free_floating_jacobian( B_J_full_WX_B, B_H_L = jaxsim.rbda.jacobian_full_doubly_left( model=model, joint_positions=data.joint_positions(), + joint_transforms=data.kyn_dyn.joint_transforms, + motion_subspaces=data.kyn_dyn.motion_subspaces, ) # ====================================================================== @@ -713,18 +716,14 @@ def generalized_free_floating_jacobian_derivative( ) # Compute the derivative of the doubly-left free-floating full jacobian. - B_J̇_full_WX_B, B_H_L = jaxsim.rbda.jacobian_derivative_full_doubly_left( - model=model, - joint_positions=data.joint_positions(), - joint_velocities=data.joint_velocities(), + B_J̇_full_WX_B, B_H_L = ( + data.kyn_dyn.jacobian_derivative_full_doubly_left, + data.kyn_dyn.forward_kinematics, ) # The derivative of the equation to change the input and output representations # of the Jacobian derivative needs the computation of the plain link Jacobian. - B_J_full_WL_B, _ = jaxsim.rbda.jacobian_full_doubly_left( - model=model, - joint_positions=data.joint_positions(), - ) + B_J_full_WL_B = data.kyn_dyn.jacobian_full_doubly_left # Compute the actual doubly-left free-floating jacobian derivative of the link # by zeroing the columns not in the path π_B(L) using the boolean κ(i). @@ -979,7 +978,7 @@ def forward_dynamics_aba( # Extract the state in inertial-fixed representation. with data.switch_velocity_representation(VelRepr.Inertial): - W_p_B = data.base_position() + W_p_B = data.base_position W_v_WB = data.base_velocity() W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model, joint_names=joint_names) @@ -1005,6 +1004,8 @@ def forward_dynamics_aba( joint_forces=τ, link_forces=W_f_L, standard_gravity=data.standard_gravity(), + joint_transforms=data.kyn_dyn.joint_transforms, + motion_subspaces=data.kyn_dyn.motion_subspaces, ) # ============= @@ -1172,10 +1173,7 @@ def free_floating_mass_matrix( The free-floating mass matrix of the model. """ - M_body = jaxsim.rbda.crba( - model=model, - joint_positions=data.state.physics_model.joint_positions, - ) + M_body = data.kyn_dyn.mass_matrix match data.velocity_representation: case VelRepr.Body: @@ -1431,7 +1429,7 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): # Extract the state in inertial-fixed representation. with data.switch_velocity_representation(VelRepr.Inertial): - W_p_B = data.base_position() + W_p_B = data.base_position W_v_WB = data.base_velocity() W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model, joint_names=joint_names) @@ -1458,6 +1456,8 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): joint_accelerations=s̈, link_forces=W_f_L, standard_gravity=data.standard_gravity(), + joint_transforms=data.kyn_dyn.joint_transforms, + motion_subspaces=data.kyn_dyn.motion_subspaces, ) # ============= @@ -1766,7 +1766,7 @@ def average_velocity_jacobian( case VelRepr.Body: GB_J = G_J - W_p_B = data.base_position() + W_p_B = data.base_position W_p_CoM = js.com.com_position(model=model, data=data) B_R_W = data.base_orientation(dcm=True).transpose() @@ -1778,7 +1778,7 @@ def average_velocity_jacobian( case VelRepr.Mixed: GW_J = G_J - W_p_B = data.base_position() + W_p_B = data.base_position W_p_CoM = js.com.com_position(model=model, data=data) BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B) @@ -1980,11 +1980,11 @@ def body_to_other_representation( ) case VelRepr.Inertial: - C_H_L = W_H_L = js.model.forward_kinematics(model=model, data=data) + C_H_L = W_H_L = data.kyn_dyn.forward_kinematics L_v_CL = L_v_WL case VelRepr.Mixed: - W_H_L = js.model.forward_kinematics(model=model, data=data) + W_H_L = data.kyn_dyn.forward_kinematics LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L) C_H_L = LW_H_L L_v_CL = L_v_LW_L = jax.vmap( # noqa: F841 @@ -2147,6 +2147,47 @@ def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.F # ========== +@jax.jit +@js.common.named_scope +def forward( + model: JaxSimModel, + data: js.data.JaxSimModelData, + link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, +) -> js.data.JaxSimModelData: + + # Kinematics computation. + i_X_λ, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + joint_positions=data.joint_positions(), base_transform=data.base_transform() + ) + M = jaxsim.rbda.crba( + model=model, + joint_positions=data.state.physics_model.joint_positions, + joint_transforms=i_X_λ, + motion_subspaces=S, + ) + J, W_H_L = jaxsim.rbda.jacobian_full_doubly_left( + model=model, + joint_positions=data.joint_positions(), + joint_transforms=i_X_λ, + motion_subspaces=S, + ) + J̇, _ = jaxsim.rbda.jacobian_derivative_full_doubly_left( + model=model, + joint_positions=data.joint_positions(), + joint_velocities=data.joint_velocities(), + ) + + data.kyn_dyn.jacobian_full_doubly_left = J + data.kyn_dyn.jacobian_derivative_full_doubly_left = J̇ + data.kyn_dyn.joint_transforms = i_X_λ + data.kyn_dyn.motion_subspaces = S + data.kyn_dyn.mass_matrix = M + data.kyn_dyn.forward_kinematics = W_H_L + + return data + + @jax.jit @js.common.named_scope def step( @@ -2254,7 +2295,7 @@ def step( f_L = references.link_forces(model=model, data=data) τ_references = references.joint_force_references(model=model) - # Step the dynamics forward. + # Integrate the system dynamics. state_tf, integrator_metadata_tf = integrator.step( x0=state_t0, t0=t0, @@ -2285,6 +2326,14 @@ def step( # Phase 3: post-step # ================== + # Step the dynamics forward. + data_tf = forward( + model=model, + data=data, + link_forces=f_L, + joint_force_references=τ_references, + ) + # Post process the simulation state, if needed. match model.contact_model: diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index 1c0a11078..bdfa4214f 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -242,7 +242,7 @@ def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: )(W_f_L, W_H_L) # The f_L output is either L_f_L or LW_f_L, depending on the representation. - W_H_L = js.model.forward_kinematics(model=model, data=data) + W_H_L = data.kyn_dyn.forward_kinematics f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :]) return f_L @@ -450,7 +450,7 @@ def convert_using_link_frame( )(f_L, W_H_L) # The f_L input is either L_f_L or LW_f_L, depending on the representation. - W_H_L = js.model.forward_kinematics(model=model, data=data) + W_H_L = data.kyn_dyn.forward_kinematics W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :]) return replace(forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L)) diff --git a/src/jaxsim/exceptions.py b/src/jaxsim/exceptions.py index 58e3fc1a0..2c78c2880 100644 --- a/src/jaxsim/exceptions.py +++ b/src/jaxsim/exceptions.py @@ -1,3 +1,5 @@ +import os + import jax @@ -18,7 +20,9 @@ def raise_if( """ # Disable host callback if running on TPU. - if jax.devices()[0].platform == "tpu": + if jax.devices()[0].platform == "tpu" or os.environ.get( + "JAXSIM_DISABLE_EXCEPTIONS", False + ): return # Check early that the format string is well-formed. diff --git a/src/jaxsim/mujoco/utils.py b/src/jaxsim/mujoco/utils.py index 2afff1732..c80bfa79e 100644 --- a/src/jaxsim/mujoco/utils.py +++ b/src/jaxsim/mujoco/utils.py @@ -59,7 +59,7 @@ def mujoco_data_from_jaxsim( if jaxsim_model.floating_base(): # Set the model position. - model_helper.set_base_position(position=np.array(jaxsim_data.base_position())) + model_helper.set_base_position(position=np.array(jaxsim_data.base_position)) # Set the model orientation. model_helper.set_base_orientation( diff --git a/src/jaxsim/rbda/aba.py b/src/jaxsim/rbda/aba.py index b01f46698..8a0aa1cbb 100644 --- a/src/jaxsim/rbda/aba.py +++ b/src/jaxsim/rbda/aba.py @@ -21,6 +21,8 @@ def aba( joint_forces: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, standard_gravity: jtp.FloatLike = StandardGravity, + joint_transforms, + motion_subspaces, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute forward dynamics using the Articulated Body Algorithm (ABA). @@ -85,12 +87,10 @@ def aba( W_X_B = W_H_B.adjoint() B_X_W = W_H_B.inverse().adjoint() - # Compute the parent-to-child adjoints and the motion subspaces of the joints. + # Extract the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=W_H_B.as_matrix() - ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate buffers. v = jnp.zeros(shape=(model.number_of_links(), 6, 1)) diff --git a/src/jaxsim/rbda/collidable_points.py b/src/jaxsim/rbda/collidable_points.py index 543be5328..dba3d98b9 100644 --- a/src/jaxsim/rbda/collidable_points.py +++ b/src/jaxsim/rbda/collidable_points.py @@ -1,6 +1,5 @@ import jax import jax.numpy as jnp -import jaxlie import jaxsim.api as js import jaxsim.typing as jtp @@ -18,6 +17,8 @@ def collidable_points_pos_vel( base_linear_velocity: jtp.Vector, base_angular_velocity: jtp.Vector, joint_velocities: jtp.Vector, + joint_transforms, + motion_subspaces, ) -> tuple[jtp.Matrix, jtp.Matrix]: """ @@ -54,7 +55,7 @@ def collidable_points_pos_vel( if len(indices_of_enabled_collidable_points) == 0: return jnp.array(0).astype(float), jnp.empty(0).astype(float) - W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs( + _, _, _, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs( model=model, base_position=base_position, base_quaternion=base_quaternion, @@ -68,18 +69,10 @@ def collidable_points_pos_vel( # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array - # Compute the base transform. - W_H_B = jaxlie.SE3.from_rotation_and_translation( - rotation=jaxlie.SO3(wxyz=W_Q_B), - translation=W_p_B, - ) - - # Compute the parent-to-child adjoints and the motion subspaces of the joints. + # Extract the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=W_H_B.as_matrix() - ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate buffer of transforms world -> link and initialize the base pose. W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6)) diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index 56b403fa7..4641fbd46 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -251,7 +251,7 @@ def link_forces_from_contact_forces( # Compute the link transforms. W_H_L = ( - js.model.forward_kinematics(model=model, data=data) + data.kyn_dyn.forward_kinematics if data.velocity_representation is not jaxsim.VelRepr.Inertial else jnp.zeros(shape=(model.number_of_links(), 4, 4)) ) diff --git a/src/jaxsim/rbda/contacts/visco_elastic.py b/src/jaxsim/rbda/contacts/visco_elastic.py index c433fe23d..11101cb0c 100644 --- a/src/jaxsim/rbda/contacts/visco_elastic.py +++ b/src/jaxsim/rbda/contacts/visco_elastic.py @@ -827,7 +827,7 @@ def integrate_data_with_average_contact_forces( """ s_t0 = data.joint_positions() - W_p_B_t0 = data.base_position() + W_p_B_t0 = data.base_position W_Q_B_t0 = data.base_orientation(dcm=False) ṡ_t0 = data.joint_velocities() @@ -1005,7 +1005,7 @@ def step( # Compute the link transforms. W_H_L = ( - js.model.forward_kinematics(model=model, data=data) + data.kyn_dyn.forward_kinematics if data.velocity_representation is not jaxsim.VelRepr.Inertial else jnp.zeros(shape=(model.number_of_links(), 4, 4)) ) diff --git a/src/jaxsim/rbda/crba.py b/src/jaxsim/rbda/crba.py index adb3506ae..6272be898 100644 --- a/src/jaxsim/rbda/crba.py +++ b/src/jaxsim/rbda/crba.py @@ -4,10 +4,14 @@ import jaxsim.api as js import jaxsim.typing as jtp -from . import utils - -def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Matrix: +def crba( + model: js.model.JaxSimModel, + *, + joint_positions: jtp.Vector, + joint_transforms, + motion_subspaces, +) -> jtp.Matrix: """ Compute the free-floating mass matrix using the Composite Rigid-Body Algorithm (CRBA). @@ -19,10 +23,6 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat The free-floating mass matrix of the model in body-fixed representation. """ - _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( - model=model, joint_positions=joint_positions - ) - # Get the 6D spatial inertia matrices of all links. Mc = js.model.link_spatial_inertia_matrices(model=model) @@ -30,12 +30,10 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array - # Compute the parent-to-child adjoints and the motion subspaces of the joints. + # Extract the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=jnp.eye(4) - ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate the buffer of transforms link -> base. i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6)) diff --git a/src/jaxsim/rbda/forward_kinematics.py b/src/jaxsim/rbda/forward_kinematics.py index d11e9b45d..371b25aaf 100644 --- a/src/jaxsim/rbda/forward_kinematics.py +++ b/src/jaxsim/rbda/forward_kinematics.py @@ -1,13 +1,10 @@ import jax import jax.numpy as jnp -import jaxlie import jaxsim.api as js import jaxsim.typing as jtp from jaxsim.math import Adjoint -from . import utils - def forward_kinematics_model( model: js.model.JaxSimModel, @@ -15,6 +12,7 @@ def forward_kinematics_model( base_position: jtp.VectorLike, base_quaternion: jtp.VectorLike, joint_positions: jtp.VectorLike, + joint_transforms, ) -> jtp.Array: """ Compute the forward kinematics. @@ -29,29 +27,14 @@ def forward_kinematics_model( A 3D array containing the SE(3) transforms of all links belonging to the model. """ - W_p_B, W_Q_B, s, _, _, _, _, _, _, _ = utils.process_inputs( - model=model, - base_position=base_position, - base_quaternion=base_quaternion, - joint_positions=joint_positions, - ) - # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array - # Compute the base transform. - W_H_B = jaxlie.SE3.from_rotation_and_translation( - rotation=jaxlie.SO3(wxyz=W_Q_B), - translation=W_p_B, - ) - - # Compute the parent-to-child adjoints and the motion subspaces of the joints. + # Extract the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, _ = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=W_H_B.as_matrix() - ) + i_X_λi = joint_transforms # Allocate the buffer of transforms world -> link and initialize the base pose. W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6)) diff --git a/src/jaxsim/rbda/jacobian.py b/src/jaxsim/rbda/jacobian.py index e8f44d088..3aedfa5bf 100644 --- a/src/jaxsim/rbda/jacobian.py +++ b/src/jaxsim/rbda/jacobian.py @@ -14,6 +14,8 @@ def jacobian( *, link_index: jtp.Int, joint_positions: jtp.VectorLike, + joint_transforms, + motion_subspaces, ) -> jtp.Matrix: """ Compute the free-floating Jacobian of a link. @@ -27,20 +29,14 @@ def jacobian( The free-floating left-trivialized Jacobian of the link :math:`{}^L J_{W,L/B}`. """ - _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( - model=model, joint_positions=joint_positions - ) - # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array - # Compute the parent-to-child adjoints and the motion subspaces of the joints. + # Extract the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=jnp.eye(4) - ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate the buffer of transforms link -> base. i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6)) @@ -127,6 +123,8 @@ def jacobian_full_doubly_left( model: js.model.JaxSimModel, *, joint_positions: jtp.VectorLike, + joint_transforms, + motion_subspaces, ) -> tuple[jtp.Matrix, jtp.Array]: r""" Compute the doubly-left full free-floating Jacobian of a model. @@ -144,10 +142,6 @@ def jacobian_full_doubly_left( The doubly-left full free-floating Jacobian of a model. """ - _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( - model=model, joint_positions=joint_positions - ) - # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array @@ -155,9 +149,7 @@ def jacobian_full_doubly_left( # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=jnp.eye(4) - ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate the buffer of transforms base -> link. B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6)) diff --git a/src/jaxsim/rbda/rnea.py b/src/jaxsim/rbda/rnea.py index 025d85a62..8f3d18036 100644 --- a/src/jaxsim/rbda/rnea.py +++ b/src/jaxsim/rbda/rnea.py @@ -23,6 +23,8 @@ def rnea( joint_accelerations: jtp.Vector | None = None, link_forces: jtp.Matrix | None = None, standard_gravity: jtp.FloatLike = StandardGravity, + joint_transforms, + motion_subspaces, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute inverse dynamics using the Recursive Newton-Euler Algorithm (RNEA). @@ -88,12 +90,10 @@ def rnea( W_X_B = W_H_B.adjoint() B_X_W = W_H_B.inverse().adjoint() - # Compute the parent-to-child adjoints and the motion subspaces of the joints. + # Extract the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. - i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( - joint_positions=s, base_transform=W_H_B.as_matrix() - ) + i_X_λi, S = joint_transforms, motion_subspaces # Allocate buffers. v = jnp.zeros(shape=(model.number_of_links(), 6, 1)) diff --git a/src/jaxsim/typing.py b/src/jaxsim/typing.py index 2c7d55aa8..2e4b4eb9b 100644 --- a/src/jaxsim/typing.py +++ b/src/jaxsim/typing.py @@ -20,9 +20,9 @@ dict[Hashable, TypeVar("PyTree")] | list[TypeVar("PyTree")] | tuple[TypeVar("PyTree")] - | None | jax.Array | Any + | None ) # ======================= diff --git a/tests/test_api_contact.py b/tests/test_api_contact.py index 4a0882737..63666431d 100644 --- a/tests/test_api_contact.py +++ b/tests/test_api_contact.py @@ -132,7 +132,7 @@ def test_contact_jacobian_derivative( # Rebuild the JaxSim data. data_with_frames = js.data.JaxSimModelData.build( model=model_with_frames, - base_position=data.base_position(), + base_position=data.base_position, base_quaternion=data.base_orientation(dcm=False), joint_positions=data.joint_positions(), base_linear_velocity=data.base_velocity()[0:3], diff --git a/tests/test_api_frame.py b/tests/test_api_frame.py index d20dbc5ff..3be031cbc 100644 --- a/tests/test_api_frame.py +++ b/tests/test_api_frame.py @@ -249,7 +249,7 @@ def J(q, frame_idxs) -> jax.Array: def compute_q(data: js.data.JaxSimModelData) -> jax.Array: q = jnp.hstack( [ - data.base_position(), + data.base_position, data.base_orientation(), data.joint_positions(), ] diff --git a/tests/test_api_link.py b/tests/test_api_link.py index 7f89e0cc5..150ef574f 100644 --- a/tests/test_api_link.py +++ b/tests/test_api_link.py @@ -341,7 +341,7 @@ def J(q) -> jax.Array: def compute_q(data: js.data.JaxSimModelData) -> jax.Array: q = jnp.hstack( - [data.base_position(), data.base_orientation(), data.joint_positions()] + [data.base_position, data.base_orientation(), data.joint_positions()] ) return q diff --git a/tests/test_api_model.py b/tests/test_api_model.py index 6b51da58d..836a4c58b 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -447,7 +447,7 @@ def M(q) -> jax.Array: def compute_q(data: js.data.JaxSimModelData) -> jax.Array: q = jnp.hstack( - [data.base_position(), data.base_orientation(), data.joint_positions()] + [data.base_position, data.base_orientation(), data.joint_positions()] ) return q diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 477f6245d..8c8e9fc9b 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -75,7 +75,7 @@ def test_ad_aba( g = jaxsim.math.StandardGravity # State in VelRepr.Inertial representation. - W_p_B = data.base_position() + W_p_B = data.base_position W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model) W_v_WB = data.base_velocity() @@ -129,7 +129,7 @@ def test_ad_rnea( g = jaxsim.math.StandardGravity # State in VelRepr.Inertial representation. - W_p_B = data.base_position() + W_p_B = data.base_position W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model) W_v_WB = data.base_velocity() @@ -217,7 +217,7 @@ def test_ad_fk( ) # State in VelRepr.Inertial representation. - W_p_B = data.base_position() + W_p_B = data.base_position W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model) @@ -344,7 +344,7 @@ def test_ad_integration( ) # State in VelRepr.Inertial representation. - W_p_B = data.base_position() + W_p_B = data.base_position W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model) W_v_WB = data.base_velocity() diff --git a/tests/test_pytree.py b/tests/test_pytree.py index c27179254..c3bf82576 100644 --- a/tests/test_pytree.py +++ b/tests/test_pytree.py @@ -51,7 +51,7 @@ def my_jit_function(model: js.model.JaxSimModel, data: js.data.JaxSimModelData): # Return random elements from model and data, just to have something returned. return ( jnp.sum(model.kin_dyn_parameters.link_parameters.mass), - data.base_position(), + data.base_position, ) data1 = js.data.JaxSimModelData.build(model=model1) diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 99f90d899..230ff11e4 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -78,7 +78,7 @@ def test_box_with_external_forces( ) # Check that the box didn't move. - assert data.base_position() == pytest.approx(data0.base_position()) + assert data.base_position == pytest.approx(data0.base_position()) assert data.base_orientation() == pytest.approx(data0.base_orientation()) @@ -158,7 +158,7 @@ def test_box_with_zero_gravity( ) # Check that the box moved as expected. - assert data.base_position() == pytest.approx( + assert data.base_position == pytest.approx( data0.base_position() + 0.5 * LW_f[:, :3].squeeze() / js.model.total_mass(model=model) * tf**2, abs=1e-3,