Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import jaxsim disable jax cache #322

Open
paLeziart opened this issue Dec 30, 2024 · 3 comments · May be fixed by #329
Open

Import jaxsim disable jax cache #322

paLeziart opened this issue Dec 30, 2024 · 3 comments · May be fixed by #329
Assignees

Comments

@paLeziart
Copy link

paLeziart commented Dec 30, 2024

Hello,

First, thank you for providing us with this nice simulator! 👍

As I often encounter long compilation time for some of my jitted functions, I usually fully enable the jax caching system. My compilation time did not seem to be reduced that much while working with jaxsim so I decided to enable jax debug mode to look into it. It looks like something is done when importing the jaxsim module which disables the cache system.

Minimal reproducible example using jax 0.4.38 and jaxsim 0.5.0 with Python 3.10.12 :

import jax
import jax.numpy as jnp
import jaxsim

jax.config.update("jax_logging_level", "DEBUG")
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_persistent_cache_enable_xla_caches", "all")

@jax.jit
def f(x):
  return x + 1

x = jnp.zeros((2, 2))
f(x)

1. With import jaxsim

Running this code outputs the following:

DEBUG:2024-12-30 18:40:21,511:jax._src.dispatch:182: Finished tracing + transforming broadcast_in_dim for pjit in 0.000211716 sec
DEBUG:2024-12-30 18:40:21,511:jax._src.interpreters.pxla:1906: Compiling broadcast_in_dim with global shapes and types [ShapedArray(float64[])]. Argument mapping: (UnspecifiedValue,).
DEBUG:2024-12-30 18:40:21,513:jax._src.dispatch:182: Finished jaxpr to MLIR module conversion jit(broadcast_in_dim) in 0.001822472 sec
DEBUG:2024-12-30 18:40:21,513:jax._src.compiler:167: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CudaDevice(id=0)]]
DEBUG:2024-12-30 18:40:21,513:jax._src.compiler:239: get_compile_options XLA-AutoFDO profile: using XLA-AutoFDO profile version -1
DEBUG:2024-12-30 18:40:21,516:jax._src.compiler:260: Enabling XLA kernel cache at '/tmp/jax_cache/xla_gpu_kernel_cache_file'
DEBUG:2024-12-30 18:40:21,516:jax._src.compiler:265: Enabling XLA autotuning cache at '/tmp/jax_cache/xla_gpu_per_fusion_autotune_cache_dir'
[......]
DEBUG:2024-12-30 18:40:21,542:jax._src.compilation_cache:215: get_executable_and_time: cache is disabled/not initialized
DEBUG:2024-12-30 18:40:21,542:jax._src.compiler:108: PERSISTENT COMPILATION CACHE MISS for 'jit_f' with key 'jit_f-365fa6e25e2cd57ba345f1671fa408f2e3a329174163d89b16be41a02faff8ab'
DEBUG:2024-12-30 18:40:21,566:jax._src.compiler:730: 'jit_f' took at least 0.00 seconds to compile (0.02s)
DEBUG:2024-12-30 18:40:21,566:jax._src.compilation_cache:245: Not writing persistent cache entry with key 'jit_f-365fa6e25e2cd57ba345f1671fa408f2e3a329174163d89b16be41a02faff8ab' since cache is disabled/not initialized
DEBUG:2024-12-30 18:40:21,566:jax._src.dispatch:182: Finished XLA compilation of jit(f) in 0.025295734 sec
DEBUG:2024-12-30 18:40:21,568:jax._src.xla_bridge:983: Clearing JAX backend caches.

Content of /tmp/jax_cache is just xla_gpu_kernel_cache_file

2. Without import jaxsim

Running this code outputs the following:

