From 6a333fa15d7f931ee200eb5fe5624e618676f69b Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 21 Oct 2024 11:39:45 +0300 Subject: [PATCH] update --- api.py | 112 +++++++++++++++++++++++++++++++++------------------------ 1 file changed, 65 insertions(+), 47 deletions(-) diff --git a/api.py b/api.py index 4eb4833f4..54c37f158 100644 --- a/api.py +++ b/api.py @@ -1,36 +1,25 @@ -import argparse -import codecs -import re -import tempfile -from pathlib import Path - -import numpy as np import soundfile as sf -import tomli import torch -import torchaudio import tqdm from cached_path import cached_path -from einops import rearrange -from pydub import AudioSegment, silence -from transformers import pipeline -from vocos import Vocos - -from model import CFM, DiT, MMDiT, UNetT -from model.utils import (convert_char_to_pinyin, get_tokenizer, - load_checkpoint, save_spectrogram) - -from model.utils_infer import ( - load_vocoder, - load_model, - preprocess_ref_audio_text, - infer_process, - remove_silence_for_generated_wav, - chunk_text -) + +from model import DiT, UNetT +from model.utils import save_spectrogram + +from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav + class F5TTS: - def __init__(self, model_type="F5-TTS", ckpt_file="", vocab_file="", ode_method="euler", use_ema=True, local_path=None, device=None): + def __init__( + self, + model_type="F5-TTS", + ckpt_file="", + vocab_file="", + ode_method="euler", + use_ema=True, + local_path=None, + device=None, + ): # Initialize parameters self.final_wave = None self.target_sample_rate = 24000 @@ -39,17 +28,18 @@ def __init__(self, model_type="F5-TTS", ckpt_file="", vocab_file="", ode_method= self.target_rms = 0.1 # Set device - self.device = device or ("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") - + self.device = device or ( + "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + ) + # Load models self.load_vecoder_model(local_path) self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema) def load_vecoder_model(self, local_path): - self.vocos = load_vocoder(local_path is not None,local_path,self.device) + self.vocos = load_vocoder(local_path is not None, local_path, self.device) def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema): - if model_type == "F5-TTS": if not ckpt_file: ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")) @@ -66,37 +56,65 @@ def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema) self.ema_model = self.load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema) def load_model(self, model_cls, model_cfg, ckpt_path, file_vocab, ode_method, use_ema): - return load_model(model_cls, model_cfg, ckpt_path, file_vocab,ode_method,use_ema,self.device) + return load_model(model_cls, model_cfg, ckpt_path, file_vocab, ode_method, use_ema, self.device) - def export_wav(self,wav, file_wave, remove_silence=False): + def export_wav(self, wav, file_wave, remove_silence=False): if remove_silence: remove_silence_for_generated_wav(file_wave) sf.write(file_wave, wav, self.target_sample_rate) - - def export_spectrogram(self,spect, file_spect): + + def export_spectrogram(self, spect, file_spect): save_spectrogram(spect, file_spect) - - def infer(self, ref_file, ref_text, gen_text, sway_sampling_coef=-1, cfg_strength=2, nfe_step=32, speed=1.0, fix_duration=None, remove_silence=False, file_wave=None, file_spect=None, cross_fade_duration=0.15, show_info=print, progress=tqdm): - wav,sr,spect = infer_process(ref_file, ref_text, gen_text, self.ema_model, cross_fade_duration, speed, show_info, progress,nfe_step,cfg_strength,sway_sampling_coef,fix_duration) + def infer( + self, + ref_file, + ref_text, + gen_text, + sway_sampling_coef=-1, + cfg_strength=2, + nfe_step=32, + speed=1.0, + fix_duration=None, + remove_silence=False, + file_wave=None, + file_spect=None, + cross_fade_duration=0.15, + show_info=print, + progress=tqdm, + ): + wav, sr, spect = infer_process( + ref_file, + ref_text, + gen_text, + self.ema_model, + cross_fade_duration, + speed, + show_info, + progress, + nfe_step, + cfg_strength, + sway_sampling_coef, + fix_duration, + ) if file_wave is not None: - self.export_wav(wav,file_wave,remove_silence) - + self.export_wav(wav, file_wave, remove_silence) + if file_spect is not None: - self.export_spectrogram(spect,file_spect) + self.export_spectrogram(spect, file_spect) - return wav,sr,spect + return wav, sr, spect -if __name__ == "__main__": - f5tts=F5TTS() +if __name__ == "__main__": + f5tts = F5TTS() - wav,sr,spect=f5tts.infer( + wav, sr, spect = f5tts.infer( ref_file="tests/ref_audio/test_en_1_ref_short.wav", ref_text="some call me nature, others call me mother nature.", gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""", file_wave="tests/out.wav", - file_spect="tests/out.png" - ) \ No newline at end of file + file_spect="tests/out.png", + )