-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a gradient checkpoint feature #20720
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #20720 +/- ##
==========================================
- Coverage 81.93% 81.91% -0.02%
==========================================
Files 548 548
Lines 51190 51203 +13
Branches 7912 7916 +4
==========================================
+ Hits 41942 41945 +3
- Misses 7310 7319 +9
- Partials 1938 1939 +1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
Thanks for the PR! @divyashreepathihalli has a currently outstanding proposal for implementing gradient checkpointing in Keras across backends. I will you two figure out what to do. |
Thank you. So my understanding is that we have a conflict? The absence of the gradient checkpoints function in Keras is very troublesome for training LLMs. As long as Keras can have this feature as soon as possible, I can follow your arrangement to implement it. |
@pass-lin thank you for the PR, we appreciate your effort to bring rematerialization support to Keras. However, we are working on adding this feature and we are trying to add more fine grained control for enabling rematerialization with a mode parameter - more details here - https://docs.google.com/document/d/199s5kaT7fdqDJ5ryJ15aJJH8QIiLvPBYPpb3ZJgPEsE/edit?tab=t.0#heading=h.lleqmh1k4q6g |
Gradient checkpoint is a widely used technique to reduce memory consumption.
Now we are adapting it for Keras. To make minimal modifications to existing models, we add a parameter
enable_gradient_checkpoint
to the layer, which is set toFalse
by default. By simply changing this parameter, we can enable gradient checkpointing. However, for specific implementations depending on different backends, the following points need to be considered:In the Torch backend, you should ensure that there are no dropout layers or normalization layers (such as BN, LN, GN, etc.) with inconsistent forward and backward behaviors in the layer of the function you're starting.
In the TensorFlow backend, you can only enable this setting in eager mode.
In the JAX backend, you should ensure that there are no strings or other non-differentiable JAX vaild types in the inputs of your function.Such as str