Skip to content

Commit

Permalink
Add gradient checkpointing to training pipeline
Browse files Browse the repository at this point in the history
Fixes SWivid#399

Implement gradient checkpointing in the training pipeline.

* **Model Backbones**:
  - Import `checkpoint` from `torch.utils.checkpoint` in `src/f5_tts/model/backbones/dit.py`, `src/f5_tts/model/backbones/unett.py`, and `src/f5_tts/model/backbones/mmdit.py`.
  - Add a parameter `use_checkpointing` to the constructors of `DiT`, `UNetT`, and `MMDiT` classes, defaulting to `False`.
  - Modify the `forward` methods to use `checkpoint` for each block if `use_checkpointing` is `True`.

* **Trainer**:
  - Add a parameter `use_checkpointing` to the `Trainer` class constructor in `src/f5_tts/model/trainer.py`, defaulting to `False`.
  - Modify the `train` method to enable gradient checkpointing if `use_checkpointing` is `True`.

* **Training Script**:
  - Add a parameter `use_checkpointing` to the `Trainer` instantiation in `src/f5_tts/train/train.py`, defaulting to `False`.
  • Loading branch information
kostum123 committed Nov 5, 2024
1 parent 4a69e6b commit c51304e
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 7 deletions.
9 changes: 8 additions & 1 deletion src/f5_tts/model/backbones/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

from x_transformers.x_transformers import RotaryEmbedding

Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(
text_dim=None,
conv_layers=0,
long_skip_connection=False,
use_checkpointing=False, # Pf526
):
super().__init__()

Expand All @@ -127,6 +129,8 @@ def __init__(
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)

self.use_checkpointing = use_checkpointing # Pf526

def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
Expand All @@ -152,7 +156,10 @@ def forward(
residual = x

for block in self.transformer_blocks:
x = block(x, t, mask=mask, rope=rope)
if self.use_checkpointing: # P7ca4
x = checkpoint(block, x, t, mask, rope)
else:
x = block(x, t, mask=mask, rope=rope)

if self.long_skip_connection is not None:
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
Expand Down
9 changes: 8 additions & 1 deletion src/f5_tts/model/backbones/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint

from x_transformers.x_transformers import RotaryEmbedding

Expand Down Expand Up @@ -85,6 +86,7 @@ def __init__(
ff_mult=4,
text_num_embeds=256,
mel_dim=100,
use_checkpointing=False, # P8e54
):
super().__init__()

Expand Down Expand Up @@ -113,6 +115,8 @@ def __init__(
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)

self.use_checkpointing = use_checkpointing # P8e54

def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
Expand All @@ -138,7 +142,10 @@ def forward(
rope_text = self.rotary_embed.forward_from_seq_len(text_len)

for block in self.transformer_blocks:
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
if self.use_checkpointing: # P0895
c, x = checkpoint(block, x, c, t, mask, rope_audio, rope_text)
else:
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)

x = self.norm_out(x, t)
output = self.proj_out(x)
Expand Down
12 changes: 10 additions & 2 deletions src/f5_tts/model/backbones/unett.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

from x_transformers import RMSNorm
from x_transformers.x_transformers import RotaryEmbedding
Expand Down Expand Up @@ -108,6 +109,7 @@ def __init__(
text_dim=None,
conv_layers=0,
skip_connect_type: Literal["add", "concat", "none"] = "concat",
use_checkpointing=False, # P2c44
):
super().__init__()
assert depth % 2 == 0, "UNet-Transformer's depth should be even."
Expand Down Expand Up @@ -161,6 +163,8 @@ def __init__(
self.norm_out = RMSNorm(dim)
self.proj_out = nn.Linear(dim, mel_dim)

self.use_checkpointing = use_checkpointing # P2c44

def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
Expand Down Expand Up @@ -209,8 +213,12 @@ def forward(
x = x + skip

# attention and feedforward blocks
x = attn(attn_norm(x), rope=rope, mask=mask) + x
x = ff(ff_norm(x)) + x
if self.use_checkpointing: # P4600
x = checkpoint(attn, attn_norm(x), rope, mask) + x
x = checkpoint(ff, ff_norm(x)) + x
else:
x = attn(attn_norm(x), rope=rope, mask=mask) + x
x = ff(ff_norm(x)) + x

assert len(skips) == 0

Expand Down
19 changes: 16 additions & 3 deletions src/f5_tts/model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
ema_kwargs: dict = dict(),
bnb_optimizer: bool = False,
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
use_checkpointing: bool = False, # Pcdd9
):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)

Expand Down Expand Up @@ -122,6 +123,8 @@ def __init__(
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)

self.use_checkpointing = use_checkpointing # Pcdd9

@property
def is_main(self):
return self.accelerator.is_main_process
Expand Down Expand Up @@ -295,9 +298,19 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int
dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)

loss, cond, pred = self.model(
mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
)
if self.use_checkpointing: # Pae0a
from torch.utils.checkpoint import checkpoint
loss, cond, pred = checkpoint(
self.model,
mel_spec,
text_inputs,
mel_lengths,
self.noise_scheduler,
)
else:
loss, cond, pred = self.model(
mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
)
self.accelerator.backward(loss)

if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
Expand Down
1 change: 1 addition & 0 deletions src/f5_tts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def main():
last_per_steps=last_per_steps,
log_samples=True,
mel_spec_type=mel_spec_type,
use_checkpointing=False # P050d
)

train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
Expand Down

0 comments on commit c51304e

Please sign in to comment.