From 96946f85fa0c1eecb971f9f30f2e5b922d270e5c Mon Sep 17 00:00:00 2001 From: unknown Date: Sat, 16 Nov 2024 18:39:52 +0200 Subject: [PATCH] import transcribe from utils_infer --- src/f5_tts/train/finetune_gradio.py | 34 +++++------------------------ 1 file changed, 5 insertions(+), 29 deletions(-) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 71665497..8b6d8a54 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -26,12 +26,13 @@ from datasets.arrow_writer import ArrowWriter from safetensors.torch import save_file from scipy.io import wavfile -from transformers import pipeline from cached_path import cached_path from f5_tts.api import F5TTS from f5_tts.model.utils import convert_char_to_pinyin +from f5_tts.infer.utils_infer import transcribe from importlib.resources import files + training_process = None system = platform.system() python_executable = sys.executable or "python" @@ -47,8 +48,6 @@ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" -pipe = None - # Save settings from a JSON file def save_settings( @@ -390,17 +389,15 @@ def start_training( logger="wandb", ch_8bit_adam=False, ): - global training_process, tts_api, stop_signal, pipe + global training_process, tts_api, stop_signal - if tts_api is not None or pipe is not None: + if tts_api is not None: if tts_api is not None: del tts_api - if pipe is not None: - del pipe + gc.collect() torch.cuda.empty_cache() tts_api = None - pipe = None path_project = os.path.join(path_data, dataset_name) @@ -652,27 +649,6 @@ def create_data_project(name, tokenizer_type): return gr.update(choices=project_list, value=name) -def transcribe(file_audio, language="english"): - global pipe - - if pipe is None: - pipe = pipeline( - "automatic-speech-recognition", - model="openai/whisper-large-v3-turbo", - torch_dtype=torch.float16, - device=device, - ) - - text_transcribe = pipe( - file_audio, - chunk_length_s=30, - batch_size=128, - generate_kwargs={"task": "transcribe", "language": language}, - return_timestamps=False, - )["text"].strip() - return text_transcribe - - def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()): path_project = os.path.join(path_data, name_project) path_dataset = os.path.join(path_project, "dataset")