-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Import jaxsim disable jax cache #322
Comments
Thanks for opening the issue @paLeziart ! Due to the end of year holidays reply may be a bit slow, but we will come back to you for sure in the new year. |
Thank you for your quick answer @traversaro ! After looking into it a bit more, it seems a temporary solution for now is simply to import jaxsim after setting the local jax config.
In this case, the minimal example produces cache files as expected and there is no "cache is disabled/not initialized" debug message. It produces way more compiled files than what I wrote in In the end I'm not sure this really qualifies as an issue from jaxsim. Simply warning the users to set the cache options before importing jaxsim could be a solution. |
Hi @paLeziart, thanks a lot for bringing this up! I was able to reproduce the issue and started a PR to solve it at #329 Apparently, it was related to the fact that some JAX arrays were allocated when JaxSim was imported. This probably made the cache invalid or impossible to initialize. |
Hello,
First, thank you for providing us with this nice simulator! 👍
As I often encounter long compilation time for some of my jitted functions, I usually fully enable the jax caching system. My compilation time did not seem to be reduced that much while working with jaxsim so I decided to enable jax debug mode to look into it. It looks like something is done when importing the jaxsim module which disables the cache system.
Minimal reproducible example using jax 0.4.38 and jaxsim 0.5.0 with Python 3.10.12 :
1. With import jaxsim
Running this code outputs the following:
Content of
/tmp/jax_cache
is justxla_gpu_kernel_cache_file
2. Without import jaxsim
Running this code outputs the following:
Content of
/tmp/jax_cache
isWe can see that when jaxsim is imported, several initialization lines are gone from the debug output at the beginning, and the cache is disabled at the end.
Thank you kindly for your help,
Best,
The text was updated successfully, but these errors were encountered: