From 1ef4765df1f73b67712b61dee0e0f394abdf568d Mon Sep 17 00:00:00 2001 From: Blaise Date: Sat, 21 Dec 2024 20:18:54 +0100 Subject: [PATCH] improve models selector logic --- tabs/inference/inference.py | 14 ++++++++++++-- tabs/tts/tts.py | 7 ++++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/tabs/inference/inference.py b/tabs/inference/inference.py index 284f1518..2382c9ab 100644 --- a/tabs/inference/inference.py +++ b/tabs/inference/inference.py @@ -62,6 +62,8 @@ ) ] +default_weight = names[0] if names else None + indexes_list = [ os.path.join(root, name) for root, _, files in os.walk(model_root_relative, topdown=False) @@ -239,6 +241,15 @@ def get_indexes(): return indexes_list if indexes_list else "" +def extract_model_and_epoch(path): + base_name = os.path.basename(path) + match = re.match(r'(.+?)_(\d+)e_', base_name) + if match: + model, epoch = match.groups() + return model, int(epoch) + return "", 0 + + def save_to_wav(record_button): if record_button is None: pass @@ -337,13 +348,12 @@ def get_speakers_id(model): # Inference tab def inference_tab(): - default_weight = names[0] if names else None with gr.Column(): with gr.Row(): model_file = gr.Dropdown( label=i18n("Voice Model"), info=i18n("Select the voice model to use for the conversion."), - choices=sorted(names, key=lambda path: os.path.getsize(path)), + choices=sorted(names, key=lambda x: extract_model_and_epoch(x)), interactive=True, value=default_weight, allow_custom_value=True, diff --git a/tabs/tts/tts.py b/tabs/tts/tts.py index bf3f78c6..f700fbbc 100644 --- a/tabs/tts/tts.py +++ b/tabs/tts/tts.py @@ -16,13 +16,14 @@ get_indexes, get_speakers_id, match_index, - names, refresh_embedders_folders, + extract_model_and_epoch, + names, + default_weight, ) i18n = I18nAuto() -default_weight = random.choice(names) if names else "" with open( os.path.join("rvc", "lib", "tools", "tts_voices.json"), "r", encoding="utf-8" @@ -50,7 +51,7 @@ def tts_tab(): model_file = gr.Dropdown( label=i18n("Voice Model"), info=i18n("Select the voice model to use for the conversion."), - choices=sorted(names, key=lambda path: os.path.getsize(path)), + choices=sorted(names, key=lambda x: extract_model_and_epoch(x)), interactive=True, value=default_weight, allow_custom_value=True,