Skip to content

Commit

Permalink
v0.3.1 fix multi-style gradio bug; add features suggested #591
Browse files Browse the repository at this point in the history
  • Loading branch information
SWivid committed Dec 16, 2024
1 parent 61f28ee commit 9a18bbe
Showing 1 changed file with 47 additions and 53 deletions.
100 changes: 47 additions & 53 deletions src/f5_tts/infer/infer_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ def infer(
speed=1,
show_info=gr.Info,
):
if not ref_audio_orig:
gr.Warning("Please provide reference audio.")
return gr.update(), gr.update(), ref_text

if not gen_text.strip():
gr.Warning("Please enter text to generate.")
return gr.update(), gr.update(), ref_text

ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)

if model == "F5-TTS":
Expand Down Expand Up @@ -240,7 +248,7 @@ def basic_tts(
nfe_step=nfe_slider,
speed=speed_slider,
)
return audio_out, spectrogram_path, gr.update(value=ref_text_out)
return audio_out, spectrogram_path, ref_text_out

generate_btn.click(
basic_tts,
Expand Down Expand Up @@ -320,7 +328,7 @@ def parse_speechtypes_text(gen_text):
)

# Regular speech type (mandatory)
with gr.Row():
with gr.Row() as regular_row:
with gr.Column():
regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
regular_insert = gr.Button("Insert Label", variant="secondary")
Expand All @@ -329,12 +337,12 @@ def parse_speechtypes_text(gen_text):

# Regular speech type (max 100)
max_speech_types = 100
speech_type_rows = [] # 99
speech_type_names = [regular_name] # 100
speech_type_audios = [regular_audio] # 100
speech_type_ref_texts = [regular_ref_text] # 100
speech_type_delete_btns = [] # 99
speech_type_insert_btns = [regular_insert] # 100
speech_type_rows = [regular_row]
speech_type_names = [regular_name]
speech_type_audios = [regular_audio]
speech_type_ref_texts = [regular_ref_text]
speech_type_delete_btns = [None]
speech_type_insert_btns = [regular_insert]

# Additional speech types (99 more)
for i in range(max_speech_types - 1):
Expand All @@ -355,51 +363,32 @@ def parse_speechtypes_text(gen_text):
# Button to add speech type
add_speech_type_btn = gr.Button("Add Speech Type")

# Keep track of current number of speech types
speech_type_count = gr.State(value=1)
# Keep track of autoincrement of speech types, no roll back
speech_type_count = 1

# Function to add a speech type
def add_speech_type_fn(speech_type_count):
def add_speech_type_fn():
row_updates = [gr.update() for _ in range(max_speech_types)]
global speech_type_count
if speech_type_count < max_speech_types:
row_updates[speech_type_count] = gr.update(visible=True)
speech_type_count += 1
# Prepare updates for the rows
row_updates = []
for i in range(1, max_speech_types):
if i < speech_type_count:
row_updates.append(gr.update(visible=True))
else:
row_updates.append(gr.update())
else:
# Optionally, show a warning
row_updates = [gr.update() for _ in range(1, max_speech_types)]
return [speech_type_count] + row_updates
gr.Warning("Exhausted maximum number of speech types. Consider restart the app.")
return row_updates

add_speech_type_btn.click(
add_speech_type_fn, inputs=speech_type_count, outputs=[speech_type_count] + speech_type_rows
)
add_speech_type_btn.click(add_speech_type_fn, outputs=speech_type_rows)

# Function to delete a speech type
def make_delete_speech_type_fn(index):
def delete_speech_type_fn(speech_type_count):
# Prepare updates
row_updates = []

for i in range(1, max_speech_types):
if i == index:
row_updates.append(gr.update(visible=False))
else:
row_updates.append(gr.update())

speech_type_count = max(1, speech_type_count)

return [speech_type_count] + row_updates

return delete_speech_type_fn
def delete_speech_type_fn():
return gr.update(visible=False), None, None, None

# Update delete button clicks
for i, delete_btn in enumerate(speech_type_delete_btns):
delete_fn = make_delete_speech_type_fn(i)
delete_btn.click(delete_fn, inputs=speech_type_count, outputs=[speech_type_count] + speech_type_rows)
for i in range(1, len(speech_type_delete_btns)):
speech_type_delete_btns[i].click(
delete_speech_type_fn,
outputs=[speech_type_rows[i], speech_type_names[i], speech_type_audios[i], speech_type_ref_texts[i]],
)

# Text input for the prompt
gen_text_input_multistyle = gr.Textbox(
Expand All @@ -413,7 +402,7 @@ def insert_speech_type_fn(current_text, speech_type_name):
current_text = current_text or ""
speech_type_name = speech_type_name or "None"
updated_text = current_text + f"{{{speech_type_name}}} "
return gr.update(value=updated_text)
return updated_text

return insert_speech_type_fn

Expand Down Expand Up @@ -473,10 +462,14 @@ def generate_multistyle_speech(
if style in speech_types:
current_style = style
else:
# If style not available, default to Regular
gr.Warning(f"Type {style} is not available, will use Regular as default.")
current_style = "Regular"

ref_audio = speech_types[current_style]["audio"]
try:
ref_audio = speech_types[current_style]["audio"]
except KeyError:
gr.Warning(f"Please provide reference audio for type {current_style}.")
return [None] + [speech_types[style]["ref_text"] for style in speech_types]
ref_text = speech_types[current_style].get("ref_text", "")

# Generate speech for this segment
Expand All @@ -491,12 +484,10 @@ def generate_multistyle_speech(
# Concatenate all audio segments
if generated_audio_segments:
final_audio_data = np.concatenate(generated_audio_segments)
return [(sr, final_audio_data)] + [
gr.update(value=speech_types[style]["ref_text"]) for style in speech_types
]
return [(sr, final_audio_data)] + [speech_types[style]["ref_text"] for style in speech_types]
else:
gr.Warning("No audio generated.")
return [None] + [gr.update(value=speech_types[style]["ref_text"]) for style in speech_types]
return [None] + [speech_types[style]["ref_text"] for style in speech_types]

generate_multistyle_btn.click(
generate_multistyle_speech,
Expand All @@ -514,7 +505,7 @@ def generate_multistyle_speech(

# Validation function to disable Generate button if speech types are missing
def validate_speech_types(gen_text, regular_name, *args):
speech_type_names_list = args[:max_speech_types]
speech_type_names_list = args

# Collect the speech types names
speech_types_available = set()
Expand Down Expand Up @@ -678,7 +669,7 @@ def generate_audio_response(history, ref_audio, ref_text, remove_silence):
speed=1.0,
show_info=print, # show_info=print no pull to top when generating
)
return audio_result, gr.update(value=ref_text_out)
return audio_result, ref_text_out

def clear_conversation():
"""Reset the conversation"""
Expand Down Expand Up @@ -828,7 +819,10 @@ def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
visible=False,
)
custom_model_cfg = gr.Dropdown(
choices=[DEFAULT_TTS_MODEL_CFG[2]],
choices=[
DEFAULT_TTS_MODEL_CFG[2],
json.dumps(dict(dim=768, depth=18, heads=12, ff_mult=2, text_dim=512, conv_layers=4)),
],
value=load_last_used_custom()[2],
allow_custom_value=True,
label="Config: in a dictionary form",
Expand Down

0 comments on commit 9a18bbe

Please sign in to comment.