Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory Requirements of PPO Example #123

Closed
milutter opened this issue Nov 23, 2021 · 6 comments
Closed

Memory Requirements of PPO Example #123

milutter opened this issue Nov 23, 2021 · 6 comments
Labels
question Further information is requested

Comments

@milutter
Copy link

I am trying to run the Brax PPO example locally, but I am experiencing Cuda out-of-memory errors. For simple environments such as reacher everything works fine. For half-cheetah and ant, I am experiencing out-of-memory errors. I presume that the required memory is proportional to the number of environments. However, even when setting the environment to 1, I get an out-of-memory error.

RuntimeError: INTERNAL: Failed to launch CUDA kernel: fusion_29 with block dimensions: 128x1x1 and grid dimensions: 1x1x1: CUDA_ERROR_OUT_OF_MEMORY: out of memory

These errors are surprising to me as I am using an RTX 3090 with 24gb RAM, which is identical to the K80 mentioned in #49 that runs ant in the collab. Therefore, I am wondering what component affects the GPU memory the most and is it possible to reduce the GPU memory needs?

@erikfrey
Copy link
Collaborator

Hi Michael,

Yes, jax PPO should work fine on such a card. A few things worth trying:

  1. Can you please grab the brax in main rather than pypi (pip install git+https://github.com/google/brax.git@main), as it contains some fixes specifically for GPU memory usage, that have not made it into pypi yet.
  2. I've noticed that frameworks (JAX included) try to grab big continuous slabs of memory in a way that CUDA cannot satisfy, especially if two things are running on the GPU at once. can you run nvidia-smi to confirm that nothing else is running on the card? If you see something else occupying memory on the card, you may want to stop that other process.
  3. Alternatively, if neither of the above fixes the problem, you can ask JAX to disable its default allocation policy (I think upon first allocation it tries to grab 90% of the free memory on the GPU) by playing with some combination of the XLA_PYTHON_CLIENT_ALLOCATOR and XLA_PYTHON_CLIENT_PREALLOCATE flags, see info here: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

Let us know if that gets you unstuck.

@erikfrey erikfrey added the question Further information is requested label Nov 24, 2021
@milutter
Copy link
Author

milutter commented Nov 25, 2021

Hi Erik,
thank you so much for your help.

Using XLA_PYTHON_CLIENT_MEM_FRACTION & XLA_PYTHON_CLIENT_PREALLOCATE, I have been getting it to work and can describe the problem in more detail. However, I don't know why it is happening. Without pre-allocation and 1024 env + 128 eval env the total GPU memory is 8Gb. With the standard JAX settings, it initially allocates 90% but immediately 'believes' that the memory is exhausted and requests additional memory which then triggers the out-of-memory error. When one sets the MEM_FRACTION to 0.5, one can nicely observe that, that the 50% is directly reserved but then additional memory is immediately requested.

Do you have an idea why it wants to allocate additional memory instead of using the original 90% which is more than enough?
Is there some multi-processing involved that might cause it?

Using no pre-allocation, I got it to run. However, I am experiencing very long compile times of 17 minutes for Half-cheetah. The compile times are spent when executing the jitted reset function for the first time, e.g. reset train env and reset eval env. From my understanding compiling these functions should not take a lot of time as it just samples a joint configuration. Furthermore, the compile-time depends on the number of environments. My unscientific measuring was

1 2 4 8 16 32
Reset Train Env 7.6s 10.0s 10.0s 16.0s 16.7s 150s
Reset Eval Env 6.1s 7.2s 8.2s 42.4s 42.4s 887s

Do you have an idea why the jit of the reset functions takes so much time?
Does it trigger something from the lazy execution from JAX? (As far as I saw it does not call step)
Do you have an idea why the compile-time scales so bad with the number of environments?
Can it be that some underlying installation issues of Cuda / CuDNN are causing these compile times?

@C-J-Cundy
Copy link

Hi,
I'm also experiencing the same issue after updating brax today.

Previously I was able to jit and run a vmapped set of 4,000 random rollouts of the halfcheetah in a few tens of seconds, but now I get a memory error. The oom is avoided if add the PREALLOCATE and MEM_FRACTION but the compilation time is incredibly long. This has slowed down my code by a huge amount.

Unfortunately I'm not able to find out what the previous version of brax was, but it was from around a month ago.
Versions of packages:

jax                       0.2.25                   pypi_0    pypi
jaxlib                    0.1.73+cuda11.cudnn805          pypi_0    pypi
brax                      0.0.7                    pypi_0    pypi
Full listing to reproduce:
from jax._src.api import value_and_grad
from jax._src.tree_util import tree_map, tree_multimap, tree_flatten


from jax import jit, vmap, lax
import jax.numpy as jnp
import jax.random as rnd
from typing import cast, Tuple

from brax import envs
import jax.random as rnd
from brax.envs.env import State
from typing import cast
import jax.numpy as jnp

from functools import partial
from jax._src.prng import PRNGKeyArray


T = 200
action_lim = 1

env_name = "halfcheetah"  # @param ['ant', 'humanoid', 'fetch', 'grasp', 'halfcheetah', 'ur5e', 'reacher']
env_fn = envs.create_fn(env_name=env_name)
env = env_fn()
action_dim = env.action_size
obs_dim = env.observation_size
n_initial_trials = 8


def rollout_random(rng_key: jnp.ndarray, T: int):
    init_s = env.reset(rng_key)
    rng_keys = rnd.split(rng_key, T)

    def inner(state, rng_key):
        rng_key_1, _ = rnd.split(rng_key, 2)
        action = (rnd.uniform(rng_key_1, shape=(action_dim,)) - 0.5) * 2 * action_lim
        action = cast(jnp.ndarray, action)
        next_state: State = env.step(state, action)
        outputs = (next_state.obs, action, next_state.reward, next_state.done, state)  # type: ignore
        return next_state, outputs

    _, outputs = lax.scan(inner, init_s, (rng_keys), unroll=1)
    return outputs


@partial(jit, static_argnames=("N", "T"))
def get_random_rollouts(
    rng_key: PRNGKeyArray, N: int, T: int
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, State]:
    return vmap(rollout_random, in_axes=(0, None))(rnd.split(rng_key, N), T)  # type: ignore


print(f"Rolling out {n_initial_trials} random inital trajectories")
obs, actions, rewards, _, full_states = get_random_rollouts(
    rnd.PRNGKey(0), n_initial_trials, T
)

@milutter
Copy link
Author

I updated to the new Brax Version 0.0.8 from 0.0.7 and now all the compile-time and memory problems are gone. The half-cheetah compiles within 20s and the memory error is gone.

@erikfrey Could you comment on what you changed between the versions? Just from looking at the recent commit logs, I could not figure out what could have resolved these issues?

@erikfrey
Copy link
Collaborator

That's great to hear! We made some major changes to the default_qp function here:

e1a8faf#diff-d5809d1d70b284727c83d435055073c0de6aa3a6a414ca00b6e24ba8756fcd5eR83

The old code iterated through the kinematic tree using a for loop, which forced JAX to unroll the creation of the initial state as a set of operations over giant literal constants embedded in the generated code (scaled, as you saw, by the number of environments... more environments means larger constants embedded into the generated code).

The new approach liberally uses jax.lax.scan so that the initial state isn't emitted as giant blobs of data in the generated code, but instead constructed on device on the fly.

@milutter
Copy link
Author

Thanks, this greatly improved the experience and resolved all issues!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants