Skip to content

Commit

Permalink
Refactor metadata handling in variable step integrators
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Dec 9, 2024
1 parent 0ba4c7b commit 8907aaf
Showing 1 changed file with 25 additions and 32 deletions.
57 changes: 25 additions & 32 deletions src/jaxsim/integrators/variable_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import jaxsim.utils.tracing
from jaxsim import typing as jtp
from jaxsim.utils import Mutability

from .common import (
ExplicitRungeKutta,
Expand Down Expand Up @@ -271,30 +270,27 @@ def init(
# Inject this key to signal that the integrator is initializing.
# This is used to allocate the arrays of the metadata dictionary,
# that are then filled with NaNs.
integrator.metadata = {EmbeddedRungeKutta.InitializingKey: jnp.array(True)}
metadata = {EmbeddedRungeKutta.InitializingKey: jnp.array(True)}

# Run a dummy call of the integrator.
# It is used only to get the metadata so that we know the structure
# of the corresponding pytree.
_ = integrator(
x0, jnp.array(t0, dtype=float), jnp.array(dt, dtype=float), **kwargs
x0,
jnp.array(t0, dtype=float),
jnp.array(dt, dtype=float),
**(kwargs | {"metadata": metadata}),
)

# Remove the injected key.
_ = integrator.metadata.pop(EmbeddedRungeKutta.InitializingKey)
_ = metadata.pop(EmbeddedRungeKutta.InitializingKey)

# Make sure that all leafs of the dictionary are JAX arrays.
# Also, since these are dummy parameters, set them all to NaN.
metadata_after_init = jax.tree.map(
lambda l: jnp.nan * jnp.zeros_like(l), integrator.metadata
lambda l: jnp.nan * jnp.zeros_like(l), metadata
)

# Store the zero parameters in the integrator.
# When the integrator is stepped, this is used to check if the passed
# parameters are valid.
with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
self.metadata = metadata_after_init

return metadata_after_init

def __call__(
Expand All @@ -307,7 +303,7 @@ def __call__(
# The metadata is a dictionary of float JAX arrays, that are initialized
# with the right shape and filled with NaNs.
# 2. During the first step, this method operates on the Nan-filled
# `self.metadata` attribute, and it populates with the actual metadata.
# `metadata` argument, and it populates with the actual metadata.
# 3. After the first step, this method operates on the actual metadata.
#
# In particular, we store the following information in the metadata:
Expand All @@ -318,8 +314,10 @@ def __call__(
# evaluate the dynamics at the final state of the previous step, that matches
# the initial state of the current step.
#
metadata = kwargs.pop("metadata", {})

integrator_init = jnp.array(
self.metadata.get(self.InitializingKey, False), dtype=bool
metadata.get(self.InitializingKey, False), dtype=bool
)

# Close f over optional kwargs.
Expand All @@ -335,24 +333,23 @@ def __call__(

# The value of dt0 is NaN (or, at least, it should be) only after initialization
# and before the first step.
self.metadata["dt0"], self.metadata["dxdt0"] = jax.lax.cond(
pred=("dt0" in self.metadata)
& ~jnp.isnan(self.metadata.get("dt0", 0.0)).any(),
metadata["dt0"], metadata["dxdt0"] = jax.lax.cond(
pred=("dt0" in metadata) & ~jnp.isnan(metadata.get("dt0", 0.0)).any(),
true_fun=lambda metadata: (
metadata.get("dt0", jnp.array(0.0, dtype=float)),
self.metadata.get("dxdt0", f(x0, t0)[0]),
metadata.get("dxdt0", f(x0, t0)[0]),
),
false_fun=lambda aux: estimate_step_size(
x0=x0, t0=t0, f=f, order=p, atol=self.atol, rtol=self.rtol
),
operand=self.metadata,
operand=metadata,
)

# Clip the estimated initial step size to the given bounds, if necessary.
self.metadata["dt0"] = jnp.clip(
self.metadata["dt0"],
jnp.minimum(self.dt_min, self.metadata["dt0"]),
jnp.minimum(self.dt_max, self.metadata["dt0"]),
metadata["dt0"] = jnp.clip(
metadata["dt0"],
jnp.minimum(self.dt_min, metadata["dt0"]),
jnp.minimum(self.dt_max, metadata["dt0"]),
)

# =========================================================
Expand All @@ -364,7 +361,7 @@ def __call__(
carry0: Carry = (
x0,
jnp.array(t0).astype(float),
self.metadata,
metadata,
jnp.array(0, dtype=int),
jnp.array(False).astype(bool),
)
Expand Down Expand Up @@ -392,9 +389,10 @@ def while_loop_body(carry: Carry) -> Carry:
# Run the underlying explicit RK integrator.
# The output z contains multiple solutions (depending on the rows of b.T).
with self.editable(validate=True) as integrator:
integrator.metadata = metadata
z, _ = integrator._compute_next_state(x0=x0, t0=t0, dt=Δt0, **kwargs)
metadata_next = integrator.metadata
z, aux_dict = integrator._compute_next_state(
x0=x0, t0=t0, dt=Δt0, **kwargs
)
metadata_next = aux_dict["metadata"]

# Extract the high-order solution xf and the low-order estimate x̂f.
xf = jax.tree.map(lambda l: l[self.row_index_of_solution], z)
Expand Down Expand Up @@ -510,12 +508,7 @@ def reject_step():
init_val=carry0,
)

# Store the parameters.
# They will be returned to the caller in a functional way in the step method.
with self.mutable_context(mutability=Mutability.MUTABLE):
self.metadata = metadata_tf

return xf, {}
return xf, {"metadata": metadata_tf}

@property
def order_of_solution(self) -> int:
Expand Down

0 comments on commit 8907aaf

Please sign in to comment.