diff --git a/gradio_app.py b/gradio_app.py index 5b6a5e3e2..07fe1e736 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -54,13 +54,12 @@ def gpu_decorator(func): UNetT, E2TTS_model_cfg, str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")) ) -# Initialize Qwen model and tokenizer -model_name = "Qwen/Qwen2.5-3B-Instruct" -model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto") -tokenizer = AutoTokenizer.from_pretrained(model_name) +chat_model_state = None +chat_tokenizer_state = None -def generate_response(messages): + +def generate_response(messages, model, tokenizer): """Generate response using Qwen""" text = tokenizer.apply_chat_template( messages, @@ -528,137 +527,157 @@ def validate_speech_types(gen_text, regular_name, *args): Have a conversation with an AI using your reference voice! 1. Upload a reference audio clip and optionally its transcript. -2. Record your message through your microphone. -3. The AI will respond using the reference voice. +2. Load the chat model. +3. Record your message through your microphone. +4. The AI will respond using the reference voice. """ ) - with gr.Row(): - with gr.Column(): - ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath") - - with gr.Column(): - with gr.Accordion("Advanced Settings", open=False): - model_choice_chat = gr.Radio( - choices=["F5-TTS", "E2-TTS"], - label="TTS Model", - value="F5-TTS", - ) - remove_silence_chat = gr.Checkbox( - label="Remove Silences", - value=True, - ) - ref_text_chat = gr.Textbox( - label="Reference Text", - info="Optional: Leave blank to auto-transcribe", - lines=2, + load_chat_model_btn = gr.Button("Load Chat Model", variant="primary") + + chat_interface_container = gr.Column(visible=False) + + def load_chat_model(): + global chat_model_state, chat_tokenizer_state + if chat_model_state is None: + show_info = gr.Info + show_info("Loading chat model...") + model_name = "Qwen/Qwen2.5-3B-Instruct" + chat_model_state = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto") + chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name) + show_info("Chat model loaded.") + return gr.update(visible=False), gr.update(visible=True) + + load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container]) + + with chat_interface_container: + with gr.Row(): + with gr.Column(): + ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath") + with gr.Column(): + with gr.Accordion("Advanced Settings", open=False): + model_choice_chat = gr.Radio( + choices=["F5-TTS", "E2-TTS"], + label="TTS Model", + value="F5-TTS", + ) + remove_silence_chat = gr.Checkbox( + label="Remove Silences", + value=True, + ) + ref_text_chat = gr.Textbox( + label="Reference Text", + info="Optional: Leave blank to auto-transcribe", + lines=2, + ) + system_prompt_chat = gr.Textbox( + label="System Prompt", + value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.", + lines=2, + ) + + chatbot_interface = gr.Chatbot(label="Conversation") + + with gr.Row(): + with gr.Column(): + audio_output_chat = gr.Audio(autoplay=True) + with gr.Column(): + audio_input_chat = gr.Microphone( + label="Speak your message", + type="filepath", ) - system_prompt_chat = gr.Textbox( - label="System Prompt", - value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.", - lines=2, - ) - - chatbot_interface = gr.Chatbot(label="Conversation") - - with gr.Row(): - with gr.Column(): - audio_output_chat = gr.Audio(autoplay=True) - with gr.Column(): - audio_input_chat = gr.Microphone( - label="Speak your message", - type="filepath", - ) - clear_btn_chat = gr.Button("Clear Conversation") - - conversation_state = gr.State( - value=[ - { - "role": "system", - "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.", - } - ] - ) - - def process_audio_input(audio_path, history, conv_state): - """Handle audio input from user""" - if not audio_path: - return history, conv_state, "" + clear_btn_chat = gr.Button("Clear Conversation") - text = "" - text = preprocess_ref_audio_text(audio_path, text)[1] + conversation_state = gr.State( + value=[ + { + "role": "system", + "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.", + } + ] + ) - if not text.strip(): - return history, conv_state, "" + # Modify process_audio_input to use model and tokenizer from state + def process_audio_input(audio_path, history, conv_state): + """Handle audio input from user""" + if not audio_path: + return history, conv_state, "" - conv_state.append({"role": "user", "content": text}) - history.append((text, None)) + text = "" + text = preprocess_ref_audio_text(audio_path, text)[1] - response = generate_response(conv_state) + if not text.strip(): + return history, conv_state, "" - conv_state.append({"role": "assistant", "content": response}) - history[-1] = (text, response) + conv_state.append({"role": "user", "content": text}) + history.append((text, None)) - return history, conv_state, "" + response = generate_response(conv_state, chat_model_state, chat_tokenizer_state) - def generate_audio_response(history, ref_audio, ref_text, model, remove_silence): - """Generate TTS audio for AI response""" - if not history or not ref_audio: - return None + conv_state.append({"role": "assistant", "content": response}) + history[-1] = (text, response) - last_user_message, last_ai_response = history[-1] - if not last_ai_response: - return None + return history, conv_state, "" - audio_result, _ = infer( - ref_audio, - ref_text, - last_ai_response, - model, - remove_silence, - cross_fade_duration=0.15, - speed=1.0, + def generate_audio_response(history, ref_audio, ref_text, model, remove_silence): + """Generate TTS audio for AI response""" + if not history or not ref_audio: + return None + + last_user_message, last_ai_response = history[-1] + if not last_ai_response: + return None + + audio_result, _ = infer( + ref_audio, + ref_text, + last_ai_response, + model, + remove_silence, + cross_fade_duration=0.15, + speed=1.0, + ) + return audio_result + + def clear_conversation(): + """Reset the conversation""" + return [], [ + { + "role": "system", + "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.", + } + ] + + def update_system_prompt(new_prompt): + """Update the system prompt and reset the conversation""" + new_conv_state = [{"role": "system", "content": new_prompt}] + return [], new_conv_state + + # Handle audio input + audio_input_chat.stop_recording( + process_audio_input, + inputs=[audio_input_chat, chatbot_interface, conversation_state], + outputs=[chatbot_interface, conversation_state], + ).then( + generate_audio_response, + inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, model_choice_chat, remove_silence_chat], + outputs=audio_output_chat, ) - return audio_result - - def clear_conversation(): - """Reset the conversation""" - return [], [ - { - "role": "system", - "content": "You are a friendly person, and may impersonate whoever they address you as. Stay in character. Keep your responses concise since they will be spoken out loud.", - } - ] - def update_system_prompt(new_prompt): - """Update the system prompt and reset the conversation""" - new_conv_state = [{"role": "system", "content": new_prompt}] - return [], new_conv_state - - # Handle audio input - audio_input_chat.stop_recording( - process_audio_input, - inputs=[audio_input_chat, chatbot_interface, conversation_state], - outputs=[chatbot_interface, conversation_state], - ).then( - generate_audio_response, - inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, model_choice_chat, remove_silence_chat], - outputs=audio_output_chat, - ) + # Handle clear button + clear_btn_chat.click( + clear_conversation, + outputs=[chatbot_interface, conversation_state], + ) - # Handle clear button - clear_btn_chat.click( - clear_conversation, - outputs=[chatbot_interface, conversation_state], - ) + # Handle system prompt change and reset conversation + system_prompt_chat.change( + update_system_prompt, + inputs=system_prompt_chat, + outputs=[chatbot_interface, conversation_state], + ) - # Handle system prompt change and reset conversation - system_prompt_chat.change( - update_system_prompt, - inputs=system_prompt_chat, - outputs=[chatbot_interface, conversation_state], - ) with gr.Blocks() as app: gr.Markdown( diff --git a/model/utils_infer.py b/model/utils_infer.py index 75f0cd39e..d2f521803 100644 --- a/model/utils_infer.py +++ b/model/utils_infer.py @@ -12,6 +12,8 @@ from transformers import pipeline from vocos import Vocos +import hashlib + from model import CFM from model.utils import ( load_checkpoint, @@ -19,6 +21,7 @@ convert_char_to_pinyin, ) +_ref_audio_cache = {} device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" @@ -158,23 +161,36 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device= aseg.export(f.name, format="wav") ref_audio = f.name - if not ref_text.strip(): - global asr_pipe - if asr_pipe is None: - initialize_asr_pipeline(device=device) - show_info("No reference text provided, transcribing reference audio...") - ref_text = asr_pipe( - ref_audio, - chunk_length_s=30, - batch_size=128, - generate_kwargs={"task": "transcribe"}, - return_timestamps=False, - )["text"].strip() - show_info("Finished transcription") + # Compute a hash of the reference audio file + with open(ref_audio, "rb") as audio_file: + audio_data = audio_file.read() + audio_hash = hashlib.md5(audio_data).hexdigest() + + global _ref_audio_cache + if audio_hash in _ref_audio_cache: + # Use cached reference text + show_info("Using cached reference text...") + ref_text = _ref_audio_cache[audio_hash] else: - show_info("Using custom reference text...") + if not ref_text.strip(): + global asr_pipe + if asr_pipe is None: + initialize_asr_pipeline(device=device) + show_info("No reference text provided, transcribing reference audio...") + ref_text = asr_pipe( + ref_audio, + chunk_length_s=30, + batch_size=128, + generate_kwargs={"task": "transcribe"}, + return_timestamps=False, + )["text"].strip() + show_info("Finished transcription") + else: + show_info("Using custom reference text...") + # Cache the transcribed text + _ref_audio_cache[audio_hash] = ref_text - # Add the functionality to ensure it ends with ". " + # Ensure ref_text ends with a proper sentence-ending punctuation if not ref_text.endswith(". ") and not ref_text.endswith("。"): if ref_text.endswith("."): ref_text += " "