diff --git a/src/jaxsim/integrators/variable_step.py b/src/jaxsim/integrators/variable_step.py index a1dca115f..5e063b8d0 100644 --- a/src/jaxsim/integrators/variable_step.py +++ b/src/jaxsim/integrators/variable_step.py @@ -14,7 +14,6 @@ import jaxsim.utils.tracing from jaxsim import typing as jtp -from jaxsim.utils import Mutability from .common import ( ExplicitRungeKutta, @@ -271,30 +270,27 @@ def init( # Inject this key to signal that the integrator is initializing. # This is used to allocate the arrays of the metadata dictionary, # that are then filled with NaNs. - integrator.metadata = {EmbeddedRungeKutta.InitializingKey: jnp.array(True)} + metadata = {EmbeddedRungeKutta.InitializingKey: jnp.array(True)} # Run a dummy call of the integrator. # It is used only to get the metadata so that we know the structure # of the corresponding pytree. _ = integrator( - x0, jnp.array(t0, dtype=float), jnp.array(dt, dtype=float), **kwargs + x0, + jnp.array(t0, dtype=float), + jnp.array(dt, dtype=float), + **(kwargs | {"metadata": metadata}), ) # Remove the injected key. - _ = integrator.metadata.pop(EmbeddedRungeKutta.InitializingKey) + _ = metadata.pop(EmbeddedRungeKutta.InitializingKey) # Make sure that all leafs of the dictionary are JAX arrays. # Also, since these are dummy parameters, set them all to NaN. metadata_after_init = jax.tree.map( - lambda l: jnp.nan * jnp.zeros_like(l), integrator.metadata + lambda l: jnp.nan * jnp.zeros_like(l), metadata ) - # Store the zero parameters in the integrator. - # When the integrator is stepped, this is used to check if the passed - # parameters are valid. - with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): - self.metadata = metadata_after_init - return metadata_after_init def __call__( @@ -307,7 +303,7 @@ def __call__( # The metadata is a dictionary of float JAX arrays, that are initialized # with the right shape and filled with NaNs. # 2. During the first step, this method operates on the Nan-filled - # `self.metadata` attribute, and it populates with the actual metadata. + # `metadata` argument, and it populates with the actual metadata. # 3. After the first step, this method operates on the actual metadata. # # In particular, we store the following information in the metadata: @@ -318,8 +314,10 @@ def __call__( # evaluate the dynamics at the final state of the previous step, that matches # the initial state of the current step. # + metadata = kwargs.pop("metadata", {}) + integrator_init = jnp.array( - self.metadata.get(self.InitializingKey, False), dtype=bool + metadata.get(self.InitializingKey, False), dtype=bool ) # Close f over optional kwargs. @@ -335,24 +333,23 @@ def __call__( # The value of dt0 is NaN (or, at least, it should be) only after initialization # and before the first step. - self.metadata["dt0"], self.metadata["dxdt0"] = jax.lax.cond( - pred=("dt0" in self.metadata) - & ~jnp.isnan(self.metadata.get("dt0", 0.0)).any(), + metadata["dt0"], metadata["dxdt0"] = jax.lax.cond( + pred=("dt0" in metadata) & ~jnp.isnan(metadata.get("dt0", 0.0)).any(), true_fun=lambda metadata: ( metadata.get("dt0", jnp.array(0.0, dtype=float)), - self.metadata.get("dxdt0", f(x0, t0)[0]), + metadata.get("dxdt0", f(x0, t0)[0]), ), false_fun=lambda aux: estimate_step_size( x0=x0, t0=t0, f=f, order=p, atol=self.atol, rtol=self.rtol ), - operand=self.metadata, + operand=metadata, ) # Clip the estimated initial step size to the given bounds, if necessary. - self.metadata["dt0"] = jnp.clip( - self.metadata["dt0"], - jnp.minimum(self.dt_min, self.metadata["dt0"]), - jnp.minimum(self.dt_max, self.metadata["dt0"]), + metadata["dt0"] = jnp.clip( + metadata["dt0"], + jnp.minimum(self.dt_min, metadata["dt0"]), + jnp.minimum(self.dt_max, metadata["dt0"]), ) # ========================================================= @@ -364,7 +361,7 @@ def __call__( carry0: Carry = ( x0, jnp.array(t0).astype(float), - self.metadata, + metadata, jnp.array(0, dtype=int), jnp.array(False).astype(bool), ) @@ -392,9 +389,10 @@ def while_loop_body(carry: Carry) -> Carry: # Run the underlying explicit RK integrator. # The output z contains multiple solutions (depending on the rows of b.T). with self.editable(validate=True) as integrator: - integrator.metadata = metadata - z, _ = integrator._compute_next_state(x0=x0, t0=t0, dt=Δt0, **kwargs) - metadata_next = integrator.metadata + z, aux_dict = integrator._compute_next_state( + x0=x0, t0=t0, dt=Δt0, **kwargs + ) + metadata_next = aux_dict["metadata"] # Extract the high-order solution xf and the low-order estimate x̂f. xf = jax.tree.map(lambda l: l[self.row_index_of_solution], z) @@ -510,12 +508,7 @@ def reject_step(): init_val=carry0, ) - # Store the parameters. - # They will be returned to the caller in a functional way in the step method. - with self.mutable_context(mutability=Mutability.MUTABLE): - self.metadata = metadata_tf - - return xf, {} + return xf, {"metadata": metadata_tf} @property def order_of_solution(self) -> int: