-
Notifications
You must be signed in to change notification settings - Fork 266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory Requirements of PPO Example #123
Comments
Hi Michael, Yes, jax PPO should work fine on such a card. A few things worth trying:
Let us know if that gets you unstuck. |
Hi Erik, Using Do you have an idea why it wants to allocate additional memory instead of using the original 90% which is more than enough? Using no pre-allocation, I got it to run. However, I am experiencing very long compile times of 17 minutes for Half-cheetah. The compile times are spent when executing the jitted reset function for the first time, e.g. reset train env and reset eval env. From my understanding compiling these functions should not take a lot of time as it just samples a joint configuration. Furthermore, the compile-time depends on the number of environments. My unscientific measuring was
Do you have an idea why the jit of the reset functions takes so much time? |
Hi, Previously I was able to jit and run a Unfortunately I'm not able to find out what the previous version of brax was, but it was from around a month ago.
Full listing to reproduce:
|
I updated to the new Brax Version 0.0.8 from 0.0.7 and now all the compile-time and memory problems are gone. The half-cheetah compiles within 20s and the memory error is gone. @erikfrey Could you comment on what you changed between the versions? Just from looking at the recent commit logs, I could not figure out what could have resolved these issues? |
That's great to hear! We made some major changes to the e1a8faf#diff-d5809d1d70b284727c83d435055073c0de6aa3a6a414ca00b6e24ba8756fcd5eR83 The old code iterated through the kinematic tree using a for loop, which forced JAX to unroll the creation of the initial state as a set of operations over giant literal constants embedded in the generated code (scaled, as you saw, by the number of environments... more environments means larger constants embedded into the generated code). The new approach liberally uses |
Thanks, this greatly improved the experience and resolved all issues! |
I am trying to run the Brax PPO example locally, but I am experiencing Cuda out-of-memory errors. For simple environments such as reacher everything works fine. For half-cheetah and ant, I am experiencing out-of-memory errors. I presume that the required memory is proportional to the number of environments. However, even when setting the environment to 1, I get an out-of-memory error.
These errors are surprising to me as I am using an RTX 3090 with 24gb RAM, which is identical to the K80 mentioned in #49 that runs ant in the collab. Therefore, I am wondering what component affects the GPU memory the most and is it possible to reduce the GPU memory needs?
The text was updated successfully, but these errors were encountered: