Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why is Gradient Checkpointing Not Implemented in Training? #399

Open
4 tasks done
kostum123 opened this issue Nov 5, 2024 · 1 comment · May be fixed by #400
Open
4 tasks done

Why is Gradient Checkpointing Not Implemented in Training? #399

kostum123 opened this issue Nov 5, 2024 · 1 comment · May be fixed by #400
Labels
question Further information is requested

Comments

@kostum123
Copy link

Checks

  • This template is only for question, not feature requests or bug reports.
  • I have thoroughly reviewed the project documentation and read the related paper(s).
  • I have searched for existing issues, including closed ones, no similar questions.
  • I confirm that I am using English to submit this report in order to facilitate communication.

Question details

It appears that gradient checkpointing is not implemented in the current training pipeline. Gradient checkpointing can significantly reduce memory usage by trading off computation, making it valuable for large models and resource-limited environments. This raises the question:

Is there a specific reason for not implementing gradient checkpointing?
If possible, could it be integrated in future updates, or are there known limitations that prevent its integration? If there is no compatibility issue, I would be open to exploring the possibility of adding it via a PR.

@kostum123 kostum123 added the question Further information is requested label Nov 5, 2024
@ZhikangNiu
Copy link
Collaborator

Yeah, I think you can explore the gradient checkpointing in F5 and add it via a PR.

kostum123 added a commit to kostum123/F5-TTS that referenced this issue Nov 5, 2024
Fixes SWivid#399

Implement gradient checkpointing in the training pipeline.

* **Model Backbones**:
  - Import `checkpoint` from `torch.utils.checkpoint` in `src/f5_tts/model/backbones/dit.py`, `src/f5_tts/model/backbones/unett.py`, and `src/f5_tts/model/backbones/mmdit.py`.
  - Add a parameter `use_checkpointing` to the constructors of `DiT`, `UNetT`, and `MMDiT` classes, defaulting to `False`.
  - Modify the `forward` methods to use `checkpoint` for each block if `use_checkpointing` is `True`.

* **Trainer**:
  - Add a parameter `use_checkpointing` to the `Trainer` class constructor in `src/f5_tts/model/trainer.py`, defaulting to `False`.
  - Modify the `train` method to enable gradient checkpointing if `use_checkpointing` is `True`.

* **Training Script**:
  - Add a parameter `use_checkpointing` to the `Trainer` instantiation in `src/f5_tts/train/train.py`, defaulting to `False`.
@kostum123 kostum123 linked a pull request Nov 5, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants