From 25cdc5182f4a0040e94a4b67f7c228bcc59585b7 Mon Sep 17 00:00:00 2001 From: lpscr <147736764+lpscr@users.noreply.github.com> Date: Mon, 21 Oct 2024 11:57:24 +0300 Subject: [PATCH 1/2] add api for easy use (#186) * add api * update infer limits --- api.py | 117 +++++++++++++++++++++++++++++++++++++++++++ model/utils_infer.py | 74 +++++++++++++++++++-------- 2 files changed, 171 insertions(+), 20 deletions(-) create mode 100644 api.py diff --git a/api.py b/api.py new file mode 100644 index 00000000..0b639a95 --- /dev/null +++ b/api.py @@ -0,0 +1,117 @@ +import soundfile as sf +import torch +import tqdm +from cached_path import cached_path + +from model import DiT, UNetT +from model.utils import save_spectrogram + +from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav + + +class F5TTS: + def __init__( + self, + model_type="F5-TTS", + ckpt_file="", + vocab_file="", + ode_method="euler", + use_ema=True, + local_path=None, + device=None, + ): + # Initialize parameters + self.final_wave = None + self.target_sample_rate = 24000 + self.n_mel_channels = 100 + self.hop_length = 256 + self.target_rms = 0.1 + + # Set device + self.device = device or ( + "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + ) + + # Load models + self.load_vecoder_model(local_path) + self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema) + + def load_vecoder_model(self, local_path): + self.vocos = load_vocoder(local_path is not None, local_path, self.device) + + def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema): + if model_type == "F5-TTS": + if not ckpt_file: + ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")) + model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) + model_cls = DiT + elif model_type == "E2-TTS": + if not ckpt_file: + ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")) + model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) + model_cls = UNetT + else: + raise ValueError(f"Unknown model type: {model_type}") + + self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device) + + def export_wav(self, wav, file_wave, remove_silence=False): + if remove_silence: + remove_silence_for_generated_wav(file_wave) + + sf.write(file_wave, wav, self.target_sample_rate) + + def export_spectrogram(self, spect, file_spect): + save_spectrogram(spect, file_spect) + + def infer( + self, + ref_file, + ref_text, + gen_text, + sway_sampling_coef=-1, + cfg_strength=2, + nfe_step=32, + speed=1.0, + fix_duration=None, + remove_silence=False, + file_wave=None, + file_spect=None, + cross_fade_duration=0.15, + show_info=print, + progress=tqdm, + ): + wav, sr, spect = infer_process( + ref_file, + ref_text, + gen_text, + self.ema_model, + cross_fade_duration, + speed, + show_info, + progress, + nfe_step, + cfg_strength, + sway_sampling_coef, + fix_duration, + ) + + if file_wave is not None: + self.export_wav(wav, file_wave, remove_silence) + + if file_spect is not None: + self.export_spectrogram(spect, file_spect) + + return wav, sr, spect + + +if __name__ == "__main__": + f5tts = F5TTS() + + wav, sr, spect = f5tts.infer( + ref_file="tests/ref_audio/test_en_1_ref_short.wav", + ref_text="some call me nature, others call me mother nature.", + gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""", + file_wave="tests/out.wav", + file_spect="tests/out.png", + ) diff --git a/model/utils_infer.py b/model/utils_infer.py index faea34cd..da87f7a4 100644 --- a/model/utils_infer.py +++ b/model/utils_infer.py @@ -38,12 +38,12 @@ n_mel_channels = 100 hop_length = 256 target_rms = 0.1 -nfe_step = 32 # 16, 32 -cfg_strength = 2.0 -ode_method = "euler" -sway_sampling_coef = -1.0 -speed = 1.0 -fix_duration = None +# nfe_step = 32 # 16, 32 +# cfg_strength = 2.0 +# ode_method = "euler" +# sway_sampling_coef = -1.0 +# speed = 1.0 +# fix_duration = None # ----------------------------------------- @@ -84,7 +84,7 @@ def chunk_text(text, max_chars=135): # load vocoder -def load_vocoder(is_local=False, local_path=""): +def load_vocoder(is_local=False, local_path="", device=device): if is_local: print(f"Load vocos from local path {local_path}") vocos = Vocos.from_hparams(f"{local_path}/config.yaml") @@ -100,14 +100,14 @@ def load_vocoder(is_local=False, local_path=""): # load model for inference -def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""): +def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method="euler", use_ema=True, device=device): if vocab_file == "": vocab_file = "Emilia_ZH_EN" tokenizer = "pinyin" else: tokenizer = "custom" - print("\nvocab : ", vocab_file, tokenizer) + print("\nvocab : ", vocab_file) print("tokenizer : ", tokenizer) print("model : ", ckpt_path, "\n") @@ -125,7 +125,7 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""): vocab_char_map=vocab_char_map, ).to(device) - model = load_checkpoint(model, ckpt_path, device, use_ema=True) + model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema) return model @@ -178,7 +178,18 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print): def infer_process( - ref_audio, ref_text, gen_text, model_obj, cross_fade_duration=0.15, speed=speed, show_info=print, progress=tqdm + ref_audio, + ref_text, + gen_text, + model_obj, + cross_fade_duration=0.15, + speed=1.0, + show_info=print, + progress=tqdm, + nfe_step=32, + cfg_strength=2, + sway_sampling_coef=-1, + fix_duration=None, ): # Split the input text into batches audio, sr = torchaudio.load(ref_audio) @@ -188,14 +199,36 @@ def infer_process( print(f"gen_text {i}", gen_text) show_info(f"Generating audio in {len(gen_text_batches)} batches...") - return infer_batch_process((audio, sr), ref_text, gen_text_batches, model_obj, cross_fade_duration, speed, progress) + return infer_batch_process( + (audio, sr), + ref_text, + gen_text_batches, + model_obj, + cross_fade_duration, + speed, + progress, + nfe_step, + cfg_strength, + sway_sampling_coef, + fix_duration, + ) # infer batches def infer_batch_process( - ref_audio, ref_text, gen_text_batches, model_obj, cross_fade_duration=0.15, speed=1, progress=tqdm + ref_audio, + ref_text, + gen_text_batches, + model_obj, + cross_fade_duration=0.15, + speed=1, + progress=tqdm, + nfe_step=32, + cfg_strength=2.0, + sway_sampling_coef=-1, + fix_duration=None, ): audio, sr = ref_audio if audio.shape[0] > 1: @@ -219,11 +252,14 @@ def infer_batch_process( text_list = [ref_text + gen_text] final_text_list = convert_char_to_pinyin(text_list) - # Calculate duration - ref_audio_len = audio.shape[-1] // hop_length - ref_text_len = len(ref_text.encode("utf-8")) - gen_text_len = len(gen_text.encode("utf-8")) - duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) + if fix_duration is not None: + duration = int(fix_duration * target_sample_rate / hop_length) + else: + # Calculate duration + ref_audio_len = audio.shape[-1] // hop_length + ref_text_len = len(ref_text.encode("utf-8")) + gen_text_len = len(gen_text.encode("utf-8")) + duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) # inference with torch.inference_mode(): @@ -293,8 +329,6 @@ def infer_batch_process( # remove silence from generated wav - - def remove_silence_for_generated_wav(filename): aseg = AudioSegment.from_file(filename) non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500) From 795cb19e4fb37fe2c9ade85dc6a80f6d6286b775 Mon Sep 17 00:00:00 2001 From: Haitao Date: Mon, 21 Oct 2024 17:00:48 +0800 Subject: [PATCH 2/2] allow for passing in custom mel spec module (#200) --- model/dataset.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/model/dataset.py b/model/dataset.py index 03ed473f..c293fe23 100644 --- a/model/dataset.py +++ b/model/dataset.py @@ -8,8 +8,10 @@ import torchaudio from datasets import load_from_disk from datasets import Dataset as Dataset_ +from torch import nn from model.modules import MelSpec +from model.utils import default class HFDataset(Dataset): @@ -77,15 +79,22 @@ def __init__( hop_length=256, n_mel_channels=100, preprocessed_mel=False, + mel_spec_module: nn.Module | None = None, ): self.data = custom_dataset self.durations = durations self.target_sample_rate = target_sample_rate self.hop_length = hop_length self.preprocessed_mel = preprocessed_mel + if not preprocessed_mel: - self.mel_spectrogram = MelSpec( - target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels + self.mel_spectrogram = default( + mel_spec_module, + MelSpec( + target_sample_rate=target_sample_rate, + hop_length=hop_length, + n_mel_channels=n_mel_channels, + ), ) def get_frame_len(self, index): @@ -201,6 +210,7 @@ def load_dataset( tokenizer: str = "pinyin", dataset_type: str = "CustomDataset", audio_type: str = "raw", + mel_spec_module: nn.Module | None = None, mel_spec_kwargs: dict = dict(), ) -> CustomDataset | HFDataset: """ @@ -224,7 +234,11 @@ def load_dataset( data_dict = json.load(f) durations = data_dict["duration"] train_dataset = CustomDataset( - train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs + train_dataset, + durations=durations, + preprocessed_mel=preprocessed_mel, + mel_spec_module=mel_spec_module, + **mel_spec_kwargs, ) elif dataset_type == "CustomDatasetPath":