diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index f451efc8a..33f9a3f21 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -119,12 +119,16 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev asr_pipe = None -def initialize_asr_pipeline(device=device): +def initialize_asr_pipeline(device=device, dtype=None): + if dtype is None: + dtype = ( + torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32 + ) global asr_pipe asr_pipe = pipeline( "automatic-speech-recognition", model="openai/whisper-large-v3-turbo", - torch_dtype=torch.float16, + torch_dtype=dtype, device=device, )