Skip to content

Commit

Permalink
fix. default fp32 for ZLUDA #578
Browse files Browse the repository at this point in the history
  • Loading branch information
SWivid committed Dec 5, 2024
1 parent eea65de commit 7f7fd29
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/f5_tts/infer/utils_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
def initialize_asr_pipeline(device: str = device, dtype=None):
if dtype is None:
dtype = (
torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
torch.float16
if "cuda" in device
and torch.cuda.get_device_properties(device).major >= 6
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
global asr_pipe
asr_pipe = pipeline(
Expand Down Expand Up @@ -171,7 +175,11 @@ def transcribe(ref_audio, language=None):
def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
if dtype is None:
dtype = (
torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
torch.float16
if "cuda" in device
and torch.cuda.get_device_properties(device).major >= 6
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
model = model.to(dtype)

Expand Down

0 comments on commit 7f7fd29

Please sign in to comment.