-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor gradient checkpointing #10611
base: main
Are you sure you want to change the base?
Conversation
a thought i have had is that we can invert the condition so that we "dont checkpoint" instead of "do checkpoint" which allows increasing memory use more finely with larger intervals |
Yep, a number of variations to try out. To not limit how to apply checkpointing, it's best to provide that control to the user, so the idea you mention should be possible to use here too. LMK if you expected something different or want to implement with a better design |
are the numbers correct?
why is no gradient checkpoint has highest memory and lowest throughput? |
@yiyixuxu They were incorrect, my bad. I thought I was doing the warmup correctly for the first backward call, but it was not correct. Looking at the profiles revealed that the backward pass kernel launch was not warmed up. Updated the example so that it happens now before the "No Gradient Checkpointing" benchmark, and we're now seeing the correct numbers "No Gradient Checkpointing" has the highest memory usage because it has to keep ALL intermediate activation tensors in memory, whereas the gradient checkpointing ones only have to save a copy of the inputs at each layer where is is applied (and intermediate activation tensors between layers if we skipping a few blocks). The memory part was correct in previous version of the code as well. The incorrect part was the reported |
yes this looks great, and will be a very useful addition. thank you. would this approach also work for SDXL? tracking all the checkpoints there was hard for me and i ended up monkey patching the checkpoint call in a way i'm not proud of in order to "make it work" easily. |
if we just want to allow user to skip certain layers, can we go with a simpler solution? transformer.enable_gradient_checkpointing(skip_layers= []) then in the code for i, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing and i not in self.skip_layers:
... |
Skipping layers is one use case. The idea is to not limit how the checkpointing is applied and just give the control to the user (with us also providing sensible default behaviour of torch.nn.utils.checkpoint) - for example, using a different provider for checkpointing, such as deepspeed, instead of monkey patching the forward pass or using other intrusive solution, or allowing use of custom checkpoint implementations that can perform CPU offloading of stored inputs and retrieve them back when required for recomputation |
@yiyixuxu LMK if you think the current changes look good and I'll propagate and add tests for all other models to make sure this works as expected Edit: oh, just saw your message - we commented at the same time. I'll work on finishing this up |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Fixes #10124.
When finetuning, we currently apply gradient checkpointing to each transformer block. This works wonders for saving memory but can lead to slower throughput. To improve throughput, at the cost of slightly higher memory usage, an acceptable compromise can be made by only checkpointing certain blocks, or by applying a different checkpointing strategy.
This PR will try to refactor how we do gradient checkpointing to enable users to use their own checkpointing functions/strategies. Currently, only LTXVideo has been updated to gather initial feedback on the changes made. If all looks well, will update all the other modeling implementations.
Benchmark
Additonal context: #9982 (comment)
cc @bghira