You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This template is only for question, not feature requests or bug reports.
I have thoroughly reviewed the project documentation and read the related paper(s).
I have searched for existing issues, including closed ones, no similar questions.
I confirm that I am using English to submit this report in order to facilitate communication.
Question details
It appears that gradient checkpointing is not implemented in the current training pipeline. Gradient checkpointing can significantly reduce memory usage by trading off computation, making it valuable for large models and resource-limited environments. This raises the question:
Is there a specific reason for not implementing gradient checkpointing?
If possible, could it be integrated in future updates, or are there known limitations that prevent its integration? If there is no compatibility issue, I would be open to exploring the possibility of adding it via a PR.
The text was updated successfully, but these errors were encountered:
FixesSWivid#399
Implement gradient checkpointing in the training pipeline.
* **Model Backbones**:
- Import `checkpoint` from `torch.utils.checkpoint` in `src/f5_tts/model/backbones/dit.py`, `src/f5_tts/model/backbones/unett.py`, and `src/f5_tts/model/backbones/mmdit.py`.
- Add a parameter `use_checkpointing` to the constructors of `DiT`, `UNetT`, and `MMDiT` classes, defaulting to `False`.
- Modify the `forward` methods to use `checkpoint` for each block if `use_checkpointing` is `True`.
* **Trainer**:
- Add a parameter `use_checkpointing` to the `Trainer` class constructor in `src/f5_tts/model/trainer.py`, defaulting to `False`.
- Modify the `train` method to enable gradient checkpointing if `use_checkpointing` is `True`.
* **Training Script**:
- Add a parameter `use_checkpointing` to the `Trainer` instantiation in `src/f5_tts/train/train.py`, defaulting to `False`.
Checks
Question details
It appears that gradient checkpointing is not implemented in the current training pipeline. Gradient checkpointing can significantly reduce memory usage by trading off computation, making it valuable for large models and resource-limited environments. This raises the question:
Is there a specific reason for not implementing gradient checkpointing?
If possible, could it be integrated in future updates, or are there known limitations that prevent its integration? If there is no compatibility issue, I would be open to exploring the possibility of adding it via a PR.
The text was updated successfully, but these errors were encountered: