Skip to content

Commit

Permalink
Add test of jit compiling functions taking JaxSimModel as input
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Mar 12, 2024
1 parent b072a1c commit d394a5b
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions tests/test_pytree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import io
from contextlib import redirect_stdout

import jax
import jax.numpy as jnp
import rod.builder.primitives
import rod.urdf.exporter

import jaxsim.api as js


# https://github.com/ami-iit/jaxsim/issues/103
def test_call_jit_compiled_function_passing_different_objects():

# Create on-the-fly a ROD model of a box.
rod_model = (
rod.builder.primitives.BoxBuilder(x=0.3, y=0.2, z=0.1, mass=1.0, name="box")
.build_model()
.add_link()
.add_inertial()
.add_visual()
.add_collision()
.build()
)

# Export the URDF string.
urdf_string = rod.urdf.exporter.UrdfExporter.sdf_to_urdf_string(
sdf=rod_model, pretty=True
)

model1 = js.model.JaxSimModel.build_from_model_description(
model_description=urdf_string,
gravity=jnp.array([0, 0, -10]),
is_urdf=True,
)

model2 = js.model.JaxSimModel.build_from_model_description(
model_description=urdf_string,
gravity=jnp.array([0, 0, -10]),
is_urdf=True,
)

assert model1 == model2
assert hash(model1) == hash(model2)

# If this function has never been compiled by any other test, JAX will
# jit-compile it here.
_ = js.contact.estimate_good_soft_contacts_parameters(model=model1)

# Now JAX should not compile it again.
with jax.log_compiles():
with io.StringIO() as buf, redirect_stdout(buf):
# Beyond running without any JIT recompilations, the following function
# should work on different JaxSimModel objects without raising any errors
# related to the comparison of Static fields.
_ = js.contact.estimate_good_soft_contacts_parameters(model=model2)
stdout = buf.getvalue()

assert (
f"Compiling {js.contact.estimate_good_soft_contacts_parameters.__name__}"
not in stdout
)

0 comments on commit d394a5b

Please sign in to comment.