Skip to content

Commit

Permalink
Merge branch 'SWivid:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
lpscr authored Oct 26, 2024
2 parents ed36f4a + 6c62344 commit 190c1ac
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 51 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pip install -e .
docker build -t f5tts:v1 .

# Or pull from GitHub Container Registry
docker pull ghcr.io/SWivid/F5-TTS:main
docker pull ghcr.io/swivid/f5-tts:main
```


Expand Down
13 changes: 10 additions & 3 deletions src/f5_tts/infer/infer_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@
action="store_true",
help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
)
parser.add_argument(
"--speed",
type=float,
default=1.0,
help="Adjust the speed of the audio generation (default: 1.0)",
)
args = parser.parse_args()

config = tomli.load(open(args.config, "rb"))
Expand Down Expand Up @@ -102,6 +108,7 @@
ckpt_file = args.ckpt_file if args.ckpt_file else ""
vocab_file = args.vocab_file if args.vocab_file else ""
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
speed = args.speed
wave_path = Path(output_dir) / "infer_cli_out.wav"
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
Expand Down Expand Up @@ -134,7 +141,7 @@
ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)


def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence, speed):
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
if "voices" not in config:
voices = {"main": main_voice}
Expand Down Expand Up @@ -168,7 +175,7 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
ref_audio = voices[voice]["ref_audio"]
ref_text = voices[voice]["ref_text"]
print(f"Voice: {voice}")
audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj)
audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj, speed=speed)
generated_audio_segments.append(audio)

if generated_audio_segments:
Expand All @@ -186,7 +193,7 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):


def main():
main_process(ref_audio, ref_text, gen_text, ema_model, remove_silence)
main_process(ref_audio, ref_text, gen_text, ema_model, remove_silence, speed)


if __name__ == "__main__":
Expand Down
116 changes: 70 additions & 46 deletions src/f5_tts/infer/infer_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def gpu_decorator(func):
chat_tokenizer_state = None


@gpu_decorator
def generate_response(messages, model, tokenizer):
"""Generate response using Qwen"""
text = tokenizer.apply_chat_template(
Expand All @@ -78,8 +79,10 @@ def generate_response(messages, model, tokenizer):


@gpu_decorator
def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1):
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=gr.Info)
def infer(
ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1, show_info=gr.Info
):
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)

if model == "F5-TTS":
ema_model = F5TTS_ema_model
Expand All @@ -93,7 +96,7 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_
ema_model,
cross_fade_duration=cross_fade_duration,
speed=speed,
show_info=gr.Info,
show_info=show_info,
progress=gr.Progress(),
)

Expand Down Expand Up @@ -182,24 +185,24 @@ def parse_speechtypes_text(gen_text):

segments = []

current_emotion = "Regular"
current_style = "Regular"

for i in range(len(tokens)):
if i % 2 == 0:
# This is text
text = tokens[i].strip()
if text:
segments.append({"emotion": current_emotion, "text": text})
segments.append({"style": current_style, "text": text})
else:
# This is emotion
emotion = tokens[i].strip()
current_emotion = emotion
# This is style
style = tokens[i].strip()
current_style = style

return segments


with gr.Blocks() as app_multistyle:
# New section for emotional generation
# New section for multistyle generation
gr.Markdown(
"""
# Multiple Speech-Type Generation
Expand Down Expand Up @@ -312,29 +315,29 @@ def delete_speech_type_fn(speech_type_count):
delete_btn.click(delete_fn, inputs=speech_type_count, outputs=[speech_type_count] + speech_type_rows)

# Text input for the prompt
gen_text_input_emotional = gr.Textbox(
gen_text_input_multistyle = gr.Textbox(
label="Text to Generate",
lines=10,
placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
)

# Model choice
model_choice_emotional = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
model_choice_multistyle = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")

with gr.Accordion("Advanced Settings", open=False):
remove_silence_emotional = gr.Checkbox(
remove_silence_multistyle = gr.Checkbox(
label="Remove Silences",
value=False,
)

# Generate button
generate_emotional_btn = gr.Button("Generate Emotional Speech", variant="primary")
generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")

# Output audio
audio_output_emotional = gr.Audio(label="Synthesized Audio")
audio_output_multistyle = gr.Audio(label="Synthesized Audio")

@gpu_decorator
def generate_emotional_speech(
def generate_multistyle_speech(
regular_audio,
regular_ref_text,
gen_text,
Expand All @@ -361,23 +364,25 @@ def generate_emotional_speech(

# For each segment, generate speech
generated_audio_segments = []
current_emotion = "Regular"
current_style = "Regular"

for segment in segments:
emotion = segment["emotion"]
style = segment["style"]
text = segment["text"]

if emotion in speech_types:
current_emotion = emotion
if style in speech_types:
current_style = style
else:
# If emotion not available, default to Regular
current_emotion = "Regular"
# If style not available, default to Regular
current_style = "Regular"

ref_audio = speech_types[current_emotion]["audio"]
ref_text = speech_types[current_emotion].get("ref_text", "")
ref_audio = speech_types[current_style]["audio"]
ref_text = speech_types[current_style].get("ref_text", "")

# Generate speech for this segment
audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, 0)
audio, _ = infer(
ref_audio, ref_text, text, model_choice, remove_silence, 0, show_info=print
) # show_info=print no pull to top when generating
sr, audio_data = audio

generated_audio_segments.append(audio_data)
Expand All @@ -390,21 +395,21 @@ def generate_emotional_speech(
gr.Warning("No audio generated.")
return None

generate_emotional_btn.click(
generate_emotional_speech,
generate_multistyle_btn.click(
generate_multistyle_speech,
inputs=[
regular_audio,
regular_ref_text,
gen_text_input_emotional,
gen_text_input_multistyle,
]
+ speech_type_names
+ speech_type_audios
+ speech_type_ref_texts
+ [
model_choice_emotional,
remove_silence_emotional,
model_choice_multistyle,
remove_silence_multistyle,
],
outputs=audio_output_emotional,
outputs=audio_output_multistyle,
)

# Validation function to disable Generate button if speech types are missing
Expand All @@ -422,7 +427,7 @@ def validate_speech_types(gen_text, regular_name, *args):

# Parse the gen_text to get the speech types used
segments = parse_speechtypes_text(gen_text)
speech_types_in_text = set(segment["emotion"] for segment in segments)
speech_types_in_text = set(segment["style"] for segment in segments)

# Check if all speech types in text are available
missing_speech_types = speech_types_in_text - speech_types_available
Expand All @@ -434,10 +439,10 @@ def validate_speech_types(gen_text, regular_name, *args):
# Enable the generate button
return gr.update(interactive=True)

gen_text_input_emotional.change(
gen_text_input_multistyle.change(
validate_speech_types,
inputs=[gen_text_input_emotional, regular_name] + speech_type_names,
outputs=generate_emotional_btn,
inputs=[gen_text_input_multistyle, regular_name] + speech_type_names,
outputs=generate_multistyle_btn,
)


Expand All @@ -453,23 +458,35 @@ def validate_speech_types(gen_text, regular_name, *args):
"""
)

load_chat_model_btn = gr.Button("Load Chat Model", variant="primary")
if not USING_SPACES:
load_chat_model_btn = gr.Button("Load Chat Model", variant="primary")

chat_interface_container = gr.Column(visible=False)

@gpu_decorator
def load_chat_model():
global chat_model_state, chat_tokenizer_state
if chat_model_state is None:
show_info = gr.Info
show_info("Loading chat model...")
model_name = "Qwen/Qwen2.5-3B-Instruct"
chat_model_state = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype="auto", device_map="auto"
)
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
show_info("Chat model loaded.")

chat_interface_container = gr.Column(visible=False)
return gr.update(visible=False), gr.update(visible=True)

load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container])

else:
chat_interface_container = gr.Column()

def load_chat_model():
global chat_model_state, chat_tokenizer_state
if chat_model_state is None:
show_info = gr.Info
show_info("Loading chat model...")
model_name = "Qwen/Qwen2.5-3B-Instruct"
chat_model_state = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
show_info("Chat model loaded.")

return gr.update(visible=False), gr.update(visible=True)

load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container])

with chat_interface_container:
with gr.Row():
Expand Down Expand Up @@ -520,6 +537,7 @@ def load_chat_model():
)

# Modify process_audio_input to use model and tokenizer from state
@gpu_decorator
def process_audio_input(audio_path, history, conv_state):
"""Handle audio input from user"""
if not audio_path:
Expand All @@ -541,6 +559,7 @@ def process_audio_input(audio_path, history, conv_state):

return history, conv_state, ""

@gpu_decorator
def generate_audio_response(history, ref_audio, ref_text, model, remove_silence):
"""Generate TTS audio for AI response"""
if not history or not ref_audio:
Expand All @@ -558,6 +577,7 @@ def generate_audio_response(history, ref_audio, ref_text, model, remove_silence)
remove_silence,
cross_fade_duration=0.15,
speed=1.0,
show_info=print, # show_info=print no pull to top when generating
)
return audio_result

Expand All @@ -583,7 +603,11 @@ def update_system_prompt(new_prompt):
).then(
generate_audio_response,
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, model_choice_chat, remove_silence_chat],
outputs=audio_output_chat,
outputs=[audio_output_chat],
).then(
lambda: None,
None,
audio_input_chat,
)

# Handle clear button
Expand Down
2 changes: 1 addition & 1 deletion src/f5_tts/train/finetune_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
path_basic = os.path.abspath(os.path.join(__file__, "../../../.."))
path_data = os.path.join(path_basic, "data")
path_project_ckpts = os.path.join(path_basic, "ckpts")
file_train = "f5_tts/train/finetune_cli.py"
file_train = "src/f5_tts/train/finetune_cli.py"

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

Expand Down

0 comments on commit 190c1ac

Please sign in to comment.