diff --git a/src/f5_tts/model/backbones/dit.py b/src/f5_tts/model/backbones/dit.py index 391752a4..7fe09251 100644 --- a/src/f5_tts/model/backbones/dit.py +++ b/src/f5_tts/model/backbones/dit.py @@ -12,6 +12,7 @@ import torch from torch import nn import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint from x_transformers.x_transformers import RotaryEmbedding @@ -105,6 +106,7 @@ def __init__( text_dim=None, conv_layers=0, long_skip_connection=False, + use_checkpointing=False, # Pf526 ): super().__init__() @@ -127,6 +129,8 @@ def __init__( self.norm_out = AdaLayerNormZero_Final(dim) # final modulation self.proj_out = nn.Linear(dim, mel_dim) + self.use_checkpointing = use_checkpointing # Pf526 + def forward( self, x: float["b n d"], # nosied input audio # noqa: F722 @@ -152,7 +156,10 @@ def forward( residual = x for block in self.transformer_blocks: - x = block(x, t, mask=mask, rope=rope) + if self.use_checkpointing: # P7ca4 + x = checkpoint(block, x, t, mask, rope) + else: + x = block(x, t, mask=mask, rope=rope) if self.long_skip_connection is not None: x = self.long_skip_connection(torch.cat((x, residual), dim=-1)) diff --git a/src/f5_tts/model/backbones/mmdit.py b/src/f5_tts/model/backbones/mmdit.py index 64c7ef18..dff68fb5 100644 --- a/src/f5_tts/model/backbones/mmdit.py +++ b/src/f5_tts/model/backbones/mmdit.py @@ -11,6 +11,7 @@ import torch from torch import nn +from torch.utils.checkpoint import checkpoint from x_transformers.x_transformers import RotaryEmbedding @@ -85,6 +86,7 @@ def __init__( ff_mult=4, text_num_embeds=256, mel_dim=100, + use_checkpointing=False, # P8e54 ): super().__init__() @@ -113,6 +115,8 @@ def __init__( self.norm_out = AdaLayerNormZero_Final(dim) # final modulation self.proj_out = nn.Linear(dim, mel_dim) + self.use_checkpointing = use_checkpointing # P8e54 + def forward( self, x: float["b n d"], # nosied input audio # noqa: F722 @@ -138,7 +142,10 @@ def forward( rope_text = self.rotary_embed.forward_from_seq_len(text_len) for block in self.transformer_blocks: - c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text) + if self.use_checkpointing: # P0895 + c, x = checkpoint(block, x, c, t, mask, rope_audio, rope_text) + else: + c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text) x = self.norm_out(x, t) output = self.proj_out(x) diff --git a/src/f5_tts/model/backbones/unett.py b/src/f5_tts/model/backbones/unett.py index acf649a5..2f23d4b7 100644 --- a/src/f5_tts/model/backbones/unett.py +++ b/src/f5_tts/model/backbones/unett.py @@ -13,6 +13,7 @@ import torch from torch import nn import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint from x_transformers import RMSNorm from x_transformers.x_transformers import RotaryEmbedding @@ -108,6 +109,7 @@ def __init__( text_dim=None, conv_layers=0, skip_connect_type: Literal["add", "concat", "none"] = "concat", + use_checkpointing=False, # P2c44 ): super().__init__() assert depth % 2 == 0, "UNet-Transformer's depth should be even." @@ -161,6 +163,8 @@ def __init__( self.norm_out = RMSNorm(dim) self.proj_out = nn.Linear(dim, mel_dim) + self.use_checkpointing = use_checkpointing # P2c44 + def forward( self, x: float["b n d"], # nosied input audio # noqa: F722 @@ -209,8 +213,12 @@ def forward( x = x + skip # attention and feedforward blocks - x = attn(attn_norm(x), rope=rope, mask=mask) + x - x = ff(ff_norm(x)) + x + if self.use_checkpointing: # P4600 + x = checkpoint(attn, attn_norm(x), rope, mask) + x + x = checkpoint(ff, ff_norm(x)) + x + else: + x = attn(attn_norm(x), rope=rope, mask=mask) + x + x = ff(ff_norm(x)) + x assert len(skips) == 0 diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index bd96da7c..e34eec8e 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -47,6 +47,7 @@ def __init__( ema_kwargs: dict = dict(), bnb_optimizer: bool = False, mel_spec_type: str = "vocos", # "vocos" | "bigvgan" + use_checkpointing: bool = False, # Pcdd9 ): ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) @@ -122,6 +123,8 @@ def __init__( self.optimizer = AdamW(model.parameters(), lr=learning_rate) self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + self.use_checkpointing = use_checkpointing # Pcdd9 + @property def is_main(self): return self.accelerator.is_main_process @@ -295,9 +298,19 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations")) self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step) - loss, cond, pred = self.model( - mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler - ) + if self.use_checkpointing: # Pae0a + from torch.utils.checkpoint import checkpoint + loss, cond, pred = checkpoint( + self.model, + mel_spec, + text_inputs, + mel_lengths, + self.noise_scheduler, + ) + else: + loss, cond, pred = self.model( + mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler + ) self.accelerator.backward(loss) if self.max_grad_norm > 0 and self.accelerator.sync_gradients: diff --git a/src/f5_tts/train/train.py b/src/f5_tts/train/train.py index fac0fe5a..21d5a33b 100644 --- a/src/f5_tts/train/train.py +++ b/src/f5_tts/train/train.py @@ -90,6 +90,7 @@ def main(): last_per_steps=last_per_steps, log_samples=True, mel_spec_type=mel_spec_type, + use_checkpointing=False # P050d ) train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)