From 03d926dd8870338cdb419c6bb6dd9b6dbe45c2c6 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 29 Nov 2024 16:26:18 +0100 Subject: [PATCH 1/7] Fix deprecation warning for Python 3.16 --- src/jaxsim/api/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 8ea60eaff..638b33baf 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2198,7 +2198,9 @@ def step( isinstance(model.contact_model, jaxsim.rbda.contacts.ViscoElasticContacts) & ( ~jnp.allclose(dt, model.time_step) - | ~isinstance(integrator, jaxsim.integrators.fixed_step.ForwardEuler) + | ~int( + isinstance(integrator, jaxsim.integrators.fixed_step.ForwardEuler) + ) ) ), msg=msg.format(module, name), From 9d353543e8a2305d7691264697cf71d00c0fc564 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 5 Dec 2024 11:05:32 +0100 Subject: [PATCH 2/7] Remove `ODEInput` class --- src/jaxsim/api/ode_data.py | 100 ++----------------------------------- 1 file changed, 3 insertions(+), 97 deletions(-) diff --git a/src/jaxsim/api/ode_data.py b/src/jaxsim/api/ode_data.py index d9e6ae3f1..e8667f29e 100644 --- a/src/jaxsim/api/ode_data.py +++ b/src/jaxsim/api/ode_data.py @@ -10,108 +10,14 @@ import jaxsim.typing as jtp from jaxsim.utils import JaxsimDataclass -# ============================================================================= -# Define the input and state of the ODE system defining the integrated dynamics -# ============================================================================= +# =================================================================== +# Define the state of the ODE system defining the integrated dynamics +# =================================================================== # Note: the ODE system is the combination of the floating-base dynamics and the # soft-contacts dynamics. -@jax_dataclasses.pytree_dataclass -class ODEInput(JaxsimDataclass): - """ - The input to the ODE system. - - Attributes: - physics_model: The input to the physics model. - """ - - physics_model: PhysicsModelInput - - @staticmethod - def build_from_jaxsim_model( - model: js.model.JaxSimModel | None = None, - link_forces: jtp.MatrixLike | None = None, - joint_force_references: jtp.VectorLike | None = None, - ) -> ODEInput: - """ - Build an `ODEInput` from a `JaxSimModel`. - - Args: - model: The `JaxSimModel` associated with the ODE input. - link_forces: The matrix of external forces applied to the links. - joint_force_references: The vector of joint force references. - - Returns: - The `ODEInput` built from the `JaxSimModel`. - - Note: - If any of the input components are not provided, they are built from the - `JaxSimModel` and initialized to zero. - """ - - return ODEInput.build( - physics_model_input=PhysicsModelInput.build_from_jaxsim_model( - model=model, - link_forces=link_forces, - joint_force_references=joint_force_references, - ), - model=model, - ) - - @staticmethod - def build( - physics_model_input: PhysicsModelInput | None = None, - model: js.model.JaxSimModel | None = None, - ) -> ODEInput: - """ - Build an `ODEInput` from a `PhysicsModelInput`. - - Args: - physics_model_input: The `PhysicsModelInput` associated with the ODE input. - model: The `JaxSimModel` associated with the ODE input. - - Returns: - A `ODEInput` instance. - """ - - physics_model_input = ( - physics_model_input - if physics_model_input is not None - else PhysicsModelInput.zero(model=model) - ) - - return ODEInput(physics_model=physics_model_input) - - @staticmethod - def zero(model: js.model.JaxSimModel) -> ODEInput: - """ - Build a zero `ODEInput` from a `JaxSimModel`. - - Args: - model: The `JaxSimModel` associated with the ODE input. - - Returns: - A zero `ODEInput` instance. - """ - - return ODEInput.build(model=model) - - def valid(self, model: js.model.JaxSimModel) -> bool: - """ - Check if the `ODEInput` is valid for a given `JaxSimModel`. - - Args: - model: The `JaxSimModel` to validate the `ODEInput` against. - - Returns: - `True` if the ODE input is valid for the given model, `False` otherwise. - """ - - return self.physics_model.valid(model=model) - - @jax_dataclasses.pytree_dataclass class ODEState(JaxsimDataclass): """ From 13a884ff823e293f30f330a130d18c6343d7cdb6 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 5 Dec 2024 11:06:02 +0100 Subject: [PATCH 3/7] Remove `PhysicsModelInput` class --- src/jaxsim/api/ode_data.py | 120 ------------------------------------- 1 file changed, 120 deletions(-) diff --git a/src/jaxsim/api/ode_data.py b/src/jaxsim/api/ode_data.py index e8667f29e..9d5db94f9 100644 --- a/src/jaxsim/api/ode_data.py +++ b/src/jaxsim/api/ode_data.py @@ -399,123 +399,3 @@ def valid(self, model: js.model.JaxSimModel) -> bool: return False return True - - -@jax_dataclasses.pytree_dataclass -class PhysicsModelInput(JaxsimDataclass): - """ - Class storing the inputs of the physics model dynamics. - - Attributes: - tau: The vector of joint forces. - f_ext: The matrix of external forces applied to the links. - """ - - tau: jtp.Vector - f_ext: jtp.Matrix - - @staticmethod - def build_from_jaxsim_model( - model: js.model.JaxSimModel | None = None, - link_forces: jtp.MatrixLike | None = None, - joint_force_references: jtp.VectorLike | None = None, - ) -> PhysicsModelInput: - """ - Build a `PhysicsModelInput` from a `JaxSimModel`. - - Args: - model: The `JaxSimModel` associated with the input. - link_forces: The matrix of external forces applied to the links. - joint_force_references: The vector of joint force references. - - Returns: - A `PhysicsModelInput` instance. - - Note: - If any of the input components are not provided, they are built from the - `JaxSimModel` and initialized to zero. - """ - - return PhysicsModelInput.build( - joint_force_references=joint_force_references, - link_forces=link_forces, - number_of_dofs=model.dofs(), - number_of_links=model.number_of_links(), - ) - - @staticmethod - def build( - link_forces: jtp.MatrixLike | None = None, - joint_force_references: jtp.VectorLike | None = None, - number_of_dofs: jtp.Int | None = None, - number_of_links: jtp.Int | None = None, - ) -> PhysicsModelInput: - """ - Build a `PhysicsModelInput`. - - Args: - link_forces: The matrix of external forces applied to the links. - joint_force_references: The vector of joint force references. - number_of_dofs: The number of degrees of freedom of the model. - number_of_links: The number of links of the model. - - Returns: - A `PhysicsModelInput` instance. - """ - - joint_force_references = jnp.atleast_1d( - jnp.array(joint_force_references, dtype=float).squeeze() - if joint_force_references is not None - else jnp.zeros(number_of_dofs) - ).astype(float) - - link_forces = jnp.atleast_2d( - jnp.array(link_forces, dtype=float).squeeze() - if link_forces is not None - else jnp.zeros(shape=(number_of_links, 6)) - ).astype(float) - - return PhysicsModelInput( - tau=joint_force_references, - f_ext=link_forces, - ) - - @staticmethod - def zero(model: js.model.JaxSimModel) -> PhysicsModelInput: - """ - Build a `PhysicsModelInput` with all components initialized to zero. - - Args: - model: The `JaxSimModel` associated with the input. - - Returns: - A `PhysicsModelInput` instance. - """ - - return PhysicsModelInput.build_from_jaxsim_model(model=model) - - def valid(self, model: js.model.JaxSimModel) -> bool: - """ - Check if the `PhysicsModelInput` is valid for a given `JaxSimModel`. - - Args: - model: The `JaxSimModel` to validate the `PhysicsModelInput` against. - - Returns: - `True` if the `PhysicsModelInput` is valid for the given model, - `False` otherwise. - """ - - shape = self.tau.shape - expected_shape = (model.dofs(),) - - if shape != expected_shape: - return False - - shape = self.f_ext.shape - expected_shape = (model.number_of_links(), 6) - - if shape != expected_shape: - return False - - return True From 4b98537c3ba751962bb26a7a45adf321f5cac7bd Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 5 Dec 2024 11:07:16 +0100 Subject: [PATCH 4/7] Refactor `JaxSimModelReferences` to remove `ODEInput` and replace with link and joint force attributes --- src/jaxsim/api/references.py | 78 +++++++++++++++++------------------- 1 file changed, 37 insertions(+), 41 deletions(-) diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index 3c4c94c23..eca7a858a 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -12,7 +12,6 @@ from jaxsim.utils.tracing import not_tracing from .common import VelRepr -from .ode_data import ODEInput try: from typing import Self @@ -26,7 +25,8 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation): Class containing the references for a `JaxSimModel` object. """ - input: ODEInput + _link_forces: jtp.Matrix + _joint_force_references: jtp.Vector @staticmethod def zero( @@ -94,17 +94,21 @@ def build( velocity_representation = ( velocity_representation if velocity_representation is not None - else ( - data.velocity_representation if data is not None else VelRepr.Inertial - ) + else getattr(data, "velocity_representation", VelRepr.Inertial) ) # Create a zero references object. references = JaxSimModelReferences( - input=ODEInput.zero(model=model), + _link_forces=f_L, + _joint_force_references=joint_force_references, velocity_representation=velocity_representation, ) + # If the velocity representation is inertial-fixed, we can return + # the references directly, as we store the link forces in this frame. + if velocity_representation is VelRepr.Inertial: + return references + # Store the joint force references. references = references.set_joint_force_references( forces=joint_force_references, @@ -135,12 +139,22 @@ def valid(self, model: js.model.JaxSimModel | None = None) -> bool: `False` otherwise. """ - valid = True + if model is None: + return True + + shape = self._joint_force_references.shape + expected_shape = (model.dofs(),) + + if shape != expected_shape: + return False - if model is not None: - valid = valid and self.input.valid(model=model) + shape = self._link_forces.shape + expected_shape = (model.number_of_links(), 6) - return valid + if shape != expected_shape: + return False + + return True # ================== # Extract quantities @@ -178,7 +192,7 @@ def link_forces( e.g. to the contact model and other kinematic constraints. """ - W_f_L = self.input.physics_model.f_ext + W_f_L = self._link_forces # Return all link forces in inertial-fixed representation using the implicit # serialization. @@ -190,7 +204,7 @@ def link_forces( if link_names is not None: raise ValueError("Link names cannot be provided without a model") - return self.input.physics_model.f_ext + return W_f_L # If we have the model, we can extract the link names, if not provided. link_idxs = ( @@ -207,7 +221,7 @@ def link_forces( msg = "Missing model data to use a representation different from {}" raise ValueError(msg.format(VelRepr.Inertial.name)) - if not_tracing(self.input.physics_model.f_ext) and not data.valid(model=model): + if not_tracing(self._link_forces) and not data.valid(model=model): raise ValueError("The provided data is not valid for the model") # Helper function to convert a single 6D force to the active representation @@ -264,9 +278,9 @@ def joint_force_references( if joint_names is not None: raise ValueError("Joint names cannot be provided without a model") - return self.input.physics_model.tau + return self._joint_force_references - if not_tracing(self.input.physics_model.tau) and not self.valid(model=model): + if not_tracing(self._joint_force_references) and not self.valid(model=model): msg = "The actuation object is not compatible with the provided model" raise ValueError(msg) @@ -277,7 +291,7 @@ def joint_force_references( ) return jnp.atleast_1d( - self.input.physics_model.tau[joint_idxs].squeeze() + self._joint_force_references[joint_idxs].squeeze() ).astype(float) # ================ @@ -310,11 +324,7 @@ def set_joint_force_references( def replace(forces: jtp.Vector) -> JaxSimModelReferences: return self.replace( validate=True, - input=self.input.replace( - physics_model=self.input.physics_model.replace( - tau=jnp.atleast_1d(forces.squeeze()).astype(float) - ) - ), + _joint_force_references=jnp.atleast_1d(forces.squeeze()).astype(float), ) if model is None: @@ -330,7 +340,7 @@ def replace(forces: jtp.Vector) -> JaxSimModelReferences: else jnp.arange(model.number_of_joints()) ) - return replace(forces=self.input.physics_model.tau.at[joint_idxs].set(forces)) + return replace(forces=self._joint_force_references.at[joint_idxs].set(forces)) @functools.partial(jax.jit, static_argnames=["link_names", "additive"]) def apply_link_forces( @@ -370,11 +380,7 @@ def apply_link_forces( def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: return self.replace( validate=True, - input=self.input.replace( - physics_model=self.input.physics_model.replace( - f_ext=jnp.atleast_2d(forces.squeeze()).astype(float) - ) - ), + _link_forces=jnp.atleast_2d(forces.squeeze()).astype(float), ) # In this case, we allow only to set the inertial 6D forces to all links @@ -389,11 +395,7 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: W_f_L = f_L - W_f0_L = ( - jnp.zeros_like(W_f_L) - if not additive - else self.input.physics_model.f_ext - ) + W_f0_L = jnp.zeros_like(W_f_L) if not additive else self._link_forces return replace(forces=W_f0_L + W_f_L) @@ -410,18 +412,14 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: # Compute the bias depending on whether we either set or add the link forces. W_f0_L = ( - jnp.zeros_like(f_L) - if not additive - else self.input.physics_model.f_ext[link_idxs, :] + jnp.zeros_like(f_L) if not additive else self._link_forces[link_idxs, :] ) # If inertial-fixed representation, we can directly store the link forces. if self.velocity_representation is VelRepr.Inertial: W_f_L = f_L return replace( - forces=self.input.physics_model.f_ext.at[link_idxs, :].set( - W_f0_L + W_f_L - ) + forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L) ) if data is None: @@ -450,9 +448,7 @@ def convert_using_link_frame( W_H_L = js.model.forward_kinematics(model=model, data=data) W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :]) - return replace( - forces=self.input.physics_model.f_ext.at[link_idxs, :].set(W_f0_L + W_f_L) - ) + return replace(forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L)) def apply_frame_forces( self, From 3c9ce4a8264d6ba93fa79135859c901866bc29bf Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 5 Dec 2024 11:16:24 +0100 Subject: [PATCH 5/7] Update tests for link forces to use internal reference instead of physics model input --- tests/test_api_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_api_model.py b/tests/test_api_model.py index e0e428968..41f19f67f 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -364,7 +364,7 @@ def test_model_jacobian( ): f = references.link_forces(model=model, data=data) - assert f == pytest.approx(references.input.physics_model.f_ext) + assert f == pytest.approx(references._link_forces) J = js.model.generalized_free_floating_jacobian(model=model, data=data) JTf_inertial = jnp.einsum("l6g,l6->g", J, f) From 80735e19bd3ee5b00179840f5ca8afd7a08b9a2d Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 5 Dec 2024 11:16:40 +0100 Subject: [PATCH 6/7] Update viscoelastic model to use link forces instead of physics model input for zero initialization --- src/jaxsim/rbda/contacts/visco_elastic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jaxsim/rbda/contacts/visco_elastic.py b/src/jaxsim/rbda/contacts/visco_elastic.py index d674c4d90..c433fe23d 100644 --- a/src/jaxsim/rbda/contacts/visco_elastic.py +++ b/src/jaxsim/rbda/contacts/visco_elastic.py @@ -851,7 +851,7 @@ def integrate_data_with_average_contact_forces( W_f̅_L = ( jnp.array(average_link_contact_forces_inertial) if average_link_contact_forces_inertial is not None - else jnp.zeros_like(references.input.physics_model.f_ext) + else jnp.zeros_like(references._link_forces) ).astype(float) LW_f̿_L = ( From 22b3eb778acc1fdc219b9b6ba3748c02b06491c0 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 6 Dec 2024 15:53:47 +0100 Subject: [PATCH 7/7] Add attribute documentation for `JaxSimModelReferences` class Co-authored-by: Alessandro Croci --- src/jaxsim/api/references.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index eca7a858a..2a5ea010d 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -23,6 +23,10 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation): """ Class containing the references for a `JaxSimModel` object. + + Attributes: + _link_forces: The link 6D forces in inertial-fixed representation. + _joint_force_references: The joint force references. """ _link_forces: jtp.Matrix