diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index 88652f103..c45b6b1d3 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -4,10 +4,6 @@ import gc from tqdm import tqdm -try: - from torch.utils.tensorboard import SummaryWriter -except ImportError: - print("TensorBoard is not installed") import torch from torch.optim import AdamW @@ -22,27 +18,11 @@ from f5_tts.model import CFM from f5_tts.model.utils import exists, default from f5_tts.model.dataset import DynamicBatchSampler, collate_fn +from f5_tts.infer.utils_infer import target_sample_rate, hop_length, nfe_step, cfg_strength, sway_sampling_coef, vocos +from f5_tts.model.utils import gen_sample -import numpy as np -import matplotlib.pyplot as plt # trainer -# audio imports -import torchaudio -import soundfile as sf -from vocos import Vocos -import warnings - -warnings.filterwarnings("ignore", category=FutureWarning) - -# ----------------------------------------- -target_sample_rate = 24000 -hop_length = 256 -nfe_step = 16 -cfg_strength = 2.0 -sway_sampling_coef = -1.0 -# ----------------------------------------- - class Trainer: def __init__( @@ -104,6 +84,8 @@ def __init__( }, ) elif self.logger == "tensorboard": + from torch.utils.tensorboard import SummaryWriter + self.accelerator = Accelerator( kwargs_handlers=[ddp_kwargs], gradient_accumulation_steps=grad_accumulation_steps, @@ -124,9 +106,6 @@ def __init__( self.export_samples = export_samples if self.export_samples: self.path_ckpts_project = checkpoint_path - - self.vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") - self.vocos.to("cpu") self.file_path_samples = os.path.join(self.path_ckpts_project, "samples") os.makedirs(self.file_path_samples, exist_ok=True) @@ -233,72 +212,6 @@ def log(self, metrics, step): for key, value in metrics.items(): self.writer.add_scalar(key, value, step) - def export_add_log(self, global_step, mel_org, text_inputs): - try: - generated_wave_org = self.vocos.decode(mel_org.unsqueeze(0).cpu()) - generated_wave_org = generated_wave_org.squeeze().cpu().numpy() - file_wav_org = os.path.join(self.file_path_samples, f"step_{global_step}_org.wav") - sf.write(file_wav_org, generated_wave_org, target_sample_rate) - - audio, sr = torchaudio.load(file_wav_org) - audio = audio.to("cuda") - - ref_audio_len = audio.shape[-1] // hop_length - text = [text_inputs[0] + [" . "] + text_inputs[0]] - duration = int((audio.shape[1] / 256) * 2.0) - - with torch.inference_mode(): - generated_gen, _ = self.model.sample( - cond=audio, - text=text, - duration=duration, - steps=nfe_step, - cfg_strength=cfg_strength, - sway_sampling_coef=sway_sampling_coef, - ) - - generated_gen = generated_gen.to(torch.float32) - generated_gen = generated_gen[:, ref_audio_len:, :] - generated_mel_spec_gen = generated_gen.permute(0, 2, 1) - generated_wave_gen = self.vocos.decode(generated_mel_spec_gen.cpu()) - generated_wave_gen = generated_wave_gen.squeeze().cpu().numpy() - file_wav_gen = os.path.join(self.file_path_samples, f"step_{global_step}_gen.wav") - sf.write(file_wav_gen, generated_wave_gen, target_sample_rate) - - if self.logger == "tensorboard": - self.writer.add_audio("Audio/original", generated_wave_org, global_step, sample_rate=target_sample_rate) - - self.writer.add_audio("Audio/generate", generated_wave_gen, global_step, sample_rate=target_sample_rate) - - mel_org = mel_org - mel_min, mel_max = mel_org.min(), mel_org.max() - mel_norm = (mel_org - mel_min) / (mel_max - mel_min + 1e-8) - mel_colored = plt.get_cmap("viridis")(mel_norm.detach().cpu().numpy())[:, :, :3] - mel_colored = np.transpose(mel_colored, (2, 0, 1)) - - if self.logger == "tensorboard": - self.writer.add_image("Mel/oginal", mel_colored, global_step, dataformats="CHW") - - mel_colored_hwc = np.transpose(mel_colored, (1, 2, 0)) - file_gen_org = os.path.join(self.file_path_samples, f"step_{global_step}_org.png") - plt.imsave(file_gen_org, mel_colored_hwc) - - mel_gen = generated_mel_spec_gen[0] - mel_min, mel_max = mel_gen.min(), mel_gen.max() - mel_norm = (mel_gen - mel_min) / (mel_max - mel_min + 1e-8) - mel_colored = plt.get_cmap("viridis")(mel_norm.detach().cpu().numpy())[:, :, :3] - mel_colored = np.transpose(mel_colored, (2, 0, 1)) - - if self.logger == "tensorboard": - self.writer.add_image("Mel/generate", mel_colored, global_step, dataformats="CHW") - - mel_colored_hwc = np.transpose(mel_colored, (1, 2, 0)) - file_gen_gen = os.path.join(self.file_path_samples, f"step_{global_step}_gen.png") - plt.imsave(file_gen_gen, mel_colored_hwc) - - except Exception as e: - print("An error occurred:", e) - def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None): if exists(resumable_with_seed): generator = torch.Generator() @@ -383,6 +296,7 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int for batch in progress_bar: with self.accelerator.accumulate(self.model): text_inputs = batch["text"] + mel_spec = batch["mel"].permute(0, 2, 1) mel_lengths = batch["mel_lengths"] @@ -401,7 +315,29 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int and self.export_samples and global_step % (int(self.save_per_updates * 0.25) * self.grad_accumulation_steps) == 0 ): - self.export_add_log(global_step, batch["mel"][0], text_inputs) + wave_org, wave_gen, mel_org, mel_gen = gen_sample( + vocos, + self.model, + self.file_path_samples, + global_step, + batch["mel"][0], + text_inputs, + target_sample_rate, + hop_length, + nfe_step, + cfg_strength, + sway_sampling_coef, + ) + + if self.logger == "tensorboard": + self.writer.add_audio( + "Audio/original", wave_org, global_step, sample_rate=target_sample_rate + ) + self.writer.add_audio( + "Audio/generate", wave_gen, global_step, sample_rate=target_sample_rate + ) + self.writer.add_image("Mel/original", mel_org, global_step, dataformats="CHW") + self.writer.add_image("Mel/generate", mel_gen, global_step, dataformats="CHW") self.accelerator.backward(loss) diff --git a/src/f5_tts/model/utils.py b/src/f5_tts/model/utils.py index 76cfa4d0d..8877a7e15 100644 --- a/src/f5_tts/model/utils.py +++ b/src/f5_tts/model/utils.py @@ -11,6 +11,10 @@ import jieba from pypinyin import lazy_pinyin, Style +import numpy as np +import matplotlib.pyplot as plt +import soundfile as sf +import torchaudio # seed everything @@ -183,3 +187,73 @@ def repetition_found(text, length=2, tolerance=10): if count > tolerance: return True return False + + +def normalize_and_colorize_spectrogram(mel_org): + mel_min, mel_max = mel_org.min(), mel_org.max() + mel_norm = (mel_org - mel_min) / (mel_max - mel_min + 1e-8) + mel_colored = plt.get_cmap("viridis")(mel_norm.detach().cpu().numpy())[:, :, :3] + mel_colored = np.transpose(mel_colored, (2, 0, 1)) + return mel_colored + + +def export_audio(file_out, wav, target_sample_rate): + sf.write(file_out, wav, samplerate=target_sample_rate) + + +def export_mel(mel_colored_hwc, file_out): + plt.imsave(file_out, mel_colored_hwc) + + +def get_sample(model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef): + audio, sr = torchaudio.load(file_wav_org) + audio = audio.to("cuda") + ref_audio_len = audio.shape[-1] // hop_length + text = [text_inputs[0] + [" . "] + text_inputs[0]] + duration = int((audio.shape[1] / 256) * 2.0) + with torch.inference_mode(): + generated_gen, _ = model.sample( + cond=audio, + text=text, + duration=duration, + steps=nfe_step, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, + ) + generated_gen = generated_gen.to(torch.float32) + generated_gen = generated_gen[:, ref_audio_len:, :] + generated_mel_spec_gen = generated_gen.permute(0, 2, 1) + generated_wave_gen = vocos.decode(generated_mel_spec_gen.cpu()) + generated_wave_gen = generated_wave_gen.squeeze().cpu().numpy() + return generated_wave_gen, generated_mel_spec_gen + + +def gen_sample( + vocos, + model, + file_path_samples, + global_step, + mel_org, + text_inputs, + target_sample_rate, + hop_length, + nfe_step, + cfg_strength, + sway_sampling_coef, +): + generated_wave_org = vocos.decode(mel_org.unsqueeze(0).cpu()) + generated_wave_org = generated_wave_org.squeeze().cpu().numpy() + file_wav_org = os.path.join(file_path_samples, f"step_{global_step}_org.wav") + export_audio(file_wav_org, generated_wave_org, target_sample_rate) + generated_wave_gen, generated_mel_spec_gen = get_sample( + model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef + ) + file_wav_gen = os.path.join(file_path_samples, f"step_{global_step}_gen.wav") + export_audio(file_wav_gen, generated_wave_gen, target_sample_rate) + mel_org = normalize_and_colorize_spectrogram(mel_org) + mel_gen = normalize_and_colorize_spectrogram(generated_mel_spec_gen[0]) + file_gen_org = os.path.join(file_path_samples, f"step_{global_step}_org.png") + export_mel(np.transpose(mel_org, (1, 2, 0)), file_gen_org) + file_gen_gen = os.path.join(file_path_samples, f"step_{global_step}_gen.png") + export_mel(np.transpose(mel_gen, (1, 2, 0)), file_gen_gen) + return generated_wave_org, generated_wave_gen, mel_org, mel_gen