Skip to content

Commit

Permalink
add speed seed remove_silence
Browse files Browse the repository at this point in the history
  • Loading branch information
lpscr committed Nov 5, 2024
1 parent bbe0d4d commit 2d2452e
Showing 1 changed file with 37 additions and 7 deletions.
44 changes: 37 additions & 7 deletions src/f5_tts/train/finetune_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

path_data = str(files("f5_tts").joinpath("../../data"))
path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts"))

file_train = "src/f5_tts/train/finetune_cli.py"

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
Expand Down Expand Up @@ -1220,7 +1221,9 @@ def get_random_sample_infer(project_name):
)


def infer(project, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, use_ema):
def infer(
project, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, use_ema, speed, seed, remove_silence
):
global last_checkpoint, last_device, tts_api, last_ema

if not os.path.isfile(file_checkpoint):
Expand Down Expand Up @@ -1250,8 +1253,17 @@ def infer(project, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe
print("update >> ", device_test, file_checkpoint, use_ema)

with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
return f.name, tts_api.device
tts_api.infer(
gen_text=gen_text.lower().strip(),
ref_text=ref_text.lower().strip(),
ref_file=ref_audio,
nfe_step=nfe_step,
file_wave=f.name,
speed=speed,
seed=seed,
remove_silence=remove_silence,
)
return f.name, tts_api.device, str(tts_api.seed)


def check_finetune(finetune):
Expand Down Expand Up @@ -1748,12 +1760,17 @@ def setup_load_settings():

with gr.TabItem("Test Model"):
gr.Markdown("""```plaintext
SOS: Check the use_ema setting (True or False) for your model to see what works best for you.
SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random
```""")
exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)

nfe_step = gr.Number(label="NFE Step", value=32)
with gr.Row():
nfe_step = gr.Number(label="NFE Step", value=32)
speed = gr.Slider(label="Speed", value=1.0, minimum=0.3, maximum=2.0, step=0.1)
seed = gr.Number(label="Seed", value=-1, minimum=-1)
remove_silence = gr.Checkbox(label="Remove Silence")

ch_use_ema = gr.Checkbox(label="Use EMA", value=True)
with gr.Row():
cm_checkpoint = gr.Dropdown(
Expand All @@ -1773,14 +1790,27 @@ def setup_load_settings():

with gr.Row():
txt_info_gpu = gr.Textbox("", label="Device")
seed_info = gr.Text(label="Seed :")
check_button_infer = gr.Button("Infer")

gen_audio = gr.Audio(label="Audio Gen", type="filepath")

check_button_infer.click(
fn=infer,
inputs=[cm_project, cm_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, ch_use_ema],
outputs=[gen_audio, txt_info_gpu],
inputs=[
cm_project,
cm_checkpoint,
exp_name,
ref_text,
ref_audio,
gen_text,
nfe_step,
ch_use_ema,
speed,
seed,
remove_silence,
],
outputs=[gen_audio, txt_info_gpu, seed_info],
)

bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
Expand Down

0 comments on commit 2d2452e

Please sign in to comment.