From 989088bbbb166eb7eecfdf2021d7ef4a923a0688 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 22 Oct 2024 11:56:58 +0300 Subject: [PATCH] fix device cpu and cuda:0! --- model/utils_infer.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/model/utils_infer.py b/model/utils_infer.py index 1cc2c98df..615f82c3c 100644 --- a/model/utils_infer.py +++ b/model/utils_infer.py @@ -19,8 +19,14 @@ convert_char_to_pinyin, ) -device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" -print(f"Using {device} device") +# get device + + +def get_device(): + device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + # print(f"Using {device} device") + return device + vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") @@ -76,7 +82,9 @@ def chunk_text(text, max_chars=135): # load vocoder -def load_vocoder(is_local=False, local_path="", device=device): +def load_vocoder(is_local=False, local_path="", device=None): + if device is None: + device = get_device() if is_local: print(f"Load vocos from local path {local_path}") vocos = Vocos.from_hparams(f"{local_path}/config.yaml") @@ -94,8 +102,10 @@ def load_vocoder(is_local=False, local_path="", device=device): asr_pipe = None -def initialize_asr_pipeline(device=device): +def initialize_asr_pipeline(device=None): global asr_pipe + if device is None: + device = get_device() asr_pipe = pipeline( "automatic-speech-recognition", @@ -108,7 +118,9 @@ def initialize_asr_pipeline(device=device): # load model for inference -def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=device): +def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=None): + if device is None: + device = get_device() if vocab_file == "": vocab_file = "Emilia_ZH_EN" tokenizer = "pinyin" @@ -141,7 +153,9 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_me # preprocess reference audio and text -def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=device): +def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=None): + device = get_device(device) + show_info("Converting audio...") with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: aseg = AudioSegment.from_file(ref_audio_orig) @@ -243,7 +257,11 @@ def infer_batch_process( sway_sampling_coef=-1, speed=1, fix_duration=None, + device=None, ): + if device is None: + device = get_device() + audio, sr = ref_audio if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) @@ -254,7 +272,7 @@ def infer_batch_process( if sr != target_sample_rate: resampler = torchaudio.transforms.Resample(sr, target_sample_rate) audio = resampler(audio) - audio = audio.to(device) + audio = audio.to() generated_waves = [] spectrograms = []