Skip to content

Commit

Permalink
add stream output
Browse files Browse the repository at this point in the history
  • Loading branch information
lpscr committed Oct 28, 2024
1 parent ae4ef3f commit 41eb33c
Showing 1 changed file with 145 additions and 19 deletions.
164 changes: 145 additions & 19 deletions src/f5_tts/train/finetune_gradio.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import threading
import queue
import re

import gc
import json
import os
Expand Down Expand Up @@ -111,7 +115,7 @@ def load_settings(project_name):
"epochs": 100,
"num_warmup_updates": 2,
"save_per_updates": 300,
"last_per_steps": 200,
"last_per_steps": 100,
"finetune": True,
"file_checkpoint_train": "",
"tokenizer_type": "pinyin",
Expand Down Expand Up @@ -369,8 +373,9 @@ def start_training(
tokenizer_type="pinyin",
tokenizer_file="",
mixed_precision="fp16",
stream=False,
):
global training_process, tts_api
global training_process, tts_api, stop_signal

if tts_api is not None:
del tts_api
Expand Down Expand Up @@ -430,6 +435,7 @@ def start_training(
f"--last_per_steps {last_per_steps} "
f"--dataset_name {dataset_name}"
)

if finetune:
cmd += f" --finetune {finetune}"

Expand Down Expand Up @@ -464,14 +470,112 @@ def start_training(
)

try:
# Start the training process
training_process = subprocess.Popen(cmd, shell=True)
if not stream:
# Start the training process
training_process = subprocess.Popen(cmd, shell=True)

time.sleep(5)
yield "train start", gr.update(interactive=False), gr.update(interactive=True)

# Wait for the training process to finish
training_process.wait()
else:

def stream_output(pipe, output_queue):
try:
for line in iter(pipe.readline, ""):
output_queue.put(line)
except Exception as e:
output_queue.put(f"Error reading pipe: {str(e)}")
finally:
pipe.close()

env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"

time.sleep(5)
yield "train start", gr.update(interactive=False), gr.update(interactive=True)
training_process = subprocess.Popen(
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, env=env
)
yield "Training started...", gr.update(interactive=False), gr.update(interactive=True)

stdout_queue = queue.Queue()
stderr_queue = queue.Queue()

stdout_thread = threading.Thread(target=stream_output, args=(training_process.stdout, stdout_queue))
stderr_thread = threading.Thread(target=stream_output, args=(training_process.stderr, stderr_queue))
stdout_thread.daemon = True
stderr_thread.daemon = True
stdout_thread.start()
stderr_thread.start()
stop_signal = False
while True:
if stop_signal:
training_process.terminate()
time.sleep(0.5)
if training_process.poll() is None:
training_process.kill()
yield "Training stopped by user.", gr.update(interactive=True), gr.update(interactive=False)
break

process_status = training_process.poll()

# Handle stdout
try:
while True:
output = stdout_queue.get_nowait()
print(output, end="")
match = re.search(
r"Epoch (\d+)/(\d+):\s+(\d+)%\|.*\[(\d+:\d+)<.*?loss=(\d+\.\d+), step=(\d+)", output
)
if match:
current_epoch = match.group(1)
total_epochs = match.group(2)
percent_complete = match.group(3)
elapsed_time = match.group(4)
loss = match.group(5)
current_step = match.group(6)
message = (
f"Epoch: {current_epoch}/{total_epochs}, "
f"Progress: {percent_complete}%, "
f"Elapsed Time: {elapsed_time}, "
f"Loss: {loss}, "
f"Step: {current_step}"
)
yield message, gr.update(interactive=False), gr.update(interactive=True)
elif output.strip():
yield output, gr.update(interactive=False), gr.update(interactive=True)
except queue.Empty:
pass

# Handle stderr
try:
while True:
error_output = stderr_queue.get_nowait()
print(error_output, end="")
if error_output.strip():
yield f"{error_output.strip()}", gr.update(interactive=False), gr.update(interactive=True)
except queue.Empty:
pass

