Skip to content

Commit

Permalink
Merge pull request #323 from ami-iit/parametrize_integrators_tests
Browse files Browse the repository at this point in the history
Remove `metadata` attribute from integrators and parametrize simulation tests
  • Loading branch information
flferretti authored Jan 3, 2025
2 parents ccc1038 + 7cdf0f4 commit 1ee249c
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 47 deletions.
19 changes: 8 additions & 11 deletions src/jaxsim/integrators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,6 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
repr=False, hash=False, compare=False, kw_only=True
)

metadata: dict[str, Any] = dataclasses.field(
default_factory=dict, repr=False, hash=False, compare=False, kw_only=True
)

@classmethod
def build(
cls: type[Self],
Expand Down Expand Up @@ -102,10 +98,7 @@ def step(

metadata = metadata if metadata is not None else {}

with self.editable(validate=False) as integrator:
integrator.metadata = metadata

with integrator.mutable_context(mutability=Mutability.MUTABLE):
with self.mutable_context(mutability=Mutability.MUTABLE) as integrator:
xf, metadata_step = integrator(x0, t0, dt, **kwargs)

return (
Expand Down Expand Up @@ -315,6 +308,9 @@ def _compute_next_state(
b = self.b
A = self.A

# Extract metadata from the kwargs.
metadata = kwargs.pop("metadata", {})

# Close f over optional kwargs.
f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)

Expand All @@ -327,7 +323,7 @@ def _compute_next_state(
# or to use the previous state derivative (only integrators supporting FSAL).
def get_ẋ0_and_aux_dict() -> tuple[StateDerivative, dict[str, Any]]:
ẋ0, aux_dict = f(x0, t0)
return self.metadata.get("dxdt0", ẋ0), aux_dict
return metadata.get("dxdt0", ẋ0), aux_dict

# We use a `jax.lax.scan` to compile the `f` function only once.
# Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code
Expand Down Expand Up @@ -381,7 +377,8 @@ def compute_ki() -> tuple[jax.Array, dict[str, Any]]:

# Update the FSAL property for the next iteration.
if self.has_fsal:
self.metadata["dxdt0"] = jax.tree.map(lambda l: l[self.index_of_fsal], K)
# Store the first derivative of the next step in the metadata.
metadata["dxdt0"] = jax.tree.map(lambda l: l[self.index_of_fsal], K)

# Compute the output state.
# Note that z contains as many new states as the rows of `b.T`.
Expand All @@ -394,7 +391,7 @@ def compute_ki() -> tuple[jax.Array, dict[str, Any]]:
lambda xf: self.post_process_state(x0=x0, t0=t0, xf=xf, dt=dt)
)(z)

return z_transformed, aux_dict
return z_transformed, aux_dict | {"metadata": metadata}

@staticmethod
def butcher_tableau_is_valid(
Expand Down
65 changes: 29 additions & 36 deletions src/jaxsim/integrators/variable_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import jaxsim.utils.tracing
from jaxsim import typing as jtp
from jaxsim.utils import Mutability

from .common import (
ExplicitRungeKutta,
Expand Down Expand Up @@ -271,30 +270,27 @@ def init(
# Inject this key to signal that the integrator is initializing.
# This is used to allocate the arrays of the metadata dictionary,
# that are then filled with NaNs.
integrator.metadata = {EmbeddedRungeKutta.InitializingKey: jnp.array(True)}
metadata = {EmbeddedRungeKutta.InitializingKey: jnp.array(True)}

# Run a dummy call of the integrator.
# It is used only to get the metadata so that we know the structure
# of the corresponding pytree.
_ = integrator(
x0, jnp.array(t0, dtype=float), jnp.array(dt, dtype=float), **kwargs
x0,
jnp.array(t0, dtype=float),
jnp.array(dt, dtype=float),
**(kwargs | {"metadata": metadata}),
)

# Remove the injected key.
_ = integrator.metadata.pop(EmbeddedRungeKutta.InitializingKey)
_ = metadata.pop(EmbeddedRungeKutta.InitializingKey)

# Make sure that all leafs of the dictionary are JAX arrays.
# Also, since these are dummy parameters, set them all to NaN.
metadata_after_init = jax.tree.map(
lambda l: jnp.nan * jnp.zeros_like(l), integrator.metadata
lambda l: jnp.nan * jnp.zeros_like(l), metadata
)

# Store the zero parameters in the integrator.
# When the integrator is stepped, this is used to check if the passed
# parameters are valid.
with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
self.metadata = metadata_after_init

return metadata_after_init

def __call__(
Expand All @@ -307,7 +303,7 @@ def __call__(
# The metadata is a dictionary of float JAX arrays, that are initialized
# with the right shape and filled with NaNs.
# 2. During the first step, this method operates on the Nan-filled
# `self.metadata` attribute, and it populates with the actual metadata.
# `metadata` argument, and it populates with the actual metadata.
# 3. After the first step, this method operates on the actual metadata.
#
# In particular, we store the following information in the metadata:
Expand All @@ -318,8 +314,10 @@ def __call__(
# evaluate the dynamics at the final state of the previous step, that matches
# the initial state of the current step.
#
metadata = kwargs.pop("metadata", {})

integrator_init = jnp.array(
self.metadata.get(self.InitializingKey, False), dtype=bool
metadata.get(self.InitializingKey, False), dtype=bool
)

# Close f over optional kwargs.
Expand All @@ -335,24 +333,23 @@ def __call__(

# The value of dt0 is NaN (or, at least, it should be) only after initialization
# and before the first step.
self.metadata["dt0"], self.metadata["dxdt0"] = jax.lax.cond(
pred=("dt0" in self.metadata)
& ~jnp.isnan(self.metadata.get("dt0", 0.0)).any(),
metadata["dt0"], metadata["dxdt0"] = jax.lax.cond(
pred=("dt0" in metadata) & ~jnp.isnan(metadata.get("dt0", 0.0)).any(),
true_fun=lambda metadata: (
metadata.get("dt0", jnp.array(0.0, dtype=float)),
self.metadata.get("dxdt0", f(x0, t0)[0]),
metadata.get("dxdt0", f(x0, t0)[0]),
),
false_fun=lambda aux: estimate_step_size(
x0=x0, t0=t0, f=f, order=p, atol=self.atol, rtol=self.rtol
),
operand=self.metadata,
operand=metadata,
)

# Clip the estimated initial step size to the given bounds, if necessary.
self.metadata["dt0"] = jnp.clip(
self.metadata["dt0"],
jnp.minimum(self.dt_min, self.metadata["dt0"]),
jnp.minimum(self.dt_max, self.metadata["dt0"]),
metadata["dt0"] = jnp.clip(
metadata["dt0"],
jnp.minimum(self.dt_min, metadata["dt0"]),
jnp.minimum(self.dt_max, metadata["dt0"]),
)

# =========================================================
Expand All @@ -364,7 +361,7 @@ def __call__(
carry0: Carry = (
x0,
jnp.array(t0).astype(float),
self.metadata,
metadata,
jnp.array(0, dtype=int),
jnp.array(False).astype(bool),
)
Expand Down Expand Up @@ -392,9 +389,10 @@ def while_loop_body(carry: Carry) -> Carry:
# Run the underlying explicit RK integrator.
# The output z contains multiple solutions (depending on the rows of b.T).
with self.editable(validate=True) as integrator:
integrator.metadata = metadata
z, _ = integrator._compute_next_state(x0=x0, t0=t0, dt=Δt0, **kwargs)
metadata_next = integrator.metadata
z, aux_dict = integrator._compute_next_state(
x0=x0, t0=t0, dt=Δt0, **kwargs
)
metadata_next = aux_dict["metadata"]

# Extract the high-order solution xf and the low-order estimate x̂f.
xf = jax.tree.map(lambda l: l[self.row_index_of_solution], z)
Expand Down Expand Up @@ -481,10 +479,10 @@ def reject_step():
metadata_next,
discarded_steps,
) = jax.lax.cond(
pred=discarded_steps
>= self.max_step_rejections | local_error
<= 1.0 | Δt_next
< self.dt_min | integrator_init,
pred=(discarded_steps >= self.max_step_rejections)
| (local_error <= 1.0)
| (Δt_next < self.dt_min)
| integrator_init,
true_fun=accept_step,
false_fun=reject_step,
)
Expand All @@ -510,12 +508,7 @@ def reject_step():
init_val=carry0,
)

# Store the parameters.
# They will be returned to the caller in a functional way in the step method.
with self.mutable_context(mutability=Mutability.MUTABLE):
self.metadata = metadata_tf

return xf, {}
return xf, {"metadata": metadata_tf}

@property
def order_of_solution(self) -> int:
Expand Down
16 changes: 16 additions & 0 deletions tests/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,19 @@ def run_simulation(
return data


@pytest.mark.parametrize(
"integrator",
[
jaxsim.integrators.fixed_step.ForwardEuler,
jaxsim.integrators.fixed_step.ForwardEulerSO3,
jaxsim.integrators.fixed_step.RungeKutta4,
jaxsim.integrators.fixed_step.RungeKutta4SO3,
jaxsim.integrators.variable_step.BogackiShampineSO3,
],
)
def test_simulation_with_soft_contacts(
jaxsim_model_box: js.model.JaxSimModel,
integrator: jaxsim.integrators.Integrator,
):

model = jaxsim_model_box
Expand All @@ -218,6 +229,11 @@ def test_simulation_with_soft_contacts(
model.kin_dyn_parameters.contact_parameters.enabled = tuple(
enabled_collidable_points_mask.tolist()
)
model.integrator = integrator.build(
dynamics=js.ode.wrap_system_dynamics_for_integration(
system_dynamics=js.ode.system_dynamics
)
)

assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4

Expand Down

0 comments on commit 1ee249c

Please sign in to comment.