Skip to content

Commit

Permalink
Merge branch 'SWivid:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
lpscr authored Oct 24, 2024
2 parents 7ab0d76 + 4d29f11 commit e0b1761
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 133 deletions.
255 changes: 137 additions & 118 deletions gradio_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,12 @@ def gpu_decorator(func):
UNetT, E2TTS_model_cfg, str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
)

# Initialize Qwen model and tokenizer
model_name = "Qwen/Qwen2.5-3B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

chat_model_state = None
chat_tokenizer_state = None

def generate_response(messages):

def generate_response(messages, model, tokenizer):
"""Generate response using Qwen"""
text = tokenizer.apply_chat_template(
messages,
Expand Down Expand Up @@ -528,137 +527,157 @@ def validate_speech_types(gen_text, regular_name, *args):
Have a conversation with an AI using your reference voice!
1. Upload a reference audio clip and optionally its transcript.
2. Record your message through your microphone.
3. The AI will respond using the reference voice.
2. Load the chat model.
3. Record your message through your microphone.
4. The AI will respond using the reference voice.
"""
)

with gr.Row():
with gr.Column():
ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")

with gr.Column():
with gr.Accordion("Advanced Settings", open=False):
model_choice_chat = gr.Radio(
choices=["F5-TTS", "E2-TTS"],
label="TTS Model",
value="F5-TTS",
)
remove_silence_chat = gr.Checkbox(
label="Remove Silences",
value=True,
)
ref_text_chat = gr.Textbox(
label="Reference Text",
info="Optional: Leave blank to auto-transcribe",
lines=2,
load_chat_model_btn = gr.Button("Load Chat Model", variant="primary")

chat_interface_container = gr.Column(visible=False)

def load_chat_model():
global chat_model_state, chat_tokenizer_state
if chat_model_state is None:
show_info = gr.Info
show_info("Loading chat model...")
model_name = "Qwen/Qwen2.5-3B-Instruct"
chat_model_state = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
show_info("Chat model loaded.")
return gr.update(visible=False), gr.update(visible=True)

load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container])

with chat_interface_container:
with gr.Row():
with gr.Column():
ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
with gr.Column():
with gr.Accordion("Advanced Settings", open=False):
model_choice_chat = gr.Radio(
choices=["F5-TTS", "E2-TTS"],
label="TTS Model",
value="F5-TTS",
)
remove_silence_chat = gr.Checkbox(
label="Remove Silences",
value=True,
)
ref_text_chat = gr.Textbox(
label="Reference Text",
info="Optional: Leave blank to auto-transcribe",
lines=2,
)
system_prompt_chat = gr.Textbox(
label="System Prompt",
value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
lines=2,
)

chatbot_interface = gr.Chatbot(label="Conversation")

with gr.Row():
with gr.Column():
audio_output_chat = gr.Audio(autoplay=True)
with gr.Column():
audio_input_chat = gr.Microphone(
label="Speak your message",
type="filepath",
)
system_prompt_chat = gr.Textbox(
label="System Prompt",
value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
lines=2,
)

chatbot_interface = gr.Chatbot(label="Conversation")

with gr.Row():
with gr.Column():
audio_output_chat = gr.Audio(autoplay=True)
with gr.Column():
audio_input_chat = gr.Microphone(
label="Speak your message",
type="filepath",
)

clear_btn_chat = gr.Button("Clear Conversation")

conversation_state = gr.State(
value=[
{
"role": "system",
"content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
}
]
)

def process_audio_input(audio_path, history, conv_state):
"""Handle audio input from user"""
if not audio_path:
return history, conv_state, ""
clear_btn_chat = gr.Button("Clear Conversation")

text = ""
text = preprocess_ref_audio_text(audio_path, text)[1]
conversation_state = gr.State(
value=[
{
"role": "system",
"content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
}
]
)

if not text.strip():
return history, conv_state, ""
# Modify process_audio_input to use model and tokenizer from state
def process_audio_input(audio_path, history, conv_state):
"""Handle audio input from user"""
if not audio_path:
return history, conv_state, ""

conv_state.append({"role": "user", "content": text})
history.append((text, None))
text = ""
text = preprocess_ref_audio_text(audio_path, text)[1]

response = generate_response(conv_state)
if not text.strip():
return history, conv_state, ""

conv_state.append({"role": "assistant", "content": response})
history[-1] = (text, response)
conv_state.append({"role": "user", "content": text})
history.append((text, None))

return history, conv_state, ""
response = generate_response(conv_state, chat_model_state, chat_tokenizer_state)

def generate_audio_response(history, ref_audio, ref_text, model, remove_silence):
"""Generate TTS audio for AI response"""
if not history or not ref_audio:
return None
conv_state.append({"role": "assistant", "content": response})
history[-1] = (text, response)

last_user_message, last_ai_response = history[-1]
if not last_ai_response:
return None
return history, conv_state, ""

audio_result, _ = infer(
ref_audio,
ref_text,
last_ai_response,
model,
remove_silence,
cross_fade_duration=0.15,
speed=1.0,
def generate_audio_response(history, ref_audio, ref_text, model, remove_silence):
"""Generate TTS audio for AI response"""
if not history or not ref_audio:
return None

last_user_message, last_ai_response = history[-1]
if not last_ai_response:
return None

audio_result, _ = infer(
ref_audio,
ref_text,
last_ai_response,
model,
remove_silence,
cross_fade_duration=0.15,
speed=1.0,
)
return audio_result

def clear_conversation():
"""Reset the conversation"""
return [], [
{
"role": "system",
"content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
}
]

def update_system_prompt(new_prompt):
"""Update the system prompt and reset the conversation"""
new_conv_state = [{"role": "system", "content": new_prompt}]
return [], new_conv_state

# Handle audio input
audio_input_chat.stop_recording(
process_audio_input,
inputs=[audio_input_chat, chatbot_interface, conversation_state],
outputs=[chatbot_interface, conversation_state],
).then(
generate_audio_response,
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, model_choice_chat, remove_silence_chat],
outputs=audio_output_chat,
)
return audio_result

def clear_conversation():
"""Reset the conversation"""
return [], [
{
"role": "system",
"content": "You are a friendly person, and may impersonate whoever they address you as. Stay in character. Keep your responses concise since they will be spoken out loud.",
}
]

def update_system_prompt(new_prompt):
"""Update the system prompt and reset the conversation"""
new_conv_state = [{"role": "system", "content": new_prompt}]
return [], new_conv_state

# Handle audio input
audio_input_chat.stop_recording(
process_audio_input,
inputs=[audio_input_chat, chatbot_interface, conversation_state],
outputs=[chatbot_interface, conversation_state],
).then(
generate_audio_response,
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, model_choice_chat, remove_silence_chat],
outputs=audio_output_chat,
)
# Handle clear button
clear_btn_chat.click(
clear_conversation,
outputs=[chatbot_interface, conversation_state],
)

# Handle clear button
clear_btn_chat.click(
clear_conversation,
outputs=[chatbot_interface, conversation_state],
)
# Handle system prompt change and reset conversation
system_prompt_chat.change(
update_system_prompt,
inputs=system_prompt_chat,
outputs=[chatbot_interface, conversation_state],
)

# Handle system prompt change and reset conversation
system_prompt_chat.change(
update_system_prompt,
inputs=system_prompt_chat,
outputs=[chatbot_interface, conversation_state],
)

with gr.Blocks() as app:
gr.Markdown(
Expand Down
46 changes: 31 additions & 15 deletions model/utils_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@
from transformers import pipeline
from vocos import Vocos

import hashlib

from model import CFM
from model.utils import (
load_checkpoint,
get_tokenizer,
convert_char_to_pinyin,
)

_ref_audio_cache = {}

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

Expand Down Expand Up @@ -158,23 +161,36 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=
aseg.export(f.name, format="wav")
ref_audio = f.name

if not ref_text.strip():
global asr_pipe
if asr_pipe is None:
initialize_asr_pipeline(device=device)
show_info("No reference text provided, transcribing reference audio...")
ref_text = asr_pipe(
ref_audio,
chunk_length_s=30,
batch_size=128,
generate_kwargs={"task": "transcribe"},
return_timestamps=False,
)["text"].strip()
show_info("Finished transcription")
# Compute a hash of the reference audio file
with open(ref_audio, "rb") as audio_file:
audio_data = audio_file.read()
audio_hash = hashlib.md5(audio_data).hexdigest()

global _ref_audio_cache
if audio_hash in _ref_audio_cache:
# Use cached reference text
show_info("Using cached reference text...")
ref_text = _ref_audio_cache[audio_hash]
else:
show_info("Using custom reference text...")
if not ref_text.strip():
global asr_pipe
if asr_pipe is None:
initialize_asr_pipeline(device=device)
show_info("No reference text provided, transcribing reference audio...")
ref_text = asr_pipe(
ref_audio,
chunk_length_s=30,
batch_size=128,
generate_kwargs={"task": "transcribe"},
return_timestamps=False,
)["text"].strip()
show_info("Finished transcription")
else:
show_info("Using custom reference text...")
# Cache the transcribed text
_ref_audio_cache[audio_hash] = ref_text

# Add the functionality to ensure it ends with ". "
# Ensure ref_text ends with a proper sentence-ending punctuation
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
if ref_text.endswith("."):
ref_text += " "
Expand Down

0 comments on commit e0b1761

Please sign in to comment.