Skip to content

Commit

Permalink
update infer limits
Browse files Browse the repository at this point in the history
  • Loading branch information
lpscr committed Oct 21, 2024
1 parent d4cb542 commit 5d02a54
Showing 1 changed file with 21 additions and 22 deletions.
43 changes: 21 additions & 22 deletions model/utils_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
nfe_step = 32 # 16, 32
cfg_strength = 2.0
ode_method = "euler"
sway_sampling_coef = -1.0
speed = 1.0
fix_duration = None
#nfe_step = 32 # 16, 32
#cfg_strength = 2.0
#ode_method = "euler"
#sway_sampling_coef = -1.0
#speed = 1.0
#fix_duration = None

# -----------------------------------------

Expand Down Expand Up @@ -84,7 +84,7 @@ def chunk_text(text, max_chars=135):
# load vocoder


def load_vocoder(is_local=False, local_path=""):
def load_vocoder(is_local=False, local_path="",device=device):
if is_local:
print(f"Load vocos from local path {local_path}")
vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
Expand All @@ -100,14 +100,14 @@ def load_vocoder(is_local=False, local_path=""):
# load model for inference


def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""):
def load_model(model_cls, model_cfg, ckpt_path, vocab_file="",ode_method="euler",use_ema=True,device=device):
if vocab_file == "":
vocab_file = "Emilia_ZH_EN"
tokenizer = "pinyin"
else:
tokenizer = "custom"

print("\nvocab : ", vocab_file, tokenizer)
print("\nvocab : ", vocab_file)
print("tokenizer : ", tokenizer)
print("model : ", ckpt_path, "\n")

Expand All @@ -125,7 +125,7 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""):
vocab_char_map=vocab_char_map,
).to(device)

model = load_checkpoint(model, ckpt_path, device, use_ema=True)
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)

return model

Expand Down Expand Up @@ -178,8 +178,7 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):


def infer_process(
ref_audio, ref_text, gen_text, model_obj, cross_fade_duration=0.15, speed=speed, show_info=print, progress=tqdm
):
ref_audio, ref_text, gen_text, model_obj, cross_fade_duration=0.15, speed=1.0, show_info=print, progress=tqdm,nfe_step=32,cfg_strength=2,sway_sampling_coef=-1,fix_duration=None):
# Split the input text into batches
audio, sr = torchaudio.load(ref_audio)
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
Expand All @@ -188,14 +187,14 @@ def infer_process(
print(f"gen_text {i}", gen_text)

show_info(f"Generating audio in {len(gen_text_batches)} batches...")
return infer_batch_process((audio, sr), ref_text, gen_text_batches, model_obj, cross_fade_duration, speed, progress)
return infer_batch_process((audio, sr), ref_text, gen_text_batches, model_obj, cross_fade_duration, speed, progress,nfe_step,cfg_strength,sway_sampling_coef,fix_duration)


# infer batches


def infer_batch_process(
ref_audio, ref_text, gen_text_batches, model_obj, cross_fade_duration=0.15, speed=1, progress=tqdm
ref_audio, ref_text, gen_text_batches, model_obj, cross_fade_duration=0.15, speed=1, progress=tqdm,nfe_step=32,cfg_strength=2.0,sway_sampling_coef=-1,fix_duration=None
):
audio, sr = ref_audio
if audio.shape[0] > 1:
Expand All @@ -219,11 +218,14 @@ def infer_batch_process(
text_list = [ref_text + gen_text]
final_text_list = convert_char_to_pinyin(text_list)

# Calculate duration
ref_audio_len = audio.shape[-1] // hop_length
ref_text_len = len(ref_text.encode("utf-8"))
gen_text_len = len(gen_text.encode("utf-8"))
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
if fix_duration is not None:
duration = int(fix_duration * target_sample_rate / hop_length)
else:
# Calculate duration
ref_audio_len = audio.shape[-1] // hop_length
ref_text_len = len(ref_text.encode("utf-8"))
gen_text_len = len(gen_text.encode("utf-8"))
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)

# inference
with torch.inference_mode():
Expand Down Expand Up @@ -291,10 +293,7 @@ def infer_batch_process(

return final_wave, target_sample_rate, combined_spectrogram


# remove silence from generated wav


def remove_silence_for_generated_wav(filename):
aseg = AudioSegment.from_file(filename)
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
Expand Down

0 comments on commit 5d02a54

Please sign in to comment.