From a002331ae38f65bcdba4759442764cce77bc8ea5 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 9 Dec 2024 15:15:45 +0100 Subject: [PATCH 1/4] Parametrize simulation test on `integrators` module --- tests/test_simulations.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 99f90d899..2db8721c8 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -201,8 +201,19 @@ def run_simulation( return data +@pytest.mark.parametrize( + "integrator", + [ + jaxsim.integrators.fixed_step.ForwardEuler, + jaxsim.integrators.fixed_step.ForwardEulerSO3, + jaxsim.integrators.fixed_step.RungeKutta4, + jaxsim.integrators.fixed_step.RungeKutta4SO3, + jaxsim.integrators.variable_step.BogackiShampineSO3, + ], +) def test_simulation_with_soft_contacts( jaxsim_model_box: js.model.JaxSimModel, + integrator: jaxsim.integrators.Integrator, ): model = jaxsim_model_box @@ -218,6 +229,11 @@ def test_simulation_with_soft_contacts( model.kin_dyn_parameters.contact_parameters.enabled = tuple( enabled_collidable_points_mask.tolist() ) + model.integrator = integrator.build( + dynamics=js.ode.wrap_system_dynamics_for_integration( + system_dynamics=js.ode.system_dynamics + ) + ) assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4 From 2dbcdcc9fcdcf50177a5293d5b2edd089feb7c29 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 9 Dec 2024 15:53:54 +0100 Subject: [PATCH 2/4] Fix boolean conversion in variable step integrators --- src/jaxsim/integrators/variable_step.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/jaxsim/integrators/variable_step.py b/src/jaxsim/integrators/variable_step.py index 5a754cc97..a1dca115f 100644 --- a/src/jaxsim/integrators/variable_step.py +++ b/src/jaxsim/integrators/variable_step.py @@ -481,10 +481,10 @@ def reject_step(): metadata_next, discarded_steps, ) = jax.lax.cond( - pred=discarded_steps - >= self.max_step_rejections | local_error - <= 1.0 | Δt_next - < self.dt_min | integrator_init, + pred=(discarded_steps >= self.max_step_rejections) + | (local_error <= 1.0) + | (Δt_next < self.dt_min) + | integrator_init, true_fun=accept_step, false_fun=reject_step, ) From 087f11298fa16adae75a88424b6e1b09ee25419e Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 9 Dec 2024 15:52:39 +0100 Subject: [PATCH 3/4] Remove `metadata` attribute in `ExplicitRungeKutta` --- src/jaxsim/integrators/common.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/jaxsim/integrators/common.py b/src/jaxsim/integrators/common.py index e858bf890..caa7c4ac2 100644 --- a/src/jaxsim/integrators/common.py +++ b/src/jaxsim/integrators/common.py @@ -53,10 +53,6 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]): repr=False, hash=False, compare=False, kw_only=True ) - metadata: dict[str, Any] = dataclasses.field( - default_factory=dict, repr=False, hash=False, compare=False, kw_only=True - ) - @classmethod def build( cls: type[Self], @@ -102,10 +98,7 @@ def step( metadata = metadata if metadata is not None else {} - with self.editable(validate=False) as integrator: - integrator.metadata = metadata - - with integrator.mutable_context(mutability=Mutability.MUTABLE): + with self.mutable_context(mutability=Mutability.MUTABLE) as integrator: xf, metadata_step = integrator(x0, t0, dt, **kwargs) return ( @@ -315,6 +308,9 @@ def _compute_next_state( b = self.b A = self.A + # Extract metadata from the kwargs. + metadata = kwargs.pop("metadata", {}) + # Close f over optional kwargs. f = lambda x, t: self.dynamics(x=x, t=t, **kwargs) @@ -327,7 +323,7 @@ def _compute_next_state( # or to use the previous state derivative (only integrators supporting FSAL). def get_ẋ0_and_aux_dict() -> tuple[StateDerivative, dict[str, Any]]: ẋ0, aux_dict = f(x0, t0) - return self.metadata.get("dxdt0", ẋ0), aux_dict + return metadata.get("dxdt0", ẋ0), aux_dict # We use a `jax.lax.scan` to compile the `f` function only once. # Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code @@ -381,7 +377,8 @@ def compute_ki() -> tuple[jax.Array, dict[str, Any]]: # Update the FSAL property for the next iteration. if self.has_fsal: - self.metadata["dxdt0"] = jax.tree.map(lambda l: l[self.index_of_fsal], K) + # Store the first derivative of the next step in the metadata. + metadata["dxdt0"] = jax.tree.map(lambda l: l[self.index_of_fsal], K) # Compute the output state. # Note that z contains as many new states as the rows of `b.T`. @@ -394,7 +391,7 @@ def compute_ki() -> tuple[jax.Array, dict[str, Any]]: lambda xf: self.post_process_state(x0=x0, t0=t0, xf=xf, dt=dt) )(z) - return z_transformed, aux_dict + return z_transformed, aux_dict | {"metadata": metadata} @staticmethod def butcher_tableau_is_valid( From 7cdf0f40a2f76fbf18f10634fd8cbeec4aabdd6a Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 9 Dec 2024 15:54:15 +0100 Subject: [PATCH 4/4] Refactor `metadata` handling in variable step integrators --- src/jaxsim/integrators/variable_step.py | 57 +++++++++++-------------- 1 file changed, 25 insertions(+), 32 deletions(-) diff --git a/src/jaxsim/integrators/variable_step.py b/src/jaxsim/integrators/variable_step.py index a1dca115f..5e063b8d0 100644 --- a/src/jaxsim/integrators/variable_step.py +++ b/src/jaxsim/integrators/variable_step.py @@ -14,7 +14,6 @@ import jaxsim.utils.tracing from jaxsim import typing as jtp -from jaxsim.utils import Mutability from .common import ( ExplicitRungeKutta, @@ -271,30 +270,27 @@ def init( # Inject this key to signal that the integrator is initializing. # This is used to allocate the arrays of the metadata dictionary, # that are then filled with NaNs. - integrator.metadata = {EmbeddedRungeKutta.InitializingKey: jnp.array(True)} + metadata = {EmbeddedRungeKutta.InitializingKey: jnp.array(True)} # Run a dummy call of the integrator. # It is used only to get the metadata so that we know the structure # of the corresponding pytree. _ = integrator( - x0, jnp.array(t0, dtype=float), jnp.array(dt, dtype=float), **kwargs + x0, + jnp.array(t0, dtype=float), + jnp.array(dt, dtype=float), + **(kwargs | {"metadata": metadata}), ) # Remove the injected key. - _ = integrator.metadata.pop(EmbeddedRungeKutta.InitializingKey) + _ = metadata.pop(EmbeddedRungeKutta.InitializingKey) # Make sure that all leafs of the dictionary are JAX arrays. # Also, since these are dummy parameters, set them all to NaN. metadata_after_init = jax.tree.map( - lambda l: jnp.nan * jnp.zeros_like(l), integrator.metadata + lambda l: jnp.nan * jnp.zeros_like(l), metadata ) - # Store the zero parameters in the integrator. - # When the integrator is stepped, this is used to check if the passed - # parameters are valid. - with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): - self.metadata = metadata_after_init - return metadata_after_init def __call__( @@ -307,7 +303,7 @@ def __call__( # The metadata is a dictionary of float JAX arrays, that are initialized # with the right shape and filled with NaNs. # 2. During the first step, this method operates on the Nan-filled - # `self.metadata` attribute, and it populates with the actual metadata. + # `metadata` argument, and it populates with the actual metadata. # 3. After the first step, this method operates on the actual metadata. # # In particular, we store the following information in the metadata: @@ -318,8 +314,10 @@ def __call__( # evaluate the dynamics at the final state of the previous step, that matches # the initial state of the current step. # + metadata = kwargs.pop("metadata", {}) + integrator_init = jnp.array( - self.metadata.get(self.InitializingKey, False), dtype=bool + metadata.get(self.InitializingKey, False), dtype=bool ) # Close f over optional kwargs. @@ -335,24 +333,23 @@ def __call__( # The value of dt0 is NaN (or, at least, it should be) only after initialization # and before the first step. - self.metadata["dt0"], self.metadata["dxdt0"] = jax.lax.cond( - pred=("dt0" in self.metadata) - & ~jnp.isnan(self.metadata.get("dt0", 0.0)).any(), + metadata["dt0"], metadata["dxdt0"] = jax.lax.cond( + pred=("dt0" in metadata) & ~jnp.isnan(metadata.get("dt0", 0.0)).any(), true_fun=lambda metadata: ( metadata.get("dt0", jnp.array(0.0, dtype=float)), - self.metadata.get("dxdt0", f(x0, t0)[0]), + metadata.get("dxdt0", f(x0, t0)[0]), ), false_fun=lambda aux: estimate_step_size( x0=x0, t0=t0, f=f, order=p, atol=self.atol, rtol=self.rtol ), - operand=self.metadata, + operand=metadata, ) # Clip the estimated initial step size to the given bounds, if necessary. - self.metadata["dt0"] = jnp.clip( - self.metadata["dt0"], - jnp.minimum(self.dt_min, self.metadata["dt0"]), - jnp.minimum(self.dt_max, self.metadata["dt0"]), + metadata["dt0"] = jnp.clip( + metadata["dt0"], + jnp.minimum(self.dt_min, metadata["dt0"]), + jnp.minimum(self.dt_max, metadata["dt0"]), ) # ========================================================= @@ -364,7 +361,7 @@ def __call__( carry0: Carry = ( x0, jnp.array(t0).astype(float), - self.metadata, + metadata, jnp.array(0, dtype=int), jnp.array(False).astype(bool), ) @@ -392,9 +389,10 @@ def while_loop_body(carry: Carry) -> Carry: # Run the underlying explicit RK integrator. # The output z contains multiple solutions (depending on the rows of b.T). with self.editable(validate=True) as integrator: - integrator.metadata = metadata - z, _ = integrator._compute_next_state(x0=x0, t0=t0, dt=Δt0, **kwargs) - metadata_next = integrator.metadata + z, aux_dict = integrator._compute_next_state( + x0=x0, t0=t0, dt=Δt0, **kwargs + ) + metadata_next = aux_dict["metadata"] # Extract the high-order solution xf and the low-order estimate x̂f. xf = jax.tree.map(lambda l: l[self.row_index_of_solution], z) @@ -510,12 +508,7 @@ def reject_step(): init_val=carry0, ) - # Store the parameters. - # They will be returned to the caller in a functional way in the step method. - with self.mutable_context(mutability=Mutability.MUTABLE): - self.metadata = metadata_tf - - return xf, {} + return xf, {"metadata": metadata_tf} @property def order_of_solution(self) -> int: