From b899a35b88db44de6a1fbb4d107297f5c9b36364 Mon Sep 17 00:00:00 2001 From: SWivid Date: Mon, 21 Oct 2024 17:45:06 +0800 Subject: [PATCH 1/2] load asr pipeline only if needed --- api.py | 4 ++-- inference-cli.py | 4 ++-- model/utils_infer.py | 32 ++++++++++++++++++++++---------- 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/api.py b/api.py index 0b639a95..d5550834 100644 --- a/api.py +++ b/api.py @@ -33,10 +33,10 @@ def __init__( ) # Load models - self.load_vecoder_model(local_path) + self.load_vocoder_model(local_path) self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema) - def load_vecoder_model(self, local_path): + def load_vocoder_model(self, local_path): self.vocos = load_vocoder(local_path is not None, local_path, self.device) def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema): diff --git a/inference-cli.py b/inference-cli.py index 3d4bd153..1e74eec2 100644 --- a/inference-cli.py +++ b/inference-cli.py @@ -104,7 +104,7 @@ exp_name = "F5TTS_Base" ckpt_step = 1200000 ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) - # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path + # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path elif model == "E2-TTS": model_cls = UNetT @@ -114,7 +114,7 @@ exp_name = "E2TTS_Base" ckpt_step = 1200000 ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) - # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path + # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path print(f"Using {model}...") ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file) diff --git a/model/utils_infer.py b/model/utils_infer.py index da87f7a4..8e625687 100644 --- a/model/utils_infer.py +++ b/model/utils_infer.py @@ -22,13 +22,6 @@ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" print(f"Using {device} device") -asr_pipe = pipeline( - "automatic-speech-recognition", - model="openai/whisper-large-v3-turbo", - torch_dtype=torch.float16, - device=device, -) - vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") @@ -82,8 +75,6 @@ def chunk_text(text, max_chars=135): # load vocoder - - def load_vocoder(is_local=False, local_path="", device=device): if is_local: print(f"Load vocos from local path {local_path}") @@ -97,6 +88,22 @@ def load_vocoder(is_local=False, local_path="", device=device): return vocos +# load asr pipeline + +asr_pipe = None + + +def initialize_asr_pipeline(device=device): + global asr_pipe + + asr_pipe = pipeline( + "automatic-speech-recognition", + model="openai/whisper-large", + torch_dtype=torch.float16, + device=device, + ) + + # load model for inference @@ -133,7 +140,7 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method="euler # preprocess reference audio and text -def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print): +def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=device): show_info("Converting audio...") with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: aseg = AudioSegment.from_file(ref_audio_orig) @@ -152,6 +159,9 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print): ref_audio = f.name if not ref_text.strip(): + global asr_pipe + if asr_pipe is None: + initialize_asr_pipeline(device=device) show_info("No reference text provided, transcribing reference audio...") ref_text = asr_pipe( ref_audio, @@ -329,6 +339,8 @@ def infer_batch_process( # remove silence from generated wav + + def remove_silence_for_generated_wav(filename): aseg = AudioSegment.from_file(filename) non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500) From d15ef3679aa572453fc1d44f66947de71fa85e72 Mon Sep 17 00:00:00 2001 From: SWivid Date: Mon, 21 Oct 2024 17:55:58 +0800 Subject: [PATCH 2/2] fix address #191 --- model/backbones/dit.py | 2 +- model/backbones/unett.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/model/backbones/dit.py b/model/backbones/dit.py index 9ff53513..b8e6dc3f 100644 --- a/model/backbones/dit.py +++ b/model/backbones/dit.py @@ -45,9 +45,9 @@ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): self.extra_modeling = False def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 - batch, text_len = text.shape[0], text.shape[1] text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens + batch, text_len = text.shape[0], text.shape[1] text = F.pad(text, (0, seq_len - text_len), value=0) if drop_text: # cfg for text diff --git a/model/backbones/unett.py b/model/backbones/unett.py index c4ce2c64..ac1d3d35 100644 --- a/model/backbones/unett.py +++ b/model/backbones/unett.py @@ -48,9 +48,9 @@ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): self.extra_modeling = False def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 - batch, text_len = text.shape[0], text.shape[1] text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens + batch, text_len = text.shape[0], text.shape[1] text = F.pad(text, (0, seq_len - text_len), value=0) if drop_text: # cfg for text