diff --git a/src/jaxsim/exceptions.py b/src/jaxsim/exceptions.py index 28f644a47..587650c0a 100644 --- a/src/jaxsim/exceptions.py +++ b/src/jaxsim/exceptions.py @@ -21,9 +21,10 @@ def raise_if( **kwargs: The keyword arguments to fill the format string """ - # Disable host callback if running on TPU. - if jax.devices()[0].platform == "tpu" or os.environ.get( - "JAXSIM_DISABLE_EXCEPTIONS", 0 + # Disable host callback if running on unsupported hardware or if the user + # explicitly disabled it. + if jax.devices()[0].platform in {"tpu", "METAL"} or os.environ.get( + "JAXSIM_DISABLE_EXCEPTIONS", "0" ): return