Skip to content

Commit

Permalink
fix. add dtype check for asr pipeline addressing SWivid#356
Browse files Browse the repository at this point in the history
  • Loading branch information
SWivid committed Nov 2, 2024
1 parent f7e248e commit ea90244
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/f5_tts/infer/utils_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,16 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
asr_pipe = None


def initialize_asr_pipeline(device=device):
def initialize_asr_pipeline(device=device, dtype=None):
if dtype is None:
dtype = (
torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
)
global asr_pipe
asr_pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=torch.float16,
torch_dtype=dtype,
device=device,
)

Expand Down

0 comments on commit ea90244

Please sign in to comment.