From 063db8ffe9d116c51994fbfc8b42672468fe2d3c Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 5 Dec 2024 15:54:46 +0100 Subject: [PATCH 1/6] Add benchmark tests for model dynamics and kinematics functions --- tests/test_benchmark.py | 141 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 tests/test_benchmark.py diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py new file mode 100644 index 000000000..32cb8b0e7 --- /dev/null +++ b/tests/test_benchmark.py @@ -0,0 +1,141 @@ +import jax +import pytest + +import jaxsim +import jaxsim.api as js + + +def vectorize_data(model: js.model.JaxSimModel, batch_size: str): + key = jax.random.PRNGKey(seed=0) + + return jax.vmap( + lambda key: js.data.random_model_data( + model=model, + key=key, + ) + )(jax.numpy.repeat(key[None, :], repeats=batch_size, axis=0)) + + +def benchmark_test_function(func, model, benchmark, batch_size=None): + """Reusability wrapper for benchmark tests.""" + if batch_size is None: + # Phase 1: Run without batch size + data = js.data.random_model_data(model=model) + + # Warm-up call to avoid including compilation time + func(model, data) + benchmark(func, model, data) + else: + # Phase 2: Run with batch size + data = vectorize_data(model=model, batch_size=batch_size) + + # Warm-up call to avoid including compilation time + jax.vmap(func, in_axes=(None, 0))(model, data) + benchmark(jax.vmap(func, in_axes=(None, 0)), model, data) + + +@pytest.mark.benchmark +@pytest.mark.parametrize("batch_size", [None, 1024]) +def test_forward_dynamics_aba( + jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size +): + model = jaxsim_model_ergocub_reduced + + benchmark_test_function(js.model.forward_dynamics_aba, model, benchmark, batch_size) + + +@pytest.mark.benchmark +@pytest.mark.parametrize("batch_size", [None, 1024]) +def test_free_floating_bias_forces( + jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size +): + model = jaxsim_model_ergocub_reduced + + benchmark_test_function( + js.model.free_floating_bias_forces, model, benchmark, batch_size + ) + + +@pytest.mark.benchmark +@pytest.mark.parametrize("batch_size", [None, 1024]) +def test_forward_kinematics( + jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size +): + model = jaxsim_model_ergocub_reduced + + benchmark_test_function(js.model.forward_kinematics, model, benchmark, batch_size) + + +@pytest.mark.benchmark +@pytest.mark.parametrize("batch_size", [None, 1024]) +def test_free_floating_mass_matrix( + jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size +): + model = jaxsim_model_ergocub_reduced + + benchmark_test_function( + js.model.free_floating_mass_matrix, model, benchmark, batch_size + ) + + +@pytest.mark.benchmark +@pytest.mark.parametrize("batch_size", [None, 1024]) +def test_free_floating_jacobian( + jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size +): + model = jaxsim_model_ergocub_reduced + + benchmark_test_function( + js.model.generalized_free_floating_jacobian, model, benchmark, batch_size + ) + + +@pytest.mark.benchmark +@pytest.mark.parametrize("batch_size", [None, 1024]) +def test_free_floating_jacobian_derivative( + jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size +): + model = jaxsim_model_ergocub_reduced + + benchmark_test_function( + js.model.generalized_free_floating_jacobian_derivative, + model, + benchmark, + batch_size, + ) + + +@pytest.mark.benchmark +@pytest.mark.parametrize("batch_size", [None, 1024]) +def test_soft_contact_model( + jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size +): + model = jaxsim_model_ergocub_reduced + + benchmark_test_function(js.ode.system_dynamics, model, benchmark, batch_size) + + +@pytest.mark.benchmark +@pytest.mark.parametrize("batch_size", [None, 1024]) +def test_rigid_contact_model( + jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size +): + model = jaxsim_model_ergocub_reduced + + with model.editable(validate=False) as model: + model.contact_model = jaxsim.rbda.contacts.RigidContacts() + + benchmark_test_function(js.ode.system_dynamics, model, benchmark, batch_size) + + +@pytest.mark.benchmark +@pytest.mark.parametrize("batch_size", [None, 1024]) +def test_relaxed_rigid_contact_model( + jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size +): + model = jaxsim_model_ergocub_reduced + + with model.editable(validate=False) as model: + model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts() + + benchmark_test_function(js.ode.system_dynamics, model, benchmark, batch_size) From 750a34c60b5c4183cf7f29a338196dae529bcdde Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 5 Dec 2024 16:02:54 +0100 Subject: [PATCH 2/6] Add `pytest-benchmark` to testing dependencies --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index db6f8ea6a..2d1367caf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,7 @@ style = [ testing = [ "idyntree >= 12.2.1", "pytest >=6.0", + "pytest-benchmark", "pytest-icdiff", "robot-descriptions" ] @@ -241,6 +242,7 @@ idyntree = "*" isort = "*" pre-commit = "*" pytest = "*" +pytest-benchmark = "*" pytest-icdiff = "*" robot_descriptions = "*" From 4bdaab98e8ca708995b05226a191747f392bb389 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 5 Dec 2024 16:03:10 +0100 Subject: [PATCH 3/6] Update pytest configuration to skip benchmark tests --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2d1367caf..cfae9020d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,7 +122,7 @@ multi_line_output = 3 profile = "black" [tool.pytest.ini_options] -addopts = "-rsxX -v --strict-markers" +addopts = "-rsxX -v --strict-markers --benchmark-skip" minversion = "6.0" testpaths = [ "tests", From 10e16d2c8aadf0e6f0a19abf5a7945046d59cfd2 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 5 Dec 2024 16:03:36 +0100 Subject: [PATCH 4/6] Refactor test `pixi` tasks to include separate benchmark command --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index cfae9020d..5fc968178 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -234,7 +234,8 @@ jaxsim = { path = "./", editable = true } [tool.pixi.feature.test.tasks] pipcheck = "pip check" -tests = { cmd = "pytest", depends_on = ["pipcheck"] } +test = { cmd = "pytest", depends_on = ["pipcheck"] } +benchmark = { cmd = "pytest --benchmark-only", depends_on = ["pipcheck"] } [tool.pixi.feature.test.dependencies] black = "24.*" From 64b35b8379d6821da6fdcc0d09e8f8db3244e499 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 5 Dec 2024 16:09:31 +0100 Subject: [PATCH 5/6] Add `batch_size` option for vectorized benchmarks --- tests/conftest.py | 14 ++++++++++++++ tests/test_benchmark.py | 34 ++++++++++------------------------ 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 746fdbf7b..37f5c26e7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,6 +20,20 @@ def pytest_addoption(parser): help="Run tests only if GPU is available and utilized", ) + parser.addoption( + "--batch-size", + action="store", + default="1", + help="Batch size for vectorized benchmarks (only applies to benchmark tests)", + ) + + +def pytest_generate_tests(metafunc): + if "batch_size" in metafunc.fixturenames: + metafunc.parametrize( + "batch_size", [1, int(metafunc.config.getoption("--batch-size"))] + ) + def check_gpu_usage(): # Set environment variable to prioritize GPU. diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 32cb8b0e7..089c78b4e 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -1,3 +1,5 @@ +from collections.abc import Callable + import jax import pytest @@ -5,7 +7,7 @@ import jaxsim.api as js -def vectorize_data(model: js.model.JaxSimModel, batch_size: str): +def vectorize_data(model: js.model.JaxSimModel, batch_size: int): key = jax.random.PRNGKey(seed=0) return jax.vmap( @@ -16,26 +18,18 @@ def vectorize_data(model: js.model.JaxSimModel, batch_size: str): )(jax.numpy.repeat(key[None, :], repeats=batch_size, axis=0)) -def benchmark_test_function(func, model, benchmark, batch_size=None): +def benchmark_test_function( + func: Callable, model: js.model.JaxSimModel, benchmark, batch_size +): """Reusability wrapper for benchmark tests.""" - if batch_size is None: - # Phase 1: Run without batch size - data = js.data.random_model_data(model=model) - - # Warm-up call to avoid including compilation time - func(model, data) - benchmark(func, model, data) - else: - # Phase 2: Run with batch size - data = vectorize_data(model=model, batch_size=batch_size) + data = vectorize_data(model=model, batch_size=batch_size) - # Warm-up call to avoid including compilation time - jax.vmap(func, in_axes=(None, 0))(model, data) - benchmark(jax.vmap(func, in_axes=(None, 0)), model, data) + # Warm-up call to avoid including compilation time + jax.vmap(func, in_axes=(None, 0))(model, data) + benchmark(jax.vmap(func, in_axes=(None, 0)), model, data) @pytest.mark.benchmark -@pytest.mark.parametrize("batch_size", [None, 1024]) def test_forward_dynamics_aba( jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size ): @@ -45,7 +39,6 @@ def test_forward_dynamics_aba( @pytest.mark.benchmark -@pytest.mark.parametrize("batch_size", [None, 1024]) def test_free_floating_bias_forces( jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size ): @@ -57,7 +50,6 @@ def test_free_floating_bias_forces( @pytest.mark.benchmark -@pytest.mark.parametrize("batch_size", [None, 1024]) def test_forward_kinematics( jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size ): @@ -67,7 +59,6 @@ def test_forward_kinematics( @pytest.mark.benchmark -@pytest.mark.parametrize("batch_size", [None, 1024]) def test_free_floating_mass_matrix( jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size ): @@ -79,7 +70,6 @@ def test_free_floating_mass_matrix( @pytest.mark.benchmark -@pytest.mark.parametrize("batch_size", [None, 1024]) def test_free_floating_jacobian( jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size ): @@ -91,7 +81,6 @@ def test_free_floating_jacobian( @pytest.mark.benchmark -@pytest.mark.parametrize("batch_size", [None, 1024]) def test_free_floating_jacobian_derivative( jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size ): @@ -106,7 +95,6 @@ def test_free_floating_jacobian_derivative( @pytest.mark.benchmark -@pytest.mark.parametrize("batch_size", [None, 1024]) def test_soft_contact_model( jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size ): @@ -116,7 +104,6 @@ def test_soft_contact_model( @pytest.mark.benchmark -@pytest.mark.parametrize("batch_size", [None, 1024]) def test_rigid_contact_model( jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size ): @@ -129,7 +116,6 @@ def test_rigid_contact_model( @pytest.mark.benchmark -@pytest.mark.parametrize("batch_size", [None, 1024]) def test_relaxed_rigid_contact_model( jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size ): From bb45c2cb017ec19817abfdeed2220e9d5f881a26 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 6 Dec 2024 11:05:11 +0100 Subject: [PATCH 6/6] Update batch size option handling in pytest configuration --- tests/conftest.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 37f5c26e7..747e00a95 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,16 +23,17 @@ def pytest_addoption(parser): parser.addoption( "--batch-size", action="store", - default="1", + default="None", help="Batch size for vectorized benchmarks (only applies to benchmark tests)", ) def pytest_generate_tests(metafunc): - if "batch_size" in metafunc.fixturenames: - metafunc.parametrize( - "batch_size", [1, int(metafunc.config.getoption("--batch-size"))] - ) + if ( + "batch_size" in metafunc.fixturenames + and (batch_size := metafunc.config.getoption("--batch-size")) != "None" + ): + metafunc.parametrize("batch_size", [1, int(batch_size)]) def check_gpu_usage(): @@ -123,6 +124,18 @@ def velocity_representation(request) -> jaxsim.VelRepr: return request.param +@pytest.fixture(scope="session") +def batch_size(request) -> int: + """ + Fixture providing the batch size for vectorized benchmarks. + + Returns: + The batch size for vectorized benchmarks. + """ + + return 1 + + # ================================ # Fixtures providing JaxSim models # ================================