if process_status is not None and stdout_queue.empty() and stderr_queue.empty():
if process_status != 0:
yield (
f"Process crashed with exit code {process_status}!",
gr.update(interactive=False),
gr.update(interactive=True),
)
else:
yield "Training complete!", gr.update(interactive=False), gr.update(interactive=True)
break

# Small sleep to prevent CPU thrashing
time.sleep(0.1)

# Clean up
training_process.stdout.close()
training_process.stderr.close()
training_process.wait()

# Wait for the training process to finish
training_process.wait()
time.sleep(1)

if training_process is None:
Expand All @@ -489,11 +593,13 @@ def start_training(


def stop_training():
global training_process
global training_process, stop_signal

if training_process is None:
return "Train not run !", gr.update(interactive=True), gr.update(interactive=False)
terminate_process_tree(training_process.pid)
training_process = None
# training_process = None
stop_signal = True
return "train stop", gr.update(interactive=True), gr.update(interactive=False)


Expand Down Expand Up @@ -1202,7 +1308,11 @@ def get_combined_stats():
project_name = gr.Textbox(label="project name", value="my_speak")
bt_create = gr.Button("create new project")

cm_project = gr.Dropdown(choices=projects, value=projects_selelect, label="Project", allow_custom_value=True)
with gr.Row():
cm_project = gr.Dropdown(
choices=projects, value=projects_selelect, label="Project", allow_custom_value=True, scale=6
)
ch_refresh_project = gr.Button("refresh", scale=1)

bt_create.click(fn=create_data_project, inputs=[project_name, tokenizer_type], outputs=[cm_project])

Expand Down Expand Up @@ -1304,6 +1414,7 @@ def get_combined_stats():
bt_prepare = bt_create = gr.Button("prepare")
txt_info_prepare = gr.Text(label="info", value="")
txt_vocab_prepare = gr.Text(label="vocab", value="")

bt_prepare.click(
fn=create_metadata, inputs=[cm_project, ch_tokenizern], outputs=[txt_info_prepare, txt_vocab_prepare]
)
Expand Down Expand Up @@ -1347,11 +1458,11 @@ def get_combined_stats():

with gr.Row():
epochs = gr.Number(label="Epochs", value=10)
num_warmup_updates = gr.Number(label="Warmup Updates", value=5)
num_warmup_updates = gr.Number(label="Warmup Updates", value=2)

with gr.Row():
save_per_updates = gr.Number(label="Save per Updates", value=10)
last_per_steps = gr.Number(label="Last per Steps", value=50)
save_per_updates = gr.Number(label="Save per Updates", value=300)
last_per_steps = gr.Number(label="Last per Steps", value=100)

with gr.Row():
mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none")
Expand Down Expand Up @@ -1394,6 +1505,7 @@ def get_combined_stats():
tokenizer_file.value = tokenizer_filev
mixed_precision.value = mixed_precisionv

ch_stream = gr.Checkbox(label="stream output experiment.", value=True)
txt_info_train = gr.Text(label="info", value="")
start_button.click(
fn=start_training,
Expand All @@ -1415,6 +1527,7 @@ def get_combined_stats():
tokenizer_type,
tokenizer_file,
mixed_precision,
ch_stream,
],
outputs=[txt_info_train, start_button, stop_button],
)
Expand Down Expand Up @@ -1448,10 +1561,8 @@ def get_combined_stats():
check_finetune, inputs=[ch_finetune], outputs=[file_checkpoint_train, tokenizer_file, tokenizer_type]
)

cm_project.change(
fn=load_settings,
inputs=[cm_project],
outputs=[
def setup_load_settings():
output_components = [
exp_name,
learning_rate,
batch_size_per_gpu,
Expand All @@ -1468,7 +1579,22 @@ def get_combined_stats():
tokenizer_type,
tokenizer_file,
mixed_precision,
],
]

return output_components

outputs = setup_load_settings()

cm_project.change(
fn=load_settings,
inputs=[cm_project],
outputs=outputs,
)

ch_refresh_project.click(
fn=load_settings,
inputs=[cm_project],
outputs=outputs,
)

with gr.TabItem("test model"):
Expand Down

0 comments on commit 41eb33c

Please sign in to comment.