Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
lpscr committed Oct 21, 2024
1 parent e3ece35 commit e6d9029
Showing 1 changed file with 47 additions and 12 deletions.
59 changes: 47 additions & 12 deletions model/utils_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
#nfe_step = 32 # 16, 32
#cfg_strength = 2.0
#ode_method = "euler"
#sway_sampling_coef = -1.0
#speed = 1.0
#fix_duration = None
# nfe_step = 32 # 16, 32
# cfg_strength = 2.0
# ode_method = "euler"
# sway_sampling_coef = -1.0
# speed = 1.0
# fix_duration = None

# -----------------------------------------

Expand Down Expand Up @@ -84,7 +84,7 @@ def chunk_text(text, max_chars=135):
# load vocoder


def load_vocoder(is_local=False, local_path="",device=device):
def load_vocoder(is_local=False, local_path="", device=device):
if is_local:
print(f"Load vocos from local path {local_path}")
vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
Expand All @@ -100,7 +100,7 @@ def load_vocoder(is_local=False, local_path="",device=device):
# load model for inference


def load_model(model_cls, model_cfg, ckpt_path, vocab_file="",ode_method="euler",use_ema=True,device=device):
def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method="euler", use_ema=True, device=device):
if vocab_file == "":
vocab_file = "Emilia_ZH_EN"
tokenizer = "pinyin"
Expand Down Expand Up @@ -178,7 +178,19 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):


def infer_process(
ref_audio, ref_text, gen_text, model_obj, cross_fade_duration=0.15, speed=1.0, show_info=print, progress=tqdm,nfe_step=32,cfg_strength=2,sway_sampling_coef=-1,fix_duration=None):
ref_audio,
ref_text,
gen_text,
model_obj,
cross_fade_duration=0.15,
speed=1.0,
show_info=print,
progress=tqdm,
nfe_step=32,
cfg_strength=2,
sway_sampling_coef=-1,
fix_duration=None,
):
# Split the input text into batches
audio, sr = torchaudio.load(ref_audio)
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
Expand All @@ -187,14 +199,36 @@ def infer_process(
print(f"gen_text {i}", gen_text)

show_info(f"Generating audio in {len(gen_text_batches)} batches...")
return infer_batch_process((audio, sr), ref_text, gen_text_batches, model_obj, cross_fade_duration, speed, progress,nfe_step,cfg_strength,sway_sampling_coef,fix_duration)
return infer_batch_process(
(audio, sr),
ref_text,
gen_text_batches,
model_obj,
cross_fade_duration,
speed,
progress,
nfe_step,
cfg_strength,
sway_sampling_coef,
fix_duration,
)


# infer batches


def infer_batch_process(
ref_audio, ref_text, gen_text_batches, model_obj, cross_fade_duration=0.15, speed=1, progress=tqdm,nfe_step=32,cfg_strength=2.0,sway_sampling_coef=-1,fix_duration=None
ref_audio,
ref_text,
gen_text_batches,
model_obj,
cross_fade_duration=0.15,
speed=1,
progress=tqdm,
nfe_step=32,
cfg_strength=2.0,
sway_sampling_coef=-1,
fix_duration=None,
):
audio, sr = ref_audio
if audio.shape[0] > 1:
Expand All @@ -220,7 +254,7 @@ def infer_batch_process(

if fix_duration is not None:
duration = int(fix_duration * target_sample_rate / hop_length)
else:
else:
# Calculate duration
ref_audio_len = audio.shape[-1] // hop_length
ref_text_len = len(ref_text.encode("utf-8"))
Expand Down Expand Up @@ -293,6 +327,7 @@ def infer_batch_process(

return final_wave, target_sample_rate, combined_spectrogram


# remove silence from generated wav
def remove_silence_for_generated_wav(filename):
aseg = AudioSegment.from_file(filename)
Expand Down

0 comments on commit e6d9029

Please sign in to comment.