Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
lpscr committed Oct 29, 2024
1 parent 37eb3b5 commit 3409192
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 92 deletions.
120 changes: 28 additions & 92 deletions src/f5_tts/model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
import gc
from tqdm import tqdm

try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
print("TensorBoard is not installed")

import torch
from torch.optim import AdamW
Expand All @@ -22,27 +18,11 @@
from f5_tts.model import CFM
from f5_tts.model.utils import exists, default
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
from f5_tts.infer.utils_infer import target_sample_rate, hop_length, nfe_step, cfg_strength, sway_sampling_coef, vocos
from f5_tts.model.utils import gen_sample

import numpy as np
import matplotlib.pyplot as plt
# trainer

# audio imports
import torchaudio
import soundfile as sf
from vocos import Vocos
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

# -----------------------------------------
target_sample_rate = 24000
hop_length = 256
nfe_step = 16
cfg_strength = 2.0
sway_sampling_coef = -1.0
# -----------------------------------------


class Trainer:
def __init__(
Expand Down Expand Up @@ -104,6 +84,8 @@ def __init__(
},
)
elif self.logger == "tensorboard":
from torch.utils.tensorboard import SummaryWriter

self.accelerator = Accelerator(
kwargs_handlers=[ddp_kwargs],
gradient_accumulation_steps=grad_accumulation_steps,
Expand All @@ -124,9 +106,6 @@ def __init__(
self.export_samples = export_samples
if self.export_samples:
self.path_ckpts_project = checkpoint_path

self.vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
self.vocos.to("cpu")
self.file_path_samples = os.path.join(self.path_ckpts_project, "samples")
os.makedirs(self.file_path_samples, exist_ok=True)

Expand Down Expand Up @@ -233,72 +212,6 @@ def log(self, metrics, step):
for key, value in metrics.items():
self.writer.add_scalar(key, value, step)

def export_add_log(self, global_step, mel_org, text_inputs):
try:
generated_wave_org = self.vocos.decode(mel_org.unsqueeze(0).cpu())
generated_wave_org = generated_wave_org.squeeze().cpu().numpy()
file_wav_org = os.path.join(self.file_path_samples, f"step_{global_step}_org.wav")
sf.write(file_wav_org, generated_wave_org, target_sample_rate)

audio, sr = torchaudio.load(file_wav_org)
audio = audio.to("cuda")

ref_audio_len = audio.shape[-1] // hop_length
text = [text_inputs[0] + [" . "] + text_inputs[0]]
duration = int((audio.shape[1] / 256) * 2.0)

with torch.inference_mode():
generated_gen, _ = self.model.sample(
cond=audio,
text=text,
duration=duration,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
)

generated_gen = generated_gen.to(torch.float32)
generated_gen = generated_gen[:, ref_audio_len:, :]
generated_mel_spec_gen = generated_gen.permute(0, 2, 1)
generated_wave_gen = self.vocos.decode(generated_mel_spec_gen.cpu())
generated_wave_gen = generated_wave_gen.squeeze().cpu().numpy()
file_wav_gen = os.path.join(self.file_path_samples, f"step_{global_step}_gen.wav")
sf.write(file_wav_gen, generated_wave_gen, target_sample_rate)

if self.logger == "tensorboard":
self.writer.add_audio("Audio/original", generated_wave_org, global_step, sample_rate=target_sample_rate)

self.writer.add_audio("Audio/generate", generated_wave_gen, global_step, sample_rate=target_sample_rate)

mel_org = mel_org
mel_min, mel_max = mel_org.min(), mel_org.max()
mel_norm = (mel_org - mel_min) / (mel_max - mel_min + 1e-8)
mel_colored = plt.get_cmap("viridis")(mel_norm.detach().cpu().numpy())[:, :, :3]
mel_colored = np.transpose(mel_colored, (2, 0, 1))

if self.logger == "tensorboard":
self.writer.add_image("Mel/oginal", mel_colored, global_step, dataformats="CHW")

mel_colored_hwc = np.transpose(mel_colored, (1, 2, 0))
file_gen_org = os.path.join(self.file_path_samples, f"step_{global_step}_org.png")
plt.imsave(file_gen_org, mel_colored_hwc)

mel_gen = generated_mel_spec_gen[0]
mel_min, mel_max = mel_gen.min(), mel_gen.max()
mel_norm = (mel_gen - mel_min) / (mel_max - mel_min + 1e-8)
mel_colored = plt.get_cmap("viridis")(mel_norm.detach().cpu().numpy())[:, :, :3]
mel_colored = np.transpose(mel_colored, (2, 0, 1))

if self.logger == "tensorboard":
self.writer.add_image("Mel/generate", mel_colored, global_step, dataformats="CHW")

mel_colored_hwc = np.transpose(mel_colored, (1, 2, 0))
file_gen_gen = os.path.join(self.file_path_samples, f"step_{global_step}_gen.png")
plt.imsave(file_gen_gen, mel_colored_hwc)

except Exception as e:
print("An error occurred:", e)

def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
if exists(resumable_with_seed):
generator = torch.Generator()
Expand Down Expand Up @@ -383,6 +296,7 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int
for batch in progress_bar:
with self.accelerator.accumulate(self.model):
text_inputs = batch["text"]

mel_spec = batch["mel"].permute(0, 2, 1)
mel_lengths = batch["mel_lengths"]

Expand All @@ -401,7 +315,29 @@ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int
and self.export_samples
and global_step % (int(self.save_per_updates * 0.25) * self.grad_accumulation_steps) == 0
):
self.export_add_log(global_step, batch["mel"][0], text_inputs)
wave_org, wave_gen, mel_org, mel_gen = gen_sample(
vocos,
self.model,
self.file_path_samples,
global_step,
batch["mel"][0],
text_inputs,
target_sample_rate,
hop_length,
nfe_step,
cfg_strength,
sway_sampling_coef,
)

if self.logger == "tensorboard":
self.writer.add_audio(
"Audio/original", wave_org, global_step, sample_rate=target_sample_rate
)
self.writer.add_audio(
"Audio/generate", wave_gen, global_step, sample_rate=target_sample_rate
)
self.writer.add_image("Mel/original", mel_org, global_step, dataformats="CHW")
self.writer.add_image("Mel/generate", mel_gen, global_step, dataformats="CHW")

self.accelerator.backward(loss)

Expand Down
74 changes: 74 additions & 0 deletions src/f5_tts/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
import jieba
from pypinyin import lazy_pinyin, Style

import numpy as np
import matplotlib.pyplot as plt
import soundfile as sf
import torchaudio

# seed everything

Expand Down Expand Up @@ -183,3 +187,73 @@ def repetition_found(text, length=2, tolerance=10):
if count > tolerance:
return True
return False


def normalize_and_colorize_spectrogram(mel_org):
mel_min, mel_max = mel_org.min(), mel_org.max()
mel_norm = (mel_org - mel_min) / (mel_max - mel_min + 1e-8)
mel_colored = plt.get_cmap("viridis")(mel_norm.detach().cpu().numpy())[:, :, :3]
mel_colored = np.transpose(mel_colored, (2, 0, 1))
return mel_colored


def export_audio(file_out, wav, target_sample_rate):
sf.write(file_out, wav, samplerate=target_sample_rate)


def export_mel(mel_colored_hwc, file_out):
plt.imsave(file_out, mel_colored_hwc)


def get_sample(model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef):
audio, sr = torchaudio.load(file_wav_org)
audio = audio.to("cuda")
ref_audio_len = audio.shape[-1] // hop_length
text = [text_inputs[0] + [" . "] + text_inputs[0]]
duration = int((audio.shape[1] / 256) * 2.0)
with torch.inference_mode():
generated_gen, _ = model.sample(
cond=audio,
text=text,
duration=duration,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
)
generated_gen = generated_gen.to(torch.float32)
generated_gen = generated_gen[:, ref_audio_len:, :]
generated_mel_spec_gen = generated_gen.permute(0, 2, 1)
generated_wave_gen = vocos.decode(generated_mel_spec_gen.cpu())
generated_wave_gen = generated_wave_gen.squeeze().cpu().numpy()
return generated_wave_gen, generated_mel_spec_gen


def gen_sample(
vocos,
model,
file_path_samples,
global_step,
mel_org,
text_inputs,
target_sample_rate,
hop_length,
nfe_step,
cfg_strength,
sway_sampling_coef,
):
generated_wave_org = vocos.decode(mel_org.unsqueeze(0).cpu())
generated_wave_org = generated_wave_org.squeeze().cpu().numpy()
file_wav_org = os.path.join(file_path_samples, f"step_{global_step}_org.wav")
export_audio(file_wav_org, generated_wave_org, target_sample_rate)
generated_wave_gen, generated_mel_spec_gen = get_sample(
model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef
)
file_wav_gen = os.path.join(file_path_samples, f"step_{global_step}_gen.wav")
export_audio(file_wav_gen, generated_wave_gen, target_sample_rate)
mel_org = normalize_and_colorize_spectrogram(mel_org)
mel_gen = normalize_and_colorize_spectrogram(generated_mel_spec_gen[0])
file_gen_org = os.path.join(file_path_samples, f"step_{global_step}_org.png")
export_mel(np.transpose(mel_org, (1, 2, 0)), file_gen_org)
file_gen_gen = os.path.join(file_path_samples, f"step_{global_step}_gen.png")
export_mel(np.transpose(mel_gen, (1, 2, 0)), file_gen_gen)
return generated_wave_org, generated_wave_gen, mel_org, mel_gen

0 comments on commit 3409192

Please sign in to comment.