-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #308 from ami-iit/feature/benchmark
Add benchmark tests for model dynamics and kinematics functions
- Loading branch information
Showing
3 changed files
with
159 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from collections.abc import Callable | ||
|
||
import jax | ||
import pytest | ||
|
||
import jaxsim | ||
import jaxsim.api as js | ||
|
||
|
||
def vectorize_data(model: js.model.JaxSimModel, batch_size: int): | ||
key = jax.random.PRNGKey(seed=0) | ||
|
||
return jax.vmap( | ||
lambda key: js.data.random_model_data( | ||
model=model, | ||
key=key, | ||
) | ||
)(jax.numpy.repeat(key[None, :], repeats=batch_size, axis=0)) | ||
|
||
|
||
def benchmark_test_function( | ||
func: Callable, model: js.model.JaxSimModel, benchmark, batch_size | ||
): | ||
"""Reusability wrapper for benchmark tests.""" | ||
data = vectorize_data(model=model, batch_size=batch_size) | ||
|
||
# Warm-up call to avoid including compilation time | ||
jax.vmap(func, in_axes=(None, 0))(model, data) | ||
benchmark(jax.vmap(func, in_axes=(None, 0)), model, data) | ||
|
||
|
||
@pytest.mark.benchmark | ||
def test_forward_dynamics_aba( | ||
jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size | ||
): | ||
model = jaxsim_model_ergocub_reduced | ||
|
||
benchmark_test_function(js.model.forward_dynamics_aba, model, benchmark, batch_size) | ||
|
||
|
||
@pytest.mark.benchmark | ||
def test_free_floating_bias_forces( | ||
jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size | ||
): | ||
model = jaxsim_model_ergocub_reduced | ||
|
||
benchmark_test_function( | ||
js.model.free_floating_bias_forces, model, benchmark, batch_size | ||
) | ||
|
||
|
||
@pytest.mark.benchmark | ||
def test_forward_kinematics( | ||
jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size | ||
): | ||
model = jaxsim_model_ergocub_reduced | ||
|
||
benchmark_test_function(js.model.forward_kinematics, model, benchmark, batch_size) | ||
|
||
|
||
@pytest.mark.benchmark | ||
def test_free_floating_mass_matrix( | ||
jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size | ||
): | ||
model = jaxsim_model_ergocub_reduced | ||
|
||
benchmark_test_function( | ||
js.model.free_floating_mass_matrix, model, benchmark, batch_size | ||
) | ||
|
||
|
||
@pytest.mark.benchmark | ||
def test_free_floating_jacobian( | ||
jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size | ||
): | ||
model = jaxsim_model_ergocub_reduced | ||
|
||
benchmark_test_function( | ||
js.model.generalized_free_floating_jacobian, model, benchmark, batch_size | ||
) | ||
|
||
|
||
@pytest.mark.benchmark | ||
def test_free_floating_jacobian_derivative( | ||
jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size | ||
): | ||
model = jaxsim_model_ergocub_reduced | ||
|
||
benchmark_test_function( | ||
js.model.generalized_free_floating_jacobian_derivative, | ||
model, | ||
benchmark, | ||
batch_size, | ||
) | ||
|
||
|
||
@pytest.mark.benchmark | ||
def test_soft_contact_model( | ||
jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size | ||
): | ||
model = jaxsim_model_ergocub_reduced | ||
|
||
benchmark_test_function(js.ode.system_dynamics, model, benchmark, batch_size) | ||
|
||
|
||
@pytest.mark.benchmark | ||
def test_rigid_contact_model( | ||
jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size | ||
): | ||
model = jaxsim_model_ergocub_reduced | ||
|
||
with model.editable(validate=False) as model: | ||
model.contact_model = jaxsim.rbda.contacts.RigidContacts() | ||
|
||
benchmark_test_function(js.ode.system_dynamics, model, benchmark, batch_size) | ||
|
||
|
||
@pytest.mark.benchmark | ||
def test_relaxed_rigid_contact_model( | ||
jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size | ||
): | ||
model = jaxsim_model_ergocub_reduced | ||
|
||
with model.editable(validate=False) as model: | ||
model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts() | ||
|
||
benchmark_test_function(js.ode.system_dynamics, model, benchmark, batch_size) |