Skip to content

Commit

Permalink
import transcribe from utils_infer
Browse files Browse the repository at this point in the history
  • Loading branch information
lpscr committed Nov 16, 2024
1 parent c4d7252 commit 96946f8
Showing 1 changed file with 5 additions and 29 deletions.
34 changes: 5 additions & 29 deletions src/f5_tts/train/finetune_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@
from datasets.arrow_writer import ArrowWriter
from safetensors.torch import save_file
from scipy.io import wavfile
from transformers import pipeline
from cached_path import cached_path
from f5_tts.api import F5TTS
from f5_tts.model.utils import convert_char_to_pinyin
from f5_tts.infer.utils_infer import transcribe
from importlib.resources import files


training_process = None
system = platform.system()
python_executable = sys.executable or "python"
Expand All @@ -47,8 +48,6 @@

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

pipe = None


# Save settings from a JSON file
def save_settings(
Expand Down Expand Up @@ -390,17 +389,15 @@ def start_training(
logger="wandb",
ch_8bit_adam=False,
):
global training_process, tts_api, stop_signal, pipe
global training_process, tts_api, stop_signal

if tts_api is not None or pipe is not None:
if tts_api is not None:
if tts_api is not None:
del tts_api
if pipe is not None:
del pipe

gc.collect()
torch.cuda.empty_cache()
tts_api = None
pipe = None

path_project = os.path.join(path_data, dataset_name)

Expand Down Expand Up @@ -652,27 +649,6 @@ def create_data_project(name, tokenizer_type):
return gr.update(choices=project_list, value=name)


def transcribe(file_audio, language="english"):
global pipe

if pipe is None:
pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=torch.float16,
device=device,
)

text_transcribe = pipe(
file_audio,
chunk_length_s=30,
batch_size=128,
generate_kwargs={"task": "transcribe", "language": language},
return_timestamps=False,
)["text"].strip()
return text_transcribe


def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
path_project = os.path.join(path_data, name_project)
path_dataset = os.path.join(path_project, "dataset")
Expand Down

0 comments on commit 96946f8

Please sign in to comment.