diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 1c739313f..016640d9c 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -1,3 +1,7 @@ +import threading +import queue +import re + import gc import json import os @@ -111,7 +115,7 @@ def load_settings(project_name): "epochs": 100, "num_warmup_updates": 2, "save_per_updates": 300, - "last_per_steps": 200, + "last_per_steps": 100, "finetune": True, "file_checkpoint_train": "", "tokenizer_type": "pinyin", @@ -369,8 +373,9 @@ def start_training( tokenizer_type="pinyin", tokenizer_file="", mixed_precision="fp16", + stream=False, ): - global training_process, tts_api + global training_process, tts_api, stop_signal if tts_api is not None: del tts_api @@ -430,6 +435,7 @@ def start_training( f"--last_per_steps {last_per_steps} " f"--dataset_name {dataset_name}" ) + if finetune: cmd += f" --finetune {finetune}" @@ -464,14 +470,112 @@ def start_training( ) try: - # Start the training process - training_process = subprocess.Popen(cmd, shell=True) + if not stream: + # Start the training process + training_process = subprocess.Popen(cmd, shell=True) + + time.sleep(5) + yield "train start", gr.update(interactive=False), gr.update(interactive=True) + + # Wait for the training process to finish + training_process.wait() + else: + + def stream_output(pipe, output_queue): + try: + for line in iter(pipe.readline, ""): + output_queue.put(line) + except Exception as e: + output_queue.put(f"Error reading pipe: {str(e)}") + finally: + pipe.close() + + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" - time.sleep(5) - yield "train start", gr.update(interactive=False), gr.update(interactive=True) + training_process = subprocess.Popen( + cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, env=env + ) + yield "Training started...", gr.update(interactive=False), gr.update(interactive=True) + + stdout_queue = queue.Queue() + stderr_queue = queue.Queue() + + stdout_thread = threading.Thread(target=stream_output, args=(training_process.stdout, stdout_queue)) + stderr_thread = threading.Thread(target=stream_output, args=(training_process.stderr, stderr_queue)) + stdout_thread.daemon = True + stderr_thread.daemon = True + stdout_thread.start() + stderr_thread.start() + stop_signal = False + while True: + if stop_signal: + training_process.terminate() + time.sleep(0.5) + if training_process.poll() is None: + training_process.kill() + yield "Training stopped by user.", gr.update(interactive=True), gr.update(interactive=False) + break + + process_status = training_process.poll() + + # Handle stdout + try: + while True: + output = stdout_queue.get_nowait() + print(output, end="") + match = re.search( + r"Epoch (\d+)/(\d+):\s+(\d+)%\|.*\[(\d+:\d+)<.*?loss=(\d+\.\d+), step=(\d+)", output + ) + if match: + current_epoch = match.group(1) + total_epochs = match.group(2) + percent_complete = match.group(3) + elapsed_time = match.group(4) + loss = match.group(5) + current_step = match.group(6) + message = ( + f"Epoch: {current_epoch}/{total_epochs}, " + f"Progress: {percent_complete}%, " + f"Elapsed Time: {elapsed_time}, " + f"Loss: {loss}, " + f"Step: {current_step}" + ) + yield message, gr.update(interactive=False), gr.update(interactive=True) + elif output.strip(): + yield output, gr.update(interactive=False), gr.update(interactive=True) + except queue.Empty: + pass + + # Handle stderr + try: + while True: + error_output = stderr_queue.get_nowait() + print(error_output, end="") + if error_output.strip(): + yield f"{error_output.strip()}", gr.update(interactive=False), gr.update(interactive=True) + except queue.Empty: + pass + + if process_status is not None and stdout_queue.empty() and stderr_queue.empty(): + if process_status != 0: + yield ( + f"Process crashed with exit code {process_status}!", + gr.update(interactive=False), + gr.update(interactive=True), + ) + else: + yield "Training complete!", gr.update(interactive=False), gr.update(interactive=True) + break + + # Small sleep to prevent CPU thrashing + time.sleep(0.1) + + # Clean up + training_process.stdout.close() + training_process.stderr.close() + training_process.wait() - # Wait for the training process to finish - training_process.wait() time.sleep(1) if training_process is None: @@ -489,11 +593,13 @@ def start_training( def stop_training(): - global training_process + global training_process, stop_signal + if training_process is None: return "Train not run !", gr.update(interactive=True), gr.update(interactive=False) terminate_process_tree(training_process.pid) - training_process = None + # training_process = None + stop_signal = True return "train stop", gr.update(interactive=True), gr.update(interactive=False) @@ -1202,7 +1308,11 @@ def get_combined_stats(): project_name = gr.Textbox(label="project name", value="my_speak") bt_create = gr.Button("create new project") - cm_project = gr.Dropdown(choices=projects, value=projects_selelect, label="Project", allow_custom_value=True) + with gr.Row(): + cm_project = gr.Dropdown( + choices=projects, value=projects_selelect, label="Project", allow_custom_value=True, scale=6 + ) + ch_refresh_project = gr.Button("refresh", scale=1) bt_create.click(fn=create_data_project, inputs=[project_name, tokenizer_type], outputs=[cm_project]) @@ -1304,6 +1414,7 @@ def get_combined_stats(): bt_prepare = bt_create = gr.Button("prepare") txt_info_prepare = gr.Text(label="info", value="") txt_vocab_prepare = gr.Text(label="vocab", value="") + bt_prepare.click( fn=create_metadata, inputs=[cm_project, ch_tokenizern], outputs=[txt_info_prepare, txt_vocab_prepare] ) @@ -1347,11 +1458,11 @@ def get_combined_stats(): with gr.Row(): epochs = gr.Number(label="Epochs", value=10) - num_warmup_updates = gr.Number(label="Warmup Updates", value=5) + num_warmup_updates = gr.Number(label="Warmup Updates", value=2) with gr.Row(): - save_per_updates = gr.Number(label="Save per Updates", value=10) - last_per_steps = gr.Number(label="Last per Steps", value=50) + save_per_updates = gr.Number(label="Save per Updates", value=300) + last_per_steps = gr.Number(label="Last per Steps", value=100) with gr.Row(): mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none") @@ -1394,6 +1505,7 @@ def get_combined_stats(): tokenizer_file.value = tokenizer_filev mixed_precision.value = mixed_precisionv + ch_stream = gr.Checkbox(label="stream output experiment.", value=True) txt_info_train = gr.Text(label="info", value="") start_button.click( fn=start_training, @@ -1415,6 +1527,7 @@ def get_combined_stats(): tokenizer_type, tokenizer_file, mixed_precision, + ch_stream, ], outputs=[txt_info_train, start_button, stop_button], ) @@ -1448,10 +1561,8 @@ def get_combined_stats(): check_finetune, inputs=[ch_finetune], outputs=[file_checkpoint_train, tokenizer_file, tokenizer_type] ) - cm_project.change( - fn=load_settings, - inputs=[cm_project], - outputs=[ + def setup_load_settings(): + output_components = [ exp_name, learning_rate, batch_size_per_gpu, @@ -1468,7 +1579,22 @@ def get_combined_stats(): tokenizer_type, tokenizer_file, mixed_precision, - ], + ] + + return output_components + + outputs = setup_load_settings() + + cm_project.change( + fn=load_settings, + inputs=[cm_project], + outputs=outputs, + ) + + ch_refresh_project.click( + fn=load_settings, + inputs=[cm_project], + outputs=outputs, ) with gr.TabItem("test model"):