DEBUG:2024-12-30 18:43:54,326:jax._src.dispatch:182: Finished tracing + transforming convert_element_type for pjit in 0.000194788 sec
DEBUG:2024-12-30 18:43:54,327:jax._src.xla_bridge:599: Discovered path based JAX plugin: jax_plugins.xla_cuda12
DEBUG:2024-12-30 18:43:54,334:jax._src.xla_bridge:608: Discovered entry-point based JAX plugin: jax_plugins.xla_cuda12
DEBUG:2024-12-30 18:43:54,334:jax._src.xla_bridge:614: Loading plugin module jax_plugins.xla_cuda12
DEBUG:2024-12-30 18:43:54,334:jax._src.xla_bridge:722: registering PJRT plugin cuda from /home/.../venv/lib/python3.10/site-packages/jax_plugins/xla_cuda12/xla_cuda_plugin.so
DEBUG:2024-12-30 18:43:54,341:jax._src.xla_bridge:1002: Initializing backend 'cpu'
DEBUG:2024-12-30 18:43:54,345:jax._src.xla_bridge:1014: Backend 'cpu' initialized
DEBUG:2024-12-30 18:43:54,345:jax._src.xla_bridge:1002: Initializing backend 'cuda'
2024-12-30 18:43:54.420083: I external/xla/xla/pjrt/pjrt_c_api_client.cc:127] PjRtCApiClient created.
DEBUG:2024-12-30 18:43:54,420:jax._src.xla_bridge:1014: Backend 'cuda' initialized
DEBUG:2024-12-30 18:43:54,420:jax._src.xla_bridge:1002: Initializing backend 'rocm'
INFO:2024-12-30 18:43:54,420:jax._src.xla_bridge:927: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:2024-12-30 18:43:54,420:jax._src.xla_bridge:1002: Initializing backend 'tpu'
INFO:2024-12-30 18:43:54,420:jax._src.xla_bridge:927: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
DEBUG:2024-12-30 18:43:54,421:jax._src.interpreters.pxla:1906: Compiling convert_element_type with global shapes and types [ShapedArray(float32[])]. Argument mapping: (UnspecifiedValue,).
DEBUG:2024-12-30 18:43:54,427:jax._src.dispatch:182: Finished jaxpr to MLIR module conversion jit(convert_element_type) in 0.005491734 sec
DEBUG:2024-12-30 18:43:54,427:jax._src.compiler:167: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CudaDevice(id=0)]]
DEBUG:2024-12-30 18:43:54,427:jax._src.compiler:239: get_compile_options XLA-AutoFDO profile: using XLA-AutoFDO profile version -1
DEBUG:2024-12-30 18:43:54,438:jax._src.compiler:260: Enabling XLA kernel cache at '/tmp/jax_cache/xla_gpu_kernel_cache_file'
DEBUG:2024-12-30 18:43:54,438:jax._src.compiler:265: Enabling XLA autotuning cache at '/tmp/jax_cache/xla_gpu_per_fusion_autotune_cache_dir'
[......]
DEBUG:2024-12-30 18:43:54,586:jax._src.compiler:108: PERSISTENT COMPILATION CACHE MISS for 'jit_f' with key 'jit_f-8191a120849534e8c93d1e6c86e01fc980f4d427b90ba0b35c62d024bc516630'
DEBUG:2024-12-30 18:43:54,614:jax._src.compiler:730: 'jit_f' took at least 0.00 seconds to compile (0.03s)
DEBUG:2024-12-30 18:43:54,614:jax._src.compilation_cache:263: Writing jit_f to persistent compilation cache with key 'jit_f-8191a120849534e8c93d1e6c86e01fc980f4d427b90ba0b35c62d024bc516630'
DEBUG:2024-12-30 18:43:54,615:jax._src.dispatch:182: Finished XLA compilation of jit(f) in 0.029417753 sec
DEBUG:2024-12-30 18:43:54,617:jax._src.xla_bridge:983: Clearing JAX backend caches.

Content of /tmp/jax_cache is

jit_broadcast_in_dim-515f5de4d4be16bb303ff58dce419c1583f9fb8b5a3d068463f7a423b529866e-atime      jit_f-8191a120849534e8c93d1e6c86e01fc980f4d427b90ba0b35c62d024bc516630-atime
jit_broadcast_in_dim-515f5de4d4be16bb303ff58dce419c1583f9fb8b5a3d068463f7a423b529866e-cache      jit_f-8191a120849534e8c93d1e6c86e01fc980f4d427b90ba0b35c62d024bc516630-cache
jit_convert_element_type-393c76da8a681a68e60cd10eb0dfe57592ffc87668df5cf3d4cda9dff09ff4c9-atime  xla_gpu_kernel_cache_file
jit_convert_element_type-393c76da8a681a68e60cd10eb0dfe57592ffc87668df5cf3d4cda9dff09ff4c9-cache

We can see that when jaxsim is imported, several initialization lines are gone from the debug output at the beginning, and the cache is disabled at the end.

Thank you kindly for your help,
Best,

@traversaro
Copy link
Contributor

Thanks for opening the issue @paLeziart ! Due to the end of year holidays reply may be a bit slow, but we will come back to you for sure in the new year.

fyi @CarlottaSartore @xela-95 @flferretti

@paLeziart
Copy link
Author

Thank you for your quick answer @traversaro !

After looking into it a bit more, it seems a temporary solution for now is simply to import jaxsim after setting the local jax config.

import jax
import jax.numpy as jnp

jax.config.update("jax_logging_level", "DEBUG")
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_persistent_cache_enable_xla_caches", "all")

import jaxsim

@jax.jit
def f(x):
  return x + 1

x = jnp.zeros((2, 2))
f(x)

In this case, the minimal example produces cache files as expected and there is no "cache is disabled/not initialized" debug message. It produces way more compiled files than what I wrote in 2. Without import jaxsim. Since some functions are initialized and ran once during the import jaxsim, I assume that if jaxsim is imported first, then jax detects that the cache options are not set (yet) so it consider them disabled for the remaining of the script (or it creates some sort of internal issues with the caching system).

In the end I'm not sure this really qualifies as an issue from jaxsim. Simply warning the users to set the cache options before importing jaxsim could be a solution.

@flferretti flferretti self-assigned this Jan 7, 2025
@flferretti flferretti linked a pull request Jan 7, 2025 that will close this issue
@flferretti
Copy link
Collaborator

Hi @paLeziart, thanks a lot for bringing this up!

I was able to reproduce the issue and started a PR to solve it at #329

Apparently, it was related to the fact that some JAX arrays were allocated when JaxSim was imported. This probably made the cache invalid or impossible to initialize.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants