A Full-Duplex Open-Domain Dialogue Agent with Continuous Turn-Taking Behavior
To reproduce results from the paper, see Reproduce Paper Results.
Inspired by Google Duplex, this bot aims to provide an experience as close as possible to a live phone call or face-to-face conversation. Unlike Google Duplex which was designed for specific tasks, this is a completely open-domain system intended to converse about anything. Importantly, there are no pre-defined turn-taking rules - the agent is free to speak whenever it chooses and learns coordination behavior directly from the training data.
- Whisper is used for Automatic Speech Recognition (ASR).
- Llama-2-7b fine-tuned on transcribed spoken dialogue from TalkBank is used for the dialogue agent. See the model card for more details.
- FastSpeech2 (trained on Common Voice v4) or Bark (trained on yet to be published) is used for Text to Speech (TTS).
Python 3.8 or greater is required. If PyTorch is not already installed in your environment, please install the appropriate configuration of PyTorch for your environment (OS, CUDA version) before proceeding - see https://pytorch.org/get-started/locally/.
If you wish to use Bark for TTS, PyTorch 2.1.x offers additional performance improvements. See the Bark Readme for more details.
To clone the repo and install dependencies, run:
git clone https://github.com/AbrahamSanders/realtime-chatbot.git
cd realtime-chatbot
pip install -r requirements.txt
The agent model is a fine-tuned LoRA adapter for meta-llama/Llama2-7b-hf, which requires all users to fill out an access request form before it will be available to download. Make sure you have done this and run huggingface-cli login
before attempting to run any of the interfaces below. For more information see https://huggingface.co/docs/hub/models-gated.
To launch the Gradio web interface, run the following. When prompted, navigate to http://127.0.0.1:7860:
python run_gradio.py
By default, FastSpeech2 is used for TTS. To use Bark instead, run:
python run_gradio.py --tts-engine=bark
Running this interface will use between 12GB and 24GB of GPU RAM, depending on the selected Whisper model size. Under default settings, it should run smoothly on a machine with a single 16GB GPU, with either FastSpeech2 or Bark, however you may experience larger floor transfer offsets (response latencies) on this minimal hardware configuration.
If you have multiple GPUs, the system will attempt to distribute the models across devices for added performance:
- If two GPUs are available, one will run the agent (Llama-2-7b) and the other will run Whisper and FastSpeech2 / Bark.
- On a machine with three or more GPUs, Llama-2, Whisper, and FastSpeech2 / Bark will each run on their own dedicated GPU to maximize performance.
Audio input and output devices (microphone + speakers) are required. There is currently no built-in echo cancellation functionality, so for the best experience it is recommended to use:
- A high-quality headset.
- Alternatively, headphones and an external microphone.
After the interface loads:
- Click Record to allow Gradio to begin recording audio from your microphone.
- [Optional] Use the
Dialogue Summary Prompt
textbox to provide a short script to help guide the topic and structure of the conversation.- e.g.,
"S1 and S2 are talking about what's new in their lives. S2 got a new dog."
- If set to a blank string, the conversation will be completely open-ended.
- e.g.,
- [Optional] Use the
Agent Starts
checkbox to determine whether the agent will start the conversation or wait for the user to speak first.- If
Agent Starts
is checked, use theOpening Utterance
textbox to provide the agent's initial utterance. If set to a blank string, the agent will be free to start the conversation however it chooses.
- If
- [Optional] Use the
Agent Voice
dropdown (scroll to bottom of page) to select the voice used by the agent.- Other options exist nearby to customize the agent's persona, such as
Agent Name
,Agent Age
, andAgent Gender
.
- Other options exist nearby to customize the agent's persona, such as
- Uncheck
Reset
to begin the conversation. - To reset the conversation at any time, check and then uncheck
Reset
.
To launch the terminal interface, run:
python run_chat.py
The purpose of the terminal interface is to provide a simple way to test the agent model in a text-only environment without the added complexity of ASR and TTS.
Keyboard input into the terminal input is processed in real-time to emulate continuous speech input.
While you type, words are submitted to the agent after space
or enter
are pressed.
- Type
--reset
to clear the dialogue history and start over. - Type
--exit
to quit.
To reproduce the results in tables 4 & 5 in the paper:
- Ensure
data/dataset_test.txt
exists (details on distributing this TBD due to TalkBank corpora licenses) - Run the evaluation script:
python run_evals.py --num-examples=150 --use-bf16 > eval_results_all.txt
This will run evaluation on all available GPUs using multiprocessing. On 4 GPUs with 48GB of memory each, this should take about ~12-24 hours.
On smaller GPUs, lower the --batch-size
and --contrastive-batch-size
as needed.
The results from table 4 will be saved to evals_output_ppl_all.csv
and the results from table 5 will be saved to evals_output_pred_all_all.csv
.
To train an agent model, first prepare the dataset and then run the HuggingFace trainer. Scripts are provided for both.
This script downloads, pre-processes and formats talkbank conversational corpora into text files for training, also handling separation into train, dev, and test splits. Simply run:
python prep_datast.py --standardize-pauses
The dataset files will be placed into the data
folder.
It is also possible to specify individual talkbank corpora or change the default train/dev/test split. To do this, check the command line options:
python prep_datast.py --help
The train.py script is a modified copy of HuggingFace's run_clm.py script, adapted to use with line-by-line text file datasets that require padding each example instead of chunking them into fixed size blocks.
The provided shell script train_large.sh is pre-configured to fine-tune a LoRA adapter for meta-llama/Llama-2-7b-hf
using train.py
.
To fine-tune a different model, simply modify this script. For example to train facebook/opt-350m
instead, modify it as such:
python train.py \
--model_name_or_path=facebook/opt-350m \
...
Currently, fine-tuning has been tested with meta-llama/Llama-2-*
and facebook/opt-*
models.