Skip to content

Commit

Permalink
fix device cpu and cuda:0!
Browse files Browse the repository at this point in the history
  • Loading branch information
lpscr committed Oct 22, 2024
1 parent a41731b commit 989088b
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions model/utils_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@
convert_char_to_pinyin,
)

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
# get device


def get_device():
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# print(f"Using {device} device")
return device


vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")

Expand Down Expand Up @@ -76,7 +82,9 @@ def chunk_text(text, max_chars=135):


# load vocoder
def load_vocoder(is_local=False, local_path="", device=device):
def load_vocoder(is_local=False, local_path="", device=None):
if device is None:
device = get_device()
if is_local:
print(f"Load vocos from local path {local_path}")
vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
Expand All @@ -94,8 +102,10 @@ def load_vocoder(is_local=False, local_path="", device=device):
asr_pipe = None


def initialize_asr_pipeline(device=device):
def initialize_asr_pipeline(device=None):
global asr_pipe
if device is None:
device = get_device()

asr_pipe = pipeline(
"automatic-speech-recognition",
Expand All @@ -108,7 +118,9 @@ def initialize_asr_pipeline(device=device):
# load model for inference


def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=device):
def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=None):
if device is None:
device = get_device()
if vocab_file == "":
vocab_file = "Emilia_ZH_EN"
tokenizer = "pinyin"
Expand Down Expand Up @@ -141,7 +153,9 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_me
# preprocess reference audio and text


def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=device):
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=None):
device = get_device(device)

show_info("Converting audio...")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
aseg = AudioSegment.from_file(ref_audio_orig)
Expand Down Expand Up @@ -243,7 +257,11 @@ def infer_batch_process(
sway_sampling_coef=-1,
speed=1,
fix_duration=None,
device=None,
):
if device is None:
device = get_device()

audio, sr = ref_audio
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
Expand All @@ -254,7 +272,7 @@ def infer_batch_process(
if sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
audio = resampler(audio)
audio = audio.to(device)
audio = audio.to()

generated_waves = []
spectrograms = []
Expand Down

0 comments on commit 989088b

Please sign in to comment.