Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a gradient checkpoint feature #20720

Closed
wants to merge 4 commits into from
Closed

Conversation

pass-lin
Copy link

@pass-lin pass-lin commented Jan 3, 2025

Gradient checkpoint is a widely used technique to reduce memory consumption.
Now we are adapting it for Keras. To make minimal modifications to existing models, we add a parameter enable_gradient_checkpoint to the layer, which is set to False by default. By simply changing this parameter, we can enable gradient checkpointing. However, for specific implementations depending on different backends, the following points need to be considered:
In the Torch backend, you should ensure that there are no dropout layers or normalization layers (such as BN, LN, GN, etc.) with inconsistent forward and backward behaviors in the layer of the function you're starting.
In the TensorFlow backend, you can only enable this setting in eager mode.
In the JAX backend, you should ensure that there are no strings or other non-differentiable JAX vaild types in the inputs of your function.Such as str

@codecov-commenter
Copy link

codecov-commenter commented Jan 4, 2025

Codecov Report

Attention: Patch coverage is 23.07692% with 10 lines in your changes missing coverage. Please review.

Project coverage is 81.91%. Comparing base (41c429e) to head (487ac08).

Files with missing lines Patch % Lines
keras/src/ops/operation.py 16.66% 9 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20720      +/-   ##
==========================================
- Coverage   81.93%   81.91%   -0.02%     
==========================================
  Files         548      548              
  Lines       51190    51203      +13     
  Branches     7912     7916       +4     
==========================================
+ Hits        41942    41945       +3     
- Misses       7310     7319       +9     
- Partials     1938     1939       +1     
Flag Coverage Δ
keras 81.74% <23.07%> (-0.02%) ⬇️
keras-jax 63.98% <15.38%> (-0.02%) ⬇️
keras-numpy 58.92% <15.38%> (-0.02%) ⬇️
keras-openvino 29.86% <15.38%> (-0.01%) ⬇️
keras-tensorflow 64.67% <23.07%> (-0.02%) ⬇️
keras-torch 64.05% <15.38%> (-0.02%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@fchollet
Copy link
Collaborator

fchollet commented Jan 4, 2025

Thanks for the PR! @divyashreepathihalli has a currently outstanding proposal for implementing gradient checkpointing in Keras across backends. I will you two figure out what to do.

@pass-lin
Copy link
Author

pass-lin commented Jan 5, 2025

Thanks for the PR! @divyashreepathihalli has a currently outstanding proposal for implementing gradient checkpointing in Keras across backends. I will you two figure out what to do.

Thank you. So my understanding is that we have a conflict? The absence of the gradient checkpoints function in Keras is very troublesome for training LLMs. As long as Keras can have this feature as soon as possible, I can follow your arrangement to implement it.

@divyashreepathihalli
Copy link
Collaborator

@pass-lin thank you for the PR, we appreciate your effort to bring rematerialization support to Keras. However, we are working on adding this feature and we are trying to add more fine grained control for enabling rematerialization with a mode parameter - more details here - https://docs.google.com/document/d/199s5kaT7fdqDJ5ryJ15aJJH8QIiLvPBYPpb3ZJgPEsE/edit?tab=t.0#heading=h.lleqmh1k4q6g
Once the initial design is worked out, I can open a contribution issue . If you’re interested, you can take up the issue at that point. I’ll make sure to keep you posted as things progress. Closing this PR for now. Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants