Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
SWivid authored Oct 21, 2024
2 parents 2fc7b9c + e80addf commit 98bc37f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 35 deletions.
26 changes: 14 additions & 12 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,18 @@ def infer(
ref_file,
ref_text,
gen_text,
show_info=print,
progress=tqdm,
target_rms=0.1,
cross_fade_duration=0.15,
sway_sampling_coef=-1,
cfg_strength=2,
nfe_step=32,
speed=1.0,
fix_duration=None,
remove_silence=False,
file_wave=None,
file_spect=None,
cross_fade_duration=0.15,
show_info=print,
progress=tqdm,
file_spect=None
seed=-1,
):
if seed == -1:
Expand All @@ -95,14 +96,15 @@ def infer(
ref_text,
gen_text,
self.ema_model,
cross_fade_duration,
speed,
show_info,
progress,
nfe_step,
cfg_strength,
sway_sampling_coef,
fix_duration,
show_info=show_info,
progress=progress,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
)

if file_wave is not None:
Expand Down
50 changes: 27 additions & 23 deletions model/utils_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
# nfe_step = 32 # 16, 32
# cfg_strength = 2.0
# ode_method = "euler"
# sway_sampling_coef = -1.0
# speed = 1.0
# fix_duration = None
cross_fade_duration = 0.15
ode_method = "euler"
nfe_step = 32 # 16, 32
cfg_strength = 2.0
sway_sampling_coef = -1.0
speed = 1.0
fix_duration = None

# -----------------------------------------

Expand Down Expand Up @@ -107,7 +108,7 @@ def initialize_asr_pipeline(device=device):
# load model for inference


def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method="euler", use_ema=True, device=device):
def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=device):
if vocab_file == "":
vocab_file = "Emilia_ZH_EN"
tokenizer = "pinyin"
Expand Down Expand Up @@ -192,14 +193,15 @@ def infer_process(
ref_text,
gen_text,
model_obj,
cross_fade_duration=0.15,
speed=1.0,
show_info=print,
progress=tqdm,
nfe_step=32,
cfg_strength=2,
sway_sampling_coef=-1,
fix_duration=None,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
):
# Split the input text into batches
audio, sr = torchaudio.load(ref_audio)
Expand All @@ -214,13 +216,14 @@ def infer_process(
ref_text,
gen_text_batches,
model_obj,
cross_fade_duration,
speed,
progress,
nfe_step,
cfg_strength,
sway_sampling_coef,
fix_duration,
progress=progress,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
)


Expand All @@ -232,12 +235,13 @@ def infer_batch_process(
ref_text,
gen_text_batches,
model_obj,
cross_fade_duration=0.15,
speed=1,
progress=tqdm,
target_rms=0.1,
cross_fade_duration=0.15,
nfe_step=32,
cfg_strength=2.0,
sway_sampling_coef=-1,
speed=1,
fix_duration=None,
):
audio, sr = ref_audio
Expand All @@ -262,11 +266,11 @@ def infer_batch_process(
text_list = [ref_text + gen_text]
final_text_list = convert_char_to_pinyin(text_list)

ref_audio_len = audio.shape[-1] // hop_length
if fix_duration is not None:
duration = int(fix_duration * target_sample_rate / hop_length)
else:
# Calculate duration
ref_audio_len = audio.shape[-1] // hop_length
ref_text_len = len(ref_text.encode("utf-8"))
gen_text_len = len(gen_text.encode("utf-8"))
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
Expand Down

0 comments on commit 98bc37f

Please sign in to comment.