Skip to content

Commit

Permalink
0.4.2 fix trainer with grad_accum
Browse files Browse the repository at this point in the history
  • Loading branch information
SWivid committed Jan 15, 2025
1 parent 12d6970 commit 9e51878
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "f5-tts"
version = "0.4.1"
version = "0.4.2"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}
Expand Down
16 changes: 8 additions & 8 deletions src/f5_tts/model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,8 @@ def __init__(

if logger == "wandb" and not wandb.api.api_key:
logger = None
print(f"Using logger: {logger}")
self.log_samples = log_samples

if grad_accumulation_steps > 1 and self.is_main:
print(
"Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e"
)

self.accelerator = Accelerator(
log_with=logger if logger == "wandb" else None,
kwargs_handlers=[ddp_kwargs],
Expand Down Expand Up @@ -106,6 +100,12 @@ def __init__(
self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
self.ema_model.to(self.accelerator.device)

print(f"Using logger: {logger}")
if grad_accumulation_steps > 1:
print(
"Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e"
)

self.epochs = epochs
self.num_warmup_updates = num_warmup_updates
self.save_per_updates = save_per_updates
Expand Down Expand Up @@ -357,7 +357,7 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int
self.writer.add_scalar("loss", loss.item(), global_update)
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update)

if global_update % self.save_per_updates == 0:
if global_update % self.save_per_updates == 0 and self.accelerator.sync_gradients:
self.save_checkpoint(global_update)

if self.log_samples and self.accelerator.is_local_main_process:
Expand Down Expand Up @@ -391,7 +391,7 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int
f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate
)

if global_update % self.last_per_updates == 0:
if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
self.save_checkpoint(global_update, last=True)

self.save_checkpoint(global_update, last=True)
Expand Down

0 comments on commit 9e51878

Please sign in to comment.