diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index 1de9048a..76f08e67 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -150,6 +150,22 @@ def initialize_asr_pipeline(device=device, dtype=None): ) +# transcribe + + +def transcribe(ref_audio, language=None): + global asr_pipe + if asr_pipe is None: + initialize_asr_pipeline(device=device) + return asr_pipe( + ref_audio, + chunk_length_s=30, + batch_size=128, + generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"}, + return_timestamps=False, + )["text"].strip() + + # load model checkpoint for inference @@ -306,17 +322,8 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in show_info("Using cached reference text...") ref_text = _ref_audio_cache[audio_hash] else: - global asr_pipe - if asr_pipe is None: - initialize_asr_pipeline(device=device) show_info("No reference text provided, transcribing reference audio...") - ref_text = asr_pipe( - ref_audio, - chunk_length_s=30, - batch_size=128, - generate_kwargs={"task": "transcribe"}, - return_timestamps=False, - )["text"].strip() + ref_text = transcribe(ref_audio) # Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak) _ref_audio_cache[audio_hash] = ref_text else: