From 7384b63b1d5cfdd6b29cc4440c5e5147278cf2f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Wed, 15 May 2024 23:33:15 -0700 Subject: [PATCH 01/25] initial interruptions support --- CHANGELOG.md | 6 ++ src/pipecat/frames/frames.py | 2 +- src/pipecat/pipeline/task.py | 6 +- .../processors/aggregators/llm_response.py | 5 ++ .../processors/aggregators/user_response.py | 5 ++ src/pipecat/processors/frame_processor.py | 5 +- src/pipecat/transports/base_input.py | 49 +++++++++++---- src/pipecat/transports/base_output.py | 60 ++++++++++++++----- src/pipecat/transports/local/audio.py | 9 +-- src/pipecat/transports/local/tk.py | 10 ++-- src/pipecat/transports/services/daily.py | 12 ++-- 11 files changed, 124 insertions(+), 45 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ee98f3b02..5f5049a57 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to **pipecat** will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added + +- Added initial interruptions support. + ## [0.0.16] - 2024-05-16 ### Fixed diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 1278a7de1..7fcb0b6c2 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -187,7 +187,7 @@ class SystemFrame(Frame): @dataclass class StartFrame(SystemFrame): """This is the first frame that should be pushed down a pipeline.""" - pass + allow_interruptions: bool = False @dataclass diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index 2ecbe84ca..b693c0d58 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -31,11 +31,12 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): class PipelineTask: - def __init__(self, pipeline: FrameProcessor): + def __init__(self, pipeline: FrameProcessor, allow_interruptions=False): self.id: int = obj_id() self.name: str = f"{self.__class__.__name__}#{obj_count(self)}" self._pipeline = pipeline + self._allow_interruptions = allow_interruptions self._task_queue = asyncio.Queue() self._up_queue = asyncio.Queue() @@ -70,7 +71,8 @@ async def queue_frames(self, frames: Iterable[Frame] | AsyncIterable[Frame]): raise Exception("Frames must be an iterable or async iterable") async def _process_task_queue(self): - await self._source.process_frame(StartFrame(), FrameDirection.DOWNSTREAM) + await self._source.process_frame( + StartFrame(allow_interruptions=self._allow_interruptions), FrameDirection.DOWNSTREAM) running = True while running: frame = await self._task_queue.get() diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 37f8dab2e..82288a4e9 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -72,6 +72,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): if isinstance(frame, self._start_frame): self._seen_start_frame = True self._aggregating = True + + await self.push_frame(frame, direction) elif isinstance(frame, self._end_frame): self._seen_end_frame = True @@ -83,6 +85,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): # Send the aggregation if we are not aggregating anymore (i.e. no # more interim results received). send_aggregation = not self._aggregating + + await self.push_frame(frame, direction) elif isinstance(frame, self._accumulator_frame): if self._aggregating: self._aggregation += f" {frame.text}" @@ -109,6 +113,7 @@ async def _push_aggregation(self): # Reset self._aggregation = "" + self._aggregating = False self._seen_start_frame = False self._seen_end_frame = False self._seen_interim_results = False diff --git a/src/pipecat/processors/aggregators/user_response.py b/src/pipecat/processors/aggregators/user_response.py index e112fd3e4..ae063b3cd 100644 --- a/src/pipecat/processors/aggregators/user_response.py +++ b/src/pipecat/processors/aggregators/user_response.py @@ -89,6 +89,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): if isinstance(frame, self._start_frame): self._seen_start_frame = True self._aggregating = True + + await self.push_frame(frame, direction) elif isinstance(frame, self._end_frame): self._seen_end_frame = True @@ -100,6 +102,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): # Send the aggregation if we are not aggregating anymore (i.e. no # more interim results received). send_aggregation = not self._aggregating + + await self.push_frame(frame, direction) elif isinstance(frame, self._accumulator_frame): if self._aggregating: self._aggregation += f" {frame.text}" @@ -124,6 +128,7 @@ async def _push_aggregation(self): # Reset self._aggregation = "" + self._aggregating = False self._seen_start_frame = False self._seen_end_frame = False self._seen_interim_results = False diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index 3bb750218..a7d330680 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -8,7 +8,7 @@ from asyncio import AbstractEventLoop from enum import Enum -from pipecat.frames.frames import ErrorFrame, Frame +from pipecat.frames.frames import AudioRawFrame, ErrorFrame, Frame from pipecat.utils.utils import obj_count, obj_id from loguru import logger @@ -47,7 +47,8 @@ async def push_error(self, error: ErrorFrame): async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): if direction == FrameDirection.DOWNSTREAM and self._next: - logger.trace(f"Pushing {frame} from {self} to {self._next}") + if not isinstance(frame, AudioRawFrame): + logger.trace(f"Pushing {frame} from {self} to {self._next}") await self._next.process_frame(frame, direction) elif direction == FrameDirection.UPSTREAM and self._prev: logger.trace(f"Pushing {frame} upstream from {self} to {self._prev}") diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index 265c8e6c4..a4f2a4cc6 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -30,23 +30,23 @@ def __init__(self, params: TransportParams): self._params = params self._running = False + self._allow_interruptions = False # Start media threads. if self._params.audio_in_enabled or self._params.vad_enabled: self._audio_in_queue = queue.Queue() - # Start push frame task. This is the task that will push frames in + # Create push frame task. This is the task that will push frames in # order. So, a transport guarantees that all frames are pushed in the # same task. - loop = self.get_event_loop() - self._push_frame_task = loop.create_task(self._push_frame_task_handler()) - self._push_queue = asyncio.Queue() + self._create_push_task() - async def start(self): + async def start(self, frame: StartFrame): if self._running: return self._running = True + self._allow_interruptions = frame.allow_interruptions if self._params.audio_in_enabled or self._params.vad_enabled: loop = self.get_event_loop() @@ -65,6 +65,9 @@ async def stop(self): await self._audio_in_thread await self._audio_out_thread + await self._internal_push_frame(None, None) + await self._push_frame_task + def vad_analyze(self, audio_frames: bytes) -> VADState: pass @@ -79,10 +82,15 @@ async def cleanup(self): pass async def process_frame(self, frame: Frame, direction: FrameDirection): - if isinstance(frame, StartFrame): - await self.start() + if isinstance(frame, CancelFrame): + await self.stop() + # We don't queue a CancelFrame since we want to stop ASAP. + await self.push_frame(frame, direction) + elif isinstance(frame, StartFrame): + self._allow_interruption = frame.allow_interruptions + await self.start(frame) await self._internal_push_frame(frame, direction) - elif isinstance(frame, CancelFrame) or isinstance(frame, EndFrame): + elif isinstance(frame, EndFrame): await self.stop() await self._internal_push_frame(frame, direction) else: @@ -92,10 +100,15 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): # Push frames task # + def _create_push_task(self): + loop = self.get_event_loop() + self._push_frame_task = loop.create_task(self._push_frame_task_handler()) + self._push_queue = asyncio.Queue() + async def _internal_push_frame( self, - frame: Frame, - direction: FrameDirection = FrameDirection.DOWNSTREAM): + frame: Frame | None, + direction: FrameDirection | None = FrameDirection.DOWNSTREAM): await self._push_queue.put((frame, direction)) async def _push_frame_task_handler(self): @@ -106,6 +119,16 @@ async def _push_frame_task_handler(self): await self.push_frame(frame, direction) running = frame is not None + # + # Handle interruptions + # + + async def _handle_interruptions(self, frame: Frame): + if self._allow_interruptions and isinstance(frame, UserStartedSpeakingFrame): + self._push_frame_task.cancel() + self._create_push_task() + await self._internal_push_frame(frame) + # # Audio input # @@ -118,11 +141,13 @@ def _handle_vad(self, audio_frames: bytes, vad_state: VADState): frame = UserStartedSpeakingFrame() elif new_vad_state == VADState.QUIET: frame = UserStoppedSpeakingFrame() + if frame: future = asyncio.run_coroutine_threadsafe( - self._internal_push_frame(frame), self.get_event_loop()) + self._handle_interruptions(frame), self.get_event_loop()) future.result() - vad_state = new_vad_state + + vad_state = new_vad_state return vad_state def _audio_in_thread_handler(self): diff --git a/src/pipecat/transports/base_output.py b/src/pipecat/transports/base_output.py index f46d629f8..520372803 100644 --- a/src/pipecat/transports/base_output.py +++ b/src/pipecat/transports/base_output.py @@ -7,9 +7,9 @@ import asyncio import itertools -from multiprocessing.context import _force_start_method import queue import time +import threading from PIL import Image from typing import List @@ -23,7 +23,9 @@ EndFrame, Frame, ImageRawFrame, - TransportMessageFrame) + TransportMessageFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame) from pipecat.transports.base_transport import TransportParams from loguru import logger @@ -37,6 +39,7 @@ def __init__(self, params: TransportParams): self._params = params self._running = False + self._allow_interruptions = False # These are the images that we should send to the camera at our desired # framerate. @@ -48,12 +51,14 @@ def __init__(self, params: TransportParams): self._sink_queue = queue.Queue() self._stopped_event = asyncio.Event() + self._is_interrupted = threading.Event() - async def start(self): + async def start(self, frame: StartFrame): if self._running: return self._running = True + self._allow_interruptions = frame.allow_interruptions loop = self.get_event_loop() @@ -93,12 +98,15 @@ async def cleanup(self): async def process_frame(self, frame: Frame, direction: FrameDirection): if isinstance(frame, StartFrame): - await self.start() + await self.start(frame) await self.push_frame(frame, direction) # EndFrame is managed in the queue handler. elif isinstance(frame, CancelFrame): await self.stop() await self.push_frame(frame, direction) + elif isinstance(frame, UserStartedSpeakingFrame) or isinstance(frame, UserStoppedSpeakingFrame): + await self._handle_interruptions(frame) + await self.push_frame(frame, direction) elif self._frame_managed_by_sink(frame): self._sink_queue.put(frame) else: @@ -117,28 +125,50 @@ def _frame_managed_by_sink(self, frame: Frame): or isinstance(frame, TransportMessageFrame) or isinstance(frame, EndFrame)) + async def _handle_interruptions(self, frame: Frame): + if not self._allow_interruptions: + return + + if isinstance(frame, UserStartedSpeakingFrame): + self._is_interrupted.set() + elif isinstance(frame, UserStoppedSpeakingFrame): + self._is_interrupted.clear() + def _sink_thread_handler(self): - buffer = bytearray() + # 10ms bytes bytes_size_10ms = int(self._params.audio_out_sample_rate / 100) * \ self._params.audio_out_channels * 2 + + # We will send at least 100ms bytes. + smallest_write_size = bytes_size_10ms * 10 + + # Audio accumlation buffer + buffer = bytearray() while self._running: try: frame = self._sink_queue.get(timeout=1) + if isinstance(frame, EndFrame): # Send all remaining audio before stopping (multiple of 10ms of audio). self._send_audio_truncated(buffer, bytes_size_10ms) future = asyncio.run_coroutine_threadsafe(self.stop(), self.get_event_loop()) future.result() - elif isinstance(frame, AudioRawFrame): - if self._params.audio_out_enabled: - buffer.extend(frame.audio) - buffer = self._send_audio_truncated(buffer, bytes_size_10ms) - elif isinstance(frame, ImageRawFrame) and self._params.camera_out_enabled: - self._set_camera_image(frame) - elif isinstance(frame, SpriteFrame) and self._params.camera_out_enabled: - self._set_camera_images(frame.images) - elif isinstance(frame, TransportMessageFrame): - self.send_message(frame) + + if not self._is_interrupted.is_set(): + if isinstance(frame, AudioRawFrame): + if self._params.audio_out_enabled: + buffer.extend(frame.audio) + buffer = self._send_audio_truncated(buffer, smallest_write_size) + elif isinstance(frame, ImageRawFrame) and self._params.camera_out_enabled: + self._set_camera_image(frame) + elif isinstance(frame, SpriteFrame) and self._params.camera_out_enabled: + self._set_camera_images(frame.images) + elif isinstance(frame, TransportMessageFrame): + self.send_message(frame) + else: + # Send any remaining audio + self._send_audio_truncated(buffer, bytes_size_10ms) + buffer = bytearray() except queue.Empty: pass except BaseException as e: diff --git a/src/pipecat/transports/local/audio.py b/src/pipecat/transports/local/audio.py index ee85bcf61..771715111 100644 --- a/src/pipecat/transports/local/audio.py +++ b/src/pipecat/transports/local/audio.py @@ -6,6 +6,7 @@ import asyncio +from pipecat.frames.frames import StartFrame from pipecat.processors.frame_processor import FrameProcessor from pipecat.transports.base_input import BaseInputTransport from pipecat.transports.base_output import BaseOutputTransport @@ -37,8 +38,8 @@ def __init__(self, py_audio: pyaudio.PyAudio, params: TransportParams): def read_raw_audio_frames(self, frame_count: int) -> bytes: return self._in_stream.read(frame_count, exception_on_overflow=False) - async def start(self): - await super().start() + async def start(self, frame: StartFrame): + await super().start(frame) self._in_stream.start_stream() async def stop(self): @@ -68,8 +69,8 @@ def __init__(self, py_audio: pyaudio.PyAudio, params: TransportParams): def write_raw_audio_frames(self, frames: bytes): self._out_stream.write(frames) - async def start(self): - await super().start() + async def start(self, frame: StartFrame): + await super().start(frame) self._out_stream.start_stream() async def stop(self): diff --git a/src/pipecat/transports/local/tk.py b/src/pipecat/transports/local/tk.py index 4165f941c..782c01dae 100644 --- a/src/pipecat/transports/local/tk.py +++ b/src/pipecat/transports/local/tk.py @@ -9,7 +9,7 @@ import numpy as np import tkinter as tk -from pipecat.frames.frames import ImageRawFrame +from pipecat.frames.frames import ImageRawFrame, StartFrame from pipecat.processors.frame_processor import FrameProcessor from pipecat.transports.base_input import BaseInputTransport from pipecat.transports.base_output import BaseOutputTransport @@ -48,8 +48,8 @@ def __init__(self, py_audio: pyaudio.PyAudio, params: TransportParams): def read_raw_audio_frames(self, frame_count: int) -> bytes: return self._in_stream.read(frame_count, exception_on_overflow=False) - async def start(self): - await super().start() + async def start(self, frame: StartFrame): + await super().start(frame) self._in_stream.start_stream() async def stop(self): @@ -89,8 +89,8 @@ def write_raw_audio_frames(self, frames: bytes): def write_frame_to_camera(self, frame: ImageRawFrame): self.get_event_loop().call_soon(self._write_frame_to_tk, frame) - async def start(self): - await super().start() + async def start(self, frame: StartFrame): + await super().start(frame) self._out_stream.start_stream() async def stop(self): diff --git a/src/pipecat/transports/services/daily.py b/src/pipecat/transports/services/daily.py index cca69a284..506c962ea 100644 --- a/src/pipecat/transports/services/daily.py +++ b/src/pipecat/transports/services/daily.py @@ -29,6 +29,7 @@ ImageRawFrame, InterimTranscriptionFrame, SpriteFrame, + StartFrame, TranscriptionFrame, TransportMessageFrame, UserImageRawFrame, @@ -283,6 +284,7 @@ def _handle_join_response(self): error_msg = f"Error joining {self._room_url}: {error}" logger.error(error_msg) self._callbacks.on_error(error_msg) + self._sync_response["join"].task_done() except queue.Empty: error_msg = f"Time out joining {self._room_url}" logger.error(error_msg) @@ -320,6 +322,7 @@ def _handle_leave_response(self): error_msg = f"Error leaving {self._room_url}: {error}" logger.error(error_msg) self._callbacks.on_error(error_msg) + self._sync_response["leave"].task_done() except queue.Empty: error_msg = f"Time out leaving {self._room_url}" logger.error(error_msg) @@ -432,13 +435,13 @@ def __init__(self, client: DailyTransportClient, params: DailyParams): self._video_renderers = {} self._camera_in_queue = queue.Queue() - async def start(self): + async def start(self, frame: StartFrame): if self._running: return # Join the room. await self._client.join() # This will set _running=True - await super().start() + await super().start(frame) # Create camera in thread (runs if _running is true). loop = asyncio.get_running_loop() self._camera_in_thread = loop.run_in_executor(None, self._camera_in_thread_handler) @@ -547,6 +550,7 @@ def _camera_in_thread_handler(self): future = asyncio.run_coroutine_threadsafe( self._internal_push_frame(frame), self.get_event_loop()) future.result() + self._camera_in_queue.task_done() except queue.Empty: pass except BaseException as e: @@ -560,11 +564,11 @@ def __init__(self, client: DailyTransportClient, params: DailyParams): self._client = client - async def start(self): + async def start(self, frame: StartFrame): if self._running: return # This will set _running=True - await super().start() + await super().start(frame) # Join the room. await self._client.join() From dc9377fb92940d0a7eae1ea4a7ff24b929865a63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Wed, 15 May 2024 23:35:15 -0700 Subject: [PATCH 02/25] add missing queue task_done() --- src/pipecat/frames/frames.py | 6 +++--- src/pipecat/transports/base_input.py | 2 ++ src/pipecat/transports/base_output.py | 3 +++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 7fcb0b6c2..ed7ac75e8 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -119,7 +119,7 @@ class TextFrame(DataFrame): text: str def __str__(self): - return f'{self.name}: "{self.text}"' + return f"{self.name}(text: {self.text})" @dataclass @@ -132,7 +132,7 @@ class TranscriptionFrame(TextFrame): timestamp: str def __str__(self): - return f"{self.name}(user: {self.user_id}, timestamp: {self.timestamp})" + return f"{self.name}(user: {self.user_id}, text: {self.text}, timestamp: {self.timestamp})" @dataclass @@ -143,7 +143,7 @@ class InterimTranscriptionFrame(TextFrame): timestamp: str def __str__(self): - return f"{self.name}(user: {self.user_id}, timestamp: {self.timestamp})" + return f"{self.name}(user: {self.user_id}, text: {self.text}, timestamp: {self.timestamp})" @dataclass diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index a4f2a4cc6..6000127b4 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -185,6 +185,8 @@ def _audio_out_thread_handler(self): future = asyncio.run_coroutine_threadsafe( self._internal_push_frame(frame), self.get_event_loop()) future.result() + + self._audio_in_queue.task_done() except queue.Empty: pass except BaseException as e: diff --git a/src/pipecat/transports/base_output.py b/src/pipecat/transports/base_output.py index 520372803..ff9040406 100644 --- a/src/pipecat/transports/base_output.py +++ b/src/pipecat/transports/base_output.py @@ -169,6 +169,8 @@ def _sink_thread_handler(self): # Send any remaining audio self._send_audio_truncated(buffer, bytes_size_10ms) buffer = bytearray() + + self._sink_queue.task_done() except queue.Empty: pass except BaseException as e: @@ -208,6 +210,7 @@ def _camera_out_thread_handler(self): if self._params.camera_out_is_live: image = self._camera_out_queue.get(timeout=1) self._draw_image(image) + self._camera_out_queue.task_done() elif self._camera_images: image = next(self._camera_images) self._draw_image(image) From 8c877d7d8effc15f9f31cff90d3d35419f8d07d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Wed, 15 May 2024 23:35:52 -0700 Subject: [PATCH 03/25] examples: update 07-interruptible --- .../foundational/06-listen-and-respond.py | 2 +- examples/foundational/06a-image-sync.py | 2 +- examples/foundational/07-interruptible.py | 81 +++++++++++-------- examples/moondream-chatbot/bot.py | 2 +- examples/simple-chatbot/bot.py | 2 +- 5 files changed, 50 insertions(+), 39 deletions(-) diff --git a/examples/foundational/06-listen-and-respond.py b/examples/foundational/06-listen-and-respond.py index 4e5d0758f..3ba220912 100644 --- a/examples/foundational/06-listen-and-respond.py +++ b/examples/foundational/06-listen-and-respond.py @@ -65,7 +65,7 @@ async def main(room_url: str, token): messages = [ { "role": "system", - "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so it should not contain special characters. Respond to what the user said in a creative and helpful way.", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so never use special characters in your answers. Respond to what the user said in a creative and helpful way.", }, ] tma_in = LLMUserResponseAggregator(messages) diff --git a/examples/foundational/06a-image-sync.py b/examples/foundational/06a-image-sync.py index 73878976f..77278f21d 100644 --- a/examples/foundational/06a-image-sync.py +++ b/examples/foundational/06a-image-sync.py @@ -83,7 +83,7 @@ async def main(room_url: str, token): messages = [ { "role": "system", - "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so it should not contain special characters. Respond to what the user said in a creative and helpful way.", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so never use special characters in your answers. Respond to what the user said in a creative and helpful way.", }, ] diff --git a/examples/foundational/07-interruptible.py b/examples/foundational/07-interruptible.py index fd0c2f842..de7ceb8c5 100644 --- a/examples/foundational/07-interruptible.py +++ b/examples/foundational/07-interruptible.py @@ -1,26 +1,33 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + import asyncio import aiohttp -import logging import os -from pipecat.pipeline.aggregators import ( - LLMAssistantResponseAggregator, - LLMUserResponseAggregator, -) +import sys +from pipecat.frames.frames import LLMMessagesFrame from pipecat.pipeline.pipeline import Pipeline -from pipecat.services.ai_services import FrameLogger -from pipecat.transports.daily_transport import DailyTransport -from pipecat.services.open_ai_services import OpenAILLMService -from pipecat.services.elevenlabs_ai_services import ElevenLabsTTSService +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineTask +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantResponseAggregator, LLMUserResponseAggregator) +from pipecat.services.elevenlabs import ElevenLabsTTSService +from pipecat.services.openai import OpenAILLMService +from pipecat.transports.services.daily import DailyParams, DailyTransport from runner import configure +from loguru import logger + from dotenv import load_dotenv load_dotenv(override=True) -logging.basicConfig(format=f"%(levelno)s %(asctime)s %(message)s") -logger = logging.getLogger("pipecat") -logger.setLevel(logging.DEBUG) +logger.remove(0) +logger.add(sys.stderr, level="TRACE") async def main(room_url: str, token): @@ -29,12 +36,12 @@ async def main(room_url: str, token): room_url, token, "Respond bot", - duration_minutes=5, - start_transcription=True, - mic_enabled=True, - mic_sample_rate=16000, - camera_enabled=False, - vad_enabled=True, + DailyParams( + audio_in_enabled=True, + audio_out_enabled=True, + transcription_enabled=True, + vad_enabled=True, + ) ) tts = ElevenLabsTTSService( @@ -47,27 +54,31 @@ async def main(room_url: str, token): api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4-turbo-preview") - pipeline = Pipeline([FrameLogger(), llm, FrameLogger(), tts]) + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so never use special characters. Respond to what the user said in a creative and helpful way.", + }, + ] - @transport.event_handler("on_first_other_participant_joined") - async def on_first_other_participant_joined(transport, participant): - await transport.say("Hi, I'm listening!", tts) + tma_in = LLMUserResponseAggregator(messages) + tma_out = LLMAssistantResponseAggregator(messages) - async def run_conversation(): - messages = [ - { - "role": "system", - "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio. Respond to what the user said in a creative and helpful way.", - }, - ] + pipeline = Pipeline([transport.input(), tma_in, llm, tts, tma_out, transport.output()]) - await transport.run_interruptible_pipeline( - pipeline, - post_processor=LLMAssistantResponseAggregator(messages), - pre_processor=LLMUserResponseAggregator(messages), - ) + task = PipelineTask(pipeline, allow_interruptions=True) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + transport.capture_participant_transcription(participant["id"]) + # Kick off the conversation. + messages.append( + {"role": "system", "content": "Please introduce yourself to the user."}) + await task.queue_frames([LLMMessagesFrame(messages)]) + + runner = PipelineRunner() - await asyncio.gather(transport.run(), run_conversation()) + await runner.run(task) if __name__ == "__main__": diff --git a/examples/moondream-chatbot/bot.py b/examples/moondream-chatbot/bot.py index 238a05f67..4a731d379 100644 --- a/examples/moondream-chatbot/bot.py +++ b/examples/moondream-chatbot/bot.py @@ -163,7 +163,7 @@ async def main(room_url: str, token): messages = [ { "role": "system", - "content": f"You are Chatbot, a friendly, helpful robot. Let the user know that you are capable of chatting or describing what you see. Your goal is to demonstrate your capabilities in a succinct way. Reply with only '{user_request_answer}' if the user asks you to describe what you see. Your output will be converted to audio so never include special characters in your answers. Respond to what the user said in a creative and helpful way, but keep your responses brief. Start by introducing yourself.", + "content": f"You are Chatbot, a friendly, helpful robot. Let the user know that you are capable of chatting or describing what you see. Your goal is to demonstrate your capabilities in a succinct way. Reply with only '{user_request_answer}' if the user asks you to describe what you see. Your output will be converted to audio so never use special characters in your answers. Respond to what the user said in a creative and helpful way, but keep your responses brief. Start by introducing yourself.", }, ] diff --git a/examples/simple-chatbot/bot.py b/examples/simple-chatbot/bot.py index e7be4732d..bde15aee8 100644 --- a/examples/simple-chatbot/bot.py +++ b/examples/simple-chatbot/bot.py @@ -126,7 +126,7 @@ async def main(room_url: str, token): # # English # - "content": "You are Chatbot, a friendly, helpful robot. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way, but keep your responses brief. Start by introducing yourself.", + "content": "You are Chatbot, a friendly, helpful robot. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so never use special characters in your answers. Respond to what the user said in a creative and helpful way, but keep your responses brief. Start by introducing yourself.", # # Spanish From f432e2b17e9665c4d14a6b7a24d9a1d727b8330e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 17 May 2024 10:25:00 -0700 Subject: [PATCH 04/25] transports: allow adding a vad analyzer to BaseInputTransport --- examples/foundational/07-interruptible.py | 2 ++ src/pipecat/transports/base_transport.py | 5 +++++ src/pipecat/transports/services/daily.py | 23 ++++++++++------------- src/pipecat/vad/silero.py | 19 +++++++++++++------ 4 files changed, 30 insertions(+), 19 deletions(-) diff --git a/examples/foundational/07-interruptible.py b/examples/foundational/07-interruptible.py index de7ceb8c5..c1cba90a7 100644 --- a/examples/foundational/07-interruptible.py +++ b/examples/foundational/07-interruptible.py @@ -18,6 +18,7 @@ from pipecat.services.elevenlabs import ElevenLabsTTSService from pipecat.services.openai import OpenAILLMService from pipecat.transports.services.daily import DailyParams, DailyTransport +from pipecat.vad.silero import SileroVADAnalyzer from runner import configure @@ -41,6 +42,7 @@ async def main(room_url: str, token): audio_out_enabled=True, transcription_enabled=True, vad_enabled=True, + vad_analyzer=SileroVADAnalyzer() ) ) diff --git a/src/pipecat/transports/base_transport.py b/src/pipecat/transports/base_transport.py index 7b3561394..7f22d2c2c 100644 --- a/src/pipecat/transports/base_transport.py +++ b/src/pipecat/transports/base_transport.py @@ -6,12 +6,16 @@ from abc import ABC, abstractmethod +from pydantic import ConfigDict from pydantic.main import BaseModel from pipecat.processors.frame_processor import FrameProcessor +from pipecat.vad.vad_analyzer import VADAnalyzer class TransportParams(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + camera_out_enabled: bool = False camera_out_is_live: bool = False camera_out_width: int = 1024 @@ -27,6 +31,7 @@ class TransportParams(BaseModel): audio_in_channels: int = 1 vad_enabled: bool = False vad_audio_passthrough: bool = False + vad_analyzer: VADAnalyzer | None = None class BaseTransport(ABC): diff --git a/src/pipecat/transports/services/daily.py b/src/pipecat/transports/services/daily.py index 506c962ea..e0d406d8e 100644 --- a/src/pipecat/transports/services/daily.py +++ b/src/pipecat/transports/services/daily.py @@ -161,12 +161,6 @@ def __init__( "speaker", sample_rate=self._params.audio_in_sample_rate, channels=self._params.audio_in_channels) Daily.select_speaker_device("speaker") - self._vad_analyzer = None - if self._params.vad_enabled: - self._vad_analyzer = WebRTCVADAnalyzer( - sample_rate=self._params.audio_in_sample_rate, - num_channels=self._params.audio_in_channels) - @property def participant_id(self) -> str: return self._participant_id @@ -174,12 +168,6 @@ def participant_id(self) -> str: def set_callbacks(self, callbacks: DailyCallbacks): self._callbacks = callbacks - def vad_analyze(self, audio_frames: bytes) -> VADState: - state = VADState.QUIET - if self._vad_analyzer: - state = self._vad_analyzer.analyze_audio(audio_frames) - return state - def send_message(self, frame: DailyTransportMessageFrame): self._client.send_app_message(frame.message, frame.participant_id) @@ -435,6 +423,12 @@ def __init__(self, client: DailyTransportClient, params: DailyParams): self._video_renderers = {} self._camera_in_queue = queue.Queue() + self._vad_analyzer = params.vad_analyzer + if params.vad_enabled and not params.vad_analyzer: + self._vad_analyzer = WebRTCVADAnalyzer( + sample_rate=self._params.audio_in_sample_rate, + num_channels=self._params.audio_in_channels) + async def start(self, frame: StartFrame): if self._running: return @@ -461,7 +455,10 @@ async def cleanup(self): await self._client.cleanup() def vad_analyze(self, audio_frames: bytes) -> VADState: - return self._client.vad_analyze(audio_frames) + state = VADState.QUIET + if self._vad_analyzer: + state = self._vad_analyzer.analyze_audio(audio_frames) + return state def read_raw_audio_frames(self, frame_count: int) -> bytes: return self._client.read_raw_audio_frames(frame_count) diff --git a/src/pipecat/vad/silero.py b/src/pipecat/vad/silero.py index cddfb9bf6..a6bb17e89 100644 --- a/src/pipecat/vad/silero.py +++ b/src/pipecat/vad/silero.py @@ -39,11 +39,10 @@ def int2float(sound): return sound -class SileroVAD(FrameProcessor, VADAnalyzer): +class SileroVADAnalyzer(VADAnalyzer): - def __init__(self, sample_rate=16000, audio_passthrough=False): - FrameProcessor.__init__(self) - VADAnalyzer.__init__(self, sample_rate=sample_rate, num_channels=1) + def __init__(self, sample_rate=16000): + super().__init__(sample_rate=sample_rate, num_channels=1) logger.debug("Loading Silero VAD model...") @@ -52,7 +51,6 @@ def __init__(self, sample_rate=16000, audio_passthrough=False): ) self._processor_vad_state: VADState = VADState.QUIET - self._audio_passthrough = audio_passthrough logger.debug("Loaded Silero VAD") @@ -74,6 +72,15 @@ def voice_confidence(self, buffer) -> float: logger.error(f"Error analyzing audio with Silero VAD: {e}") return 0 + +class SileroVAD(FrameProcessor): + + def __init__(self, sample_rate=16000, audio_passthrough=False): + super().__init__() + + self._vad_analyzer = SileroVADAnalyzer(sample_rate=sample_rate) + self._audio_passthrough = audio_passthrough + # # FrameProcessor # @@ -89,7 +96,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): async def _analyze_audio(self, frame: AudioRawFrame): # Check VAD and push event if necessary. We just care about changes # from QUIET to SPEAKING and vice versa. - new_vad_state = self.analyze_audio(frame.audio) + new_vad_state = self._vad_analyzer.analyze_audio(frame.audio) if new_vad_state != self._processor_vad_state and new_vad_state != VADState.STARTING and new_vad_state != VADState.STOPPING: new_frame = None From f62fe059b1d04aab93cb0ac7a80228a12be20345 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 17 May 2024 10:43:52 -0700 Subject: [PATCH 05/25] fix issues with Ctrl-C tasks cancellation --- CHANGELOG.md | 4 +++ src/pipecat/pipeline/task.py | 46 ++++++++++++++++------------ src/pipecat/transports/base_input.py | 13 ++++---- 3 files changed, 37 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f5049a57..b1c9876ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added initial interruptions support. +### Fixed + +- Fixed issues with Ctrl-C program termination. + ## [0.0.16] - 2024-05-16 ### Fixed diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index b693c0d58..221d588f5 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -38,7 +38,7 @@ def __init__(self, pipeline: FrameProcessor, allow_interruptions=False): self._pipeline = pipeline self._allow_interruptions = allow_interruptions - self._task_queue = asyncio.Queue() + self._down_queue = asyncio.Queue() self._up_queue = asyncio.Queue() self._source = Source(self._up_queue) @@ -50,15 +50,22 @@ async def stop_when_done(self): async def cancel(self): logger.debug(f"Canceling pipeline task {self}") - await self.queue_frame(CancelFrame()) + # Make sure everything is cleaned up downstream. This is sent + # out-of-band from the main streaming task which is what we want since + # we want to cancel right away. + await self._source.process_frame(CancelFrame(), FrameDirection.DOWNSTREAM) + self._process_down_task.cancel() + self._process_up_task.cancel() async def run(self): - await asyncio.gather(self._process_task_queue(), self._process_up_queue()) + self._process_up_task = asyncio.create_task(self._process_up_queue()) + self._process_down_task = asyncio.create_task(self._process_down_queue()) + await asyncio.gather(self._process_up_task, self._process_down_task) await self._source.cleanup() await self._pipeline.cleanup() async def queue_frame(self, frame: Frame): - await self._task_queue.put(frame) + await self._down_queue.put(frame) async def queue_frames(self, frames: Iterable[Frame] | AsyncIterable[Frame]): if isinstance(frames, AsyncIterable): @@ -70,30 +77,31 @@ async def queue_frames(self, frames: Iterable[Frame] | AsyncIterable[Frame]): else: raise Exception("Frames must be an iterable or async iterable") - async def _process_task_queue(self): + async def _process_down_queue(self): await self._source.process_frame( StartFrame(allow_interruptions=self._allow_interruptions), FrameDirection.DOWNSTREAM) running = True while running: - frame = await self._task_queue.get() - await self._source.process_frame(frame, FrameDirection.DOWNSTREAM) - self._task_queue.task_done() - running = not (isinstance(frame, StopTaskFrame) or - isinstance(frame, CancelFrame) or - isinstance(frame, EndFrame)) - # We just enqueue None to terminate the task. - await self._up_queue.put(None) + try: + frame = await self._down_queue.get() + await self._source.process_frame(frame, FrameDirection.DOWNSTREAM) + self._down_queue.task_done() + running = not (isinstance(frame, StopTaskFrame) or isinstance(frame, EndFrame)) + except asyncio.CancelledError: + break + # We just enqueue None to terminate the task gracefully. + self._process_up_task.cancel() async def _process_up_queue(self): - running = True - while running: - frame = await self._up_queue.get() - if frame: + while True: + try: + frame = await self._up_queue.get() if isinstance(frame, ErrorFrame): logger.error(f"Error running app: {frame.error}") await self.queue_frame(CancelFrame()) - self._up_queue.task_done() - running = frame is not None + self._up_queue.task_done() + except asyncio.CancelledError: + break def __str__(self): return self.name diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index 6000127b4..92473c769 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -65,8 +65,7 @@ async def stop(self): await self._audio_in_thread await self._audio_out_thread - await self._internal_push_frame(None, None) - await self._push_frame_task + self._push_frame_task.cancel() def vad_analyze(self, audio_frames: bytes) -> VADState: pass @@ -112,12 +111,12 @@ async def _internal_push_frame( await self._push_queue.put((frame, direction)) async def _push_frame_task_handler(self): - running = True - while running: - (frame, direction) = await self._push_queue.get() - if frame: + while True: + try: + (frame, direction) = await self._push_queue.get() await self.push_frame(frame, direction) - running = frame is not None + except asyncio.CancelledError: + break # # Handle interruptions From 0bef44c2ff395afa0abc41340305d0c95eda91c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 17 May 2024 11:07:12 -0700 Subject: [PATCH 06/25] introduce StartInterruptionFrame and StopInterruptionFrame --- src/pipecat/frames/frames.py | 22 ++++++++++++++++++++++ src/pipecat/transports/base_input.py | 13 ++++++++++--- src/pipecat/transports/base_output.py | 12 ++++++------ 3 files changed, 38 insertions(+), 9 deletions(-) diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index ed7ac75e8..9d5b15b36 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -216,6 +216,28 @@ class StopTaskFrame(SystemFrame): pass +@dataclass +class StartInterruptionFrame(SystemFrame): + """Emitted by VAD to indicate that a user has started speaking (i.e. is + interruption). This is similar to UserStartedSpeakingFrame except that it + should be pushed concurrently with other frames (so the order is not + guaranteed). + + """ + pass + + +@dataclass +class StopInterruptionFrame(SystemFrame): + """Emitted by VAD to indicate that a user has stopped speaking (i.e. no more + interruptions). This is similar to UserStoppedSpeakingFrame except that it + should be pushed concurrently with other frames (so the order is not + guaranteed). + + """ + pass + + # # Control frames # diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index 92473c769..f3ee492d3 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -14,6 +14,8 @@ StartFrame, EndFrame, Frame, + StartInterruptionFrame, + StopInterruptionFrame, UserStartedSpeakingFrame, UserStoppedSpeakingFrame) from pipecat.transports.base_transport import TransportParams @@ -123,9 +125,14 @@ async def _push_frame_task_handler(self): # async def _handle_interruptions(self, frame: Frame): - if self._allow_interruptions and isinstance(frame, UserStartedSpeakingFrame): - self._push_frame_task.cancel() - self._create_push_task() + if self._allow_interruptions: + # Make sure we notify about interruptions quickly out-of-band + if isinstance(frame, UserStartedSpeakingFrame): + self._push_frame_task.cancel() + self._create_push_task() + await self.push_frame(StartInterruptionFrame()) + elif isinstance(frame, UserStoppedSpeakingFrame): + await self.push_frame(StopInterruptionFrame()) await self._internal_push_frame(frame) # diff --git a/src/pipecat/transports/base_output.py b/src/pipecat/transports/base_output.py index ff9040406..7d18f1265 100644 --- a/src/pipecat/transports/base_output.py +++ b/src/pipecat/transports/base_output.py @@ -23,9 +23,9 @@ EndFrame, Frame, ImageRawFrame, - TransportMessageFrame, - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame) + StartInterruptionFrame, + StopInterruptionFrame, + TransportMessageFrame) from pipecat.transports.base_transport import TransportParams from loguru import logger @@ -104,7 +104,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): elif isinstance(frame, CancelFrame): await self.stop() await self.push_frame(frame, direction) - elif isinstance(frame, UserStartedSpeakingFrame) or isinstance(frame, UserStoppedSpeakingFrame): + elif isinstance(frame, StartInterruptionFrame) or isinstance(frame, StopInterruptionFrame): await self._handle_interruptions(frame) await self.push_frame(frame, direction) elif self._frame_managed_by_sink(frame): @@ -129,9 +129,9 @@ async def _handle_interruptions(self, frame: Frame): if not self._allow_interruptions: return - if isinstance(frame, UserStartedSpeakingFrame): + if isinstance(frame, StartInterruptionFrame): self._is_interrupted.set() - elif isinstance(frame, UserStoppedSpeakingFrame): + elif isinstance(frame, StopInterruptionFrame): self._is_interrupted.clear() def _sink_thread_handler(self): From efa5a061d7b051c9809e62b2c7492c134ada98eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 17 May 2024 11:37:38 -0700 Subject: [PATCH 07/25] silero: simplify int16 -> float32 conversion --- src/pipecat/vad/silero.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/src/pipecat/vad/silero.py b/src/pipecat/vad/silero.py index a6bb17e89..bfe13affe 100644 --- a/src/pipecat/vad/silero.py +++ b/src/pipecat/vad/silero.py @@ -26,19 +26,6 @@ raise Exception(f"Missing module(s): {e}") -# Provided by Alexander Veysov -def int2float(sound): - try: - abs_max = np.abs(sound).max() - sound = sound.astype("float32") - if abs_max > 0: - sound *= 1 / 32768 - sound = sound.squeeze() # depends on the use case - return sound - except ValueError: - return sound - - class SileroVADAnalyzer(VADAnalyzer): def __init__(self, sample_rate=16000): @@ -64,7 +51,8 @@ def num_frames_required(self) -> int: def voice_confidence(self, buffer) -> float: try: audio_int16 = np.frombuffer(buffer, np.int16) - audio_float32 = int2float(audio_int16) + # Divide by 32768 because we have signed 16-bit data. + audio_float32 = np.frombuffer(audio_int16, dtype=np.int16).astype(np.float32) / 32768.0 new_confidence = self._model(torch.from_numpy(audio_float32), self.sample_rate).item() return new_confidence except BaseException as e: From 537e72a05f0c7bf618093cb45d3e63d3e62abc30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 17 May 2024 13:01:43 -0700 Subject: [PATCH 08/25] vad: introduce VADParams so you can tweak things --- CHANGELOG.md | 2 ++ src/pipecat/transports/services/daily.py | 6 ++--- src/pipecat/vad/silero.py | 14 +++++++----- src/pipecat/vad/vad_analyzer.py | 28 +++++++++++++----------- 4 files changed, 29 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b1c9876ee..c381ac630 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added initial interruptions support. +- Added `VADParams` so you can control voice confidence level and others. + ### Fixed - Fixed issues with Ctrl-C program termination. diff --git a/src/pipecat/transports/services/daily.py b/src/pipecat/transports/services/daily.py index e0d406d8e..47a1b9925 100644 --- a/src/pipecat/transports/services/daily.py +++ b/src/pipecat/transports/services/daily.py @@ -38,7 +38,7 @@ from pipecat.transports.base_input import BaseInputTransport from pipecat.transports.base_output import BaseOutputTransport from pipecat.transports.base_transport import BaseTransport, TransportParams -from pipecat.vad.vad_analyzer import VADAnalyzer, VADState +from pipecat.vad.vad_analyzer import VADAnalyzer, VADParams, VADState from loguru import logger @@ -60,8 +60,8 @@ class DailyTransportMessageFrame(TransportMessageFrame): class WebRTCVADAnalyzer(VADAnalyzer): - def __init__(self, sample_rate=16000, num_channels=1): - super().__init__(sample_rate, num_channels) + def __init__(self, sample_rate=16000, num_channels=1, params: VADParams = VADParams()): + super().__init__(sample_rate, num_channels, params) self._webrtc_vad = Daily.create_native_vad( reset_period_ms=VAD_RESET_PERIOD_MS, diff --git a/src/pipecat/vad/silero.py b/src/pipecat/vad/silero.py index bfe13affe..ab7cf36df 100644 --- a/src/pipecat/vad/silero.py +++ b/src/pipecat/vad/silero.py @@ -8,7 +8,7 @@ from pipecat.frames.frames import AudioRawFrame, Frame, UserStartedSpeakingFrame, UserStoppedSpeakingFrame from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -from pipecat.vad.vad_analyzer import VADAnalyzer, VADState +from pipecat.vad.vad_analyzer import VADAnalyzer, VADParams, VADState from loguru import logger @@ -28,8 +28,8 @@ class SileroVADAnalyzer(VADAnalyzer): - def __init__(self, sample_rate=16000): - super().__init__(sample_rate=sample_rate, num_channels=1) + def __init__(self, sample_rate=16000, params: VADParams = VADParams()): + super().__init__(sample_rate=sample_rate, num_channels=1, params=params) logger.debug("Loading Silero VAD model...") @@ -63,10 +63,14 @@ def voice_confidence(self, buffer) -> float: class SileroVAD(FrameProcessor): - def __init__(self, sample_rate=16000, audio_passthrough=False): + def __init__( + self, + sample_rate: int = 16000, + vad_params: VADParams = VADParams(), + audio_passthrough: bool = False): super().__init__() - self._vad_analyzer = SileroVADAnalyzer(sample_rate=sample_rate) + self._vad_analyzer = SileroVADAnalyzer(sample_rate=sample_rate, params=vad_params) self._audio_passthrough = audio_passthrough # diff --git a/src/pipecat/vad/vad_analyzer.py b/src/pipecat/vad/vad_analyzer.py index 58bec3b9a..6c4afceba 100644 --- a/src/pipecat/vad/vad_analyzer.py +++ b/src/pipecat/vad/vad_analyzer.py @@ -7,6 +7,10 @@ from abc import abstractmethod from enum import Enum +from pydantic.main import BaseModel + +from pipecat.utils.utils import exp_smoothing + class VADState(Enum): QUIET = 1 @@ -15,26 +19,24 @@ class VADState(Enum): STOPPING = 4 +class VADParams(BaseModel): + confidence: float = 0.8 + start_secs: float = 0.2 + stop_secs: float = 0.8 + + class VADAnalyzer: - def __init__( - self, - sample_rate: int, - num_channels: int, - vad_confidence: float = 0.5, - vad_start_secs: float = 0.2, - vad_stop_secs: float = 0.8): + def __init__(self, sample_rate: int, num_channels: int, params: VADParams): self._sample_rate = sample_rate - self._vad_confidence = vad_confidence - self._vad_start_secs = vad_start_secs - self._vad_stop_secs = vad_stop_secs + self._params = params self._vad_frames = self.num_frames_required() self._vad_frames_num_bytes = self._vad_frames * num_channels * 2 vad_frames_per_sec = self._vad_frames / self._sample_rate - self._vad_start_frames = round(self._vad_start_secs / vad_frames_per_sec) - self._vad_stop_frames = round(self._vad_stop_secs / vad_frames_per_sec) + self._vad_start_frames = round(self._params.start_secs / vad_frames_per_sec) + self._vad_stop_frames = round(self._params.stop_secs / vad_frames_per_sec) self._vad_starting_count = 0 self._vad_stopping_count = 0 self._vad_state: VADState = VADState.QUIET @@ -64,7 +66,7 @@ def analyze_audio(self, buffer) -> VADState: self._vad_buffer = self._vad_buffer[num_required_bytes:] confidence = self.voice_confidence(audio_frames) - speaking = confidence >= self._vad_confidence + speaking = confidence >= self._params.confidence if speaking: match self._vad_state: From f2cefeeedc7c60e99f0917529f37ddbdfa5e656e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 17 May 2024 13:02:05 -0700 Subject: [PATCH 09/25] utils: move exp_smoothing to utils module --- src/pipecat/services/ai_services.py | 6 ++---- src/pipecat/utils/utils.py | 4 ++++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index ffb9aeee8..a3cb5cde5 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -22,6 +22,7 @@ VisionImageRawFrame, ) from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.utils.utils import exp_smoothing class AIService(FrameProcessor): @@ -115,16 +116,13 @@ def _new_wave(self): ww.setframerate(self._sample_rate) return (content, ww) - def _exp_smoothing(self, value: float, prev_value: float, factor: float) -> float: - return prev_value + factor * (value - prev_value) - def _get_smoothed_volume(self, audio: bytes, prev_rms: float, factor: float) -> float: # https://docs.python.org/3/library/array.html audio_array = array.array('h', audio) squares = [sample**2 for sample in audio_array] mean = sum(squares) / len(audio_array) rms = math.sqrt(mean) - return self._exp_smoothing(rms, prev_rms, factor) + return exp_smoothing(rms, prev_rms, factor) async def _append_audio(self, frame: AudioRawFrame): # Try to filter out empty background noise diff --git a/src/pipecat/utils/utils.py b/src/pipecat/utils/utils.py index a72f7234e..0be73191f 100644 --- a/src/pipecat/utils/utils.py +++ b/src/pipecat/utils/utils.py @@ -29,3 +29,7 @@ def obj_count(obj) -> int: else: _COUNTS[name] += 1 return _COUNTS[name] + + +def exp_smoothing(value: float, prev_value: float, factor: float) -> float: + return prev_value + factor * (value - prev_value) From a5d246ec0c286002adcde4f103594d866a2f3db8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 17 May 2024 13:11:53 -0700 Subject: [PATCH 10/25] vad: use exponential smoothing to avoid sudden changes --- CHANGELOG.md | 2 ++ examples/foundational/07-interruptible.py | 2 +- src/pipecat/vad/vad_analyzer.py | 16 ++++++++++++++-- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c381ac630..f6debb3b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `VADParams` so you can control voice confidence level and others. +- `VADAnalyzer` now uses an exponential smoothing to avoid sudden changes. + ### Fixed - Fixed issues with Ctrl-C program termination. diff --git a/examples/foundational/07-interruptible.py b/examples/foundational/07-interruptible.py index c1cba90a7..67c2eddc9 100644 --- a/examples/foundational/07-interruptible.py +++ b/examples/foundational/07-interruptible.py @@ -28,7 +28,7 @@ load_dotenv(override=True) logger.remove(0) -logger.add(sys.stderr, level="TRACE") +logger.add(sys.stderr, level="DEBUG") async def main(room_url: str, token): diff --git a/src/pipecat/vad/vad_analyzer.py b/src/pipecat/vad/vad_analyzer.py index 6c4afceba..23f7263ab 100644 --- a/src/pipecat/vad/vad_analyzer.py +++ b/src/pipecat/vad/vad_analyzer.py @@ -20,7 +20,7 @@ class VADState(Enum): class VADParams(BaseModel): - confidence: float = 0.8 + confidence: float = 0.6 start_secs: float = 0.2 stop_secs: float = 0.8 @@ -43,6 +43,10 @@ def __init__(self, sample_rate: int, num_channels: int, params: VADParams): self._vad_buffer = b"" + # Exponential smoothing + self._smoothing_factor = 0.6 + self._prev_confidence = 1 - self._smoothing_factor + @property def sample_rate(self): return self._sample_rate @@ -55,6 +59,11 @@ def num_frames_required(self) -> int: def voice_confidence(self, buffer) -> float: pass + def _smoothed_confidence(self, audio_frames, prev_confidence, factor): + confidence = self.voice_confidence(audio_frames) + smoothed = exp_smoothing(confidence, prev_confidence, factor) + return smoothed + def analyze_audio(self, buffer) -> VADState: self._vad_buffer += buffer @@ -65,7 +74,10 @@ def analyze_audio(self, buffer) -> VADState: audio_frames = self._vad_buffer[:num_required_bytes] self._vad_buffer = self._vad_buffer[num_required_bytes:] - confidence = self.voice_confidence(audio_frames) + confidence = self._smoothed_confidence( + audio_frames, self._prev_confidence, self._smoothing_factor) + self._prev_confidence = confidence + speaking = confidence >= self._params.confidence if speaking: From 57121338b1d0d6c1734ce10d5a12d8637237e5a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 17 May 2024 13:47:46 -0700 Subject: [PATCH 11/25] pipeline(task): cleanup processors only if we need to --- CHANGELOG.md | 3 +++ src/pipecat/pipeline/task.py | 10 +++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f6debb3b3..f80a6cf3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed issues with Ctrl-C program termination. +- Fixed an issue that was causing `StopTaskFrame` to actually not exit the + `PipelineTask`. + ## [0.0.16] - 2024-05-16 ### Fixed diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index 221d588f5..400adcbfd 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -61,8 +61,6 @@ async def run(self): self._process_up_task = asyncio.create_task(self._process_up_queue()) self._process_down_task = asyncio.create_task(self._process_down_queue()) await asyncio.gather(self._process_up_task, self._process_down_task) - await self._source.cleanup() - await self._pipeline.cleanup() async def queue_frame(self, frame: Frame): await self._down_queue.put(frame) @@ -81,14 +79,20 @@ async def _process_down_queue(self): await self._source.process_frame( StartFrame(allow_interruptions=self._allow_interruptions), FrameDirection.DOWNSTREAM) running = True + should_cleanup = True while running: try: frame = await self._down_queue.get() await self._source.process_frame(frame, FrameDirection.DOWNSTREAM) - self._down_queue.task_done() running = not (isinstance(frame, StopTaskFrame) or isinstance(frame, EndFrame)) + should_cleanup = not isinstance(frame, StopTaskFrame) + self._down_queue.task_done() except asyncio.CancelledError: break + # Cleanup only if we need to. + if should_cleanup: + await self._source.cleanup() + await self._pipeline.cleanup() # We just enqueue None to terminate the task gracefully. self._process_up_task.cancel() From 34762bf604ea7270501ff111c985153285101878 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 17 May 2024 14:15:37 -0700 Subject: [PATCH 12/25] transports: allows update allow_interruptinos when receiving StartFrame --- src/pipecat/transports/base_input.py | 6 +++++- src/pipecat/transports/base_output.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index f3ee492d3..1aa3ceaab 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -44,11 +44,15 @@ def __init__(self, params: TransportParams): self._create_push_task() async def start(self, frame: StartFrame): + # Make sure we have the latest params. Note that this transport might + # have been started on another task that might not need interruptions, + # for example. + self._allow_interruptions = frame.allow_interruptions + if self._running: return self._running = True - self._allow_interruptions = frame.allow_interruptions if self._params.audio_in_enabled or self._params.vad_enabled: loop = self.get_event_loop() diff --git a/src/pipecat/transports/base_output.py b/src/pipecat/transports/base_output.py index 7d18f1265..9d6c567b1 100644 --- a/src/pipecat/transports/base_output.py +++ b/src/pipecat/transports/base_output.py @@ -54,11 +54,15 @@ def __init__(self, params: TransportParams): self._is_interrupted = threading.Event() async def start(self, frame: StartFrame): + # Make sure we have the latest params. Note that this transport might + # have been started on another task that might not need interruptions, + # for example. + self._allow_interruptions = frame.allow_interruptions + if self._running: return self._running = True - self._allow_interruptions = frame.allow_interruptions loop = self.get_event_loop() From d66a79541367161cd6c22d22c9e2469e673cfd0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 17 May 2024 14:16:05 -0700 Subject: [PATCH 13/25] examples: use SileroVADAnalyzer instead of SileroVAD --- examples/foundational/06-listen-and-respond.py | 11 +++++------ examples/foundational/07-interruptible.py | 1 - examples/foundational/12-describe-video.py | 11 +++++------ examples/moondream-chatbot/bot.py | 11 +++++------ examples/simple-chatbot/bot.py | 11 +++++------ 5 files changed, 20 insertions(+), 25 deletions(-) diff --git a/examples/foundational/06-listen-and-respond.py b/examples/foundational/06-listen-and-respond.py index 3ba220912..e561c3c64 100644 --- a/examples/foundational/06-listen-and-respond.py +++ b/examples/foundational/06-listen-and-respond.py @@ -21,7 +21,7 @@ from pipecat.services.elevenlabs import ElevenLabsTTSService from pipecat.services.openai import OpenAILLMService from pipecat.transports.services.daily import DailyParams, DailyTransport -from pipecat.vad.silero import SileroVAD +from pipecat.vad.silero import SileroVADAnalyzer from runner import configure @@ -41,14 +41,13 @@ async def main(room_url: str, token): token, "Respond bot", DailyParams( - audio_in_enabled=True, # This is so Silero VAD can get audio data audio_out_enabled=True, - transcription_enabled=True + transcription_enabled=True, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer() ) ) - vad = SileroVAD() - tts = ElevenLabsTTSService( aiohttp_session=session, api_key=os.getenv("ELEVENLABS_API_KEY"), @@ -71,7 +70,7 @@ async def main(room_url: str, token): tma_in = LLMUserResponseAggregator(messages) tma_out = LLMAssistantResponseAggregator(messages) - pipeline = Pipeline([fl_in, transport.input(), vad, tma_in, llm, + pipeline = Pipeline([fl_in, transport.input(), tma_in, llm, fl_out, tts, tma_out, transport.output()]) task = PipelineTask(pipeline) diff --git a/examples/foundational/07-interruptible.py b/examples/foundational/07-interruptible.py index 67c2eddc9..70c74f5e2 100644 --- a/examples/foundational/07-interruptible.py +++ b/examples/foundational/07-interruptible.py @@ -38,7 +38,6 @@ async def main(room_url: str, token): token, "Respond bot", DailyParams( - audio_in_enabled=True, audio_out_enabled=True, transcription_enabled=True, vad_enabled=True, diff --git a/examples/foundational/12-describe-video.py b/examples/foundational/12-describe-video.py index feef343fc..68033b77f 100644 --- a/examples/foundational/12-describe-video.py +++ b/examples/foundational/12-describe-video.py @@ -19,7 +19,7 @@ from pipecat.services.elevenlabs import ElevenLabsTTSService from pipecat.services.moondream import MoondreamService from pipecat.transports.services.daily import DailyParams, DailyTransport -from pipecat.vad.silero import SileroVAD +from pipecat.vad.silero import SileroVADAnalyzer from runner import configure @@ -54,14 +54,13 @@ async def main(room_url: str, token): token, "Describe participant video", DailyParams( - audio_in_enabled=True, # This is so Silero VAD can get audio data audio_out_enabled=True, - transcription_enabled=True + transcription_enabled=True, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer() ) ) - vad = SileroVAD() - tts = ElevenLabsTTSService( aiohttp_session=session, api_key=os.getenv("ELEVENLABS_API_KEY"), @@ -90,7 +89,7 @@ async def on_first_participant_joined(transport, participant): transport.capture_participant_transcription(participant["id"]) image_requester.set_participant_id(participant["id"]) - pipeline = Pipeline([transport.input(), vad, user_response, image_requester, + pipeline = Pipeline([transport.input(), user_response, image_requester, vision_aggregator, moondream, tts, transport.output()]) task = PipelineTask(pipeline) diff --git a/examples/moondream-chatbot/bot.py b/examples/moondream-chatbot/bot.py index 4a731d379..6a43f617e 100644 --- a/examples/moondream-chatbot/bot.py +++ b/examples/moondream-chatbot/bot.py @@ -29,7 +29,7 @@ from pipecat.services.moondream import MoondreamService from pipecat.services.openai import OpenAILLMService from pipecat.transports.services.daily import DailyParams, DailyTransport -from pipecat.vad.silero import SileroVAD +from pipecat.vad.silero import SileroVADAnalyzer from runner import configure @@ -127,17 +127,16 @@ async def main(room_url: str, token): token, "Chatbot", DailyParams( - audio_in_enabled=True, audio_out_enabled=True, camera_out_enabled=True, camera_out_width=1024, camera_out_height=576, - transcription_enabled=True + transcription_enabled=True, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer() ) ) - vad = SileroVAD() - tts = ElevenLabsTTSService( aiohttp_session=session, api_key=os.getenv("ELEVENLABS_API_KEY"), @@ -169,7 +168,7 @@ async def main(room_url: str, token): ura = LLMUserResponseAggregator(messages) - pipeline = Pipeline([transport.input(), vad, ura, llm, + pipeline = Pipeline([transport.input(), ura, llm, ParallelPipeline( [sa, ir, va, moondream], [tf, imgf]), diff --git a/examples/simple-chatbot/bot.py b/examples/simple-chatbot/bot.py index bde15aee8..e379ae049 100644 --- a/examples/simple-chatbot/bot.py +++ b/examples/simple-chatbot/bot.py @@ -21,7 +21,7 @@ from pipecat.services.elevenlabs import ElevenLabsTTSService from pipecat.services.openai import OpenAILLMService from pipecat.transports.services.daily import DailyParams, DailyTranscriptionSettings, DailyTransport -from pipecat.vad.silero import SileroVAD +from pipecat.vad.silero import SileroVADAnalyzer from runner import configure @@ -82,11 +82,12 @@ async def main(room_url: str, token): token, "Chatbot", DailyParams( - audio_in_enabled=True, audio_out_enabled=True, camera_out_enabled=True, camera_out_width=1024, camera_out_height=576, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), transcription_enabled=True, # # Spanish @@ -99,8 +100,6 @@ async def main(room_url: str, token): ) ) - vad = SileroVAD() - tts = ElevenLabsTTSService( aiohttp_session=session, api_key=os.getenv("ELEVENLABS_API_KEY"), @@ -139,10 +138,10 @@ async def main(room_url: str, token): ta = TalkingAnimation() - pipeline = Pipeline([transport.input(), vad, user_response, + pipeline = Pipeline([transport.input(), user_response, llm, tts, ta, transport.output()]) - task = PipelineTask(pipeline) + task = PipelineTask(pipeline, allow_interruptions=True) await task.queue_frame(quiet_frame) @transport.event_handler("on_first_participant_joined") From de65028061e8c30eb98ad3e148d30dd7d2109493 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 17 May 2024 14:39:40 -0700 Subject: [PATCH 14/25] vad: reduce default confidence back to 0.5 --- src/pipecat/vad/vad_analyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pipecat/vad/vad_analyzer.py b/src/pipecat/vad/vad_analyzer.py index 23f7263ab..1d68e2af6 100644 --- a/src/pipecat/vad/vad_analyzer.py +++ b/src/pipecat/vad/vad_analyzer.py @@ -20,7 +20,7 @@ class VADState(Enum): class VADParams(BaseModel): - confidence: float = 0.6 + confidence: float = 0.5 start_secs: float = 0.2 stop_secs: float = 0.8 From c77db79447a9030009da771b9222e25c764b25e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 17 May 2024 14:52:51 -0700 Subject: [PATCH 15/25] examples: pipelines readability and add LLM assistants after transport --- .../05a-local-sync-speech-and-image.py | 10 +++++++--- examples/foundational/06-listen-and-respond.py | 12 ++++++++++-- examples/foundational/06a-image-sync.py | 11 +++++++++-- examples/foundational/07-interruptible.py | 8 +++++++- examples/foundational/10-wake-word.py | 12 ++++++++++-- examples/foundational/11-sound-effects.py | 14 ++++++++++++-- examples/foundational/12-describe-video.py | 11 +++++++++-- examples/moondream-chatbot/bot.py | 16 +++++++++++----- examples/simple-chatbot/bot.py | 10 ++++++++-- examples/storytelling-chatbot/src/bot.py | 4 ++-- examples/translation-chatbot/bot.py | 11 ++++++++++- 11 files changed, 95 insertions(+), 24 deletions(-) diff --git a/examples/foundational/05a-local-sync-speech-and-image.py b/examples/foundational/05a-local-sync-speech-and-image.py index a5629745b..bfbd453e2 100644 --- a/examples/foundational/05a-local-sync-speech-and-image.py +++ b/examples/foundational/05a-local-sync-speech-and-image.py @@ -98,9 +98,13 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): image_grabber = ImageGrabber() - pipeline = Pipeline([llm, aggregator, description, - ParallelPipeline([tts, audio_grabber], - [imagegen, image_grabber])]) + pipeline = Pipeline([ + llm, + aggregator, + description, + ParallelPipeline([tts, audio_grabber], + [imagegen, image_grabber]) + ]) task = PipelineTask(pipeline) await task.queue_frame(LLMMessagesFrame(messages)) diff --git a/examples/foundational/06-listen-and-respond.py b/examples/foundational/06-listen-and-respond.py index e561c3c64..b7e0d13bc 100644 --- a/examples/foundational/06-listen-and-respond.py +++ b/examples/foundational/06-listen-and-respond.py @@ -70,8 +70,16 @@ async def main(room_url: str, token): tma_in = LLMUserResponseAggregator(messages) tma_out = LLMAssistantResponseAggregator(messages) - pipeline = Pipeline([fl_in, transport.input(), tma_in, llm, - fl_out, tts, tma_out, transport.output()]) + pipeline = Pipeline([ + fl_in, + transport.input(), + tma_in, + llm, + fl_out, + tts, + transport.output(), + tma_out + ]) task = PipelineTask(pipeline) diff --git a/examples/foundational/06a-image-sync.py b/examples/foundational/06a-image-sync.py index 77278f21d..30f2eea95 100644 --- a/examples/foundational/06a-image-sync.py +++ b/examples/foundational/06a-image-sync.py @@ -95,8 +95,15 @@ async def main(room_url: str, token): os.path.join(os.path.dirname(__file__), "assets", "waiting.png"), ) - pipeline = Pipeline([transport.input(), image_sync_aggregator, - tma_in, llm, tma_out, tts, transport.output()]) + pipeline = Pipeline([ + transport.input(), + image_sync_aggregator, + tma_in, + llm, + tts, + transport.output(), + tma_out + ]) task = PipelineTask(pipeline) diff --git a/examples/foundational/07-interruptible.py b/examples/foundational/07-interruptible.py index 70c74f5e2..71f9a22ab 100644 --- a/examples/foundational/07-interruptible.py +++ b/examples/foundational/07-interruptible.py @@ -65,7 +65,13 @@ async def main(room_url: str, token): tma_in = LLMUserResponseAggregator(messages) tma_out = LLMAssistantResponseAggregator(messages) - pipeline = Pipeline([transport.input(), tma_in, llm, tts, tma_out, transport.output()]) + pipeline = Pipeline([ + transport.input(), + tma_in, + llm, + tts, + transport.output(), + tma_out]) task = PipelineTask(pipeline, allow_interruptions=True) diff --git a/examples/foundational/10-wake-word.py b/examples/foundational/10-wake-word.py index 430f76e8c..cc1829046 100644 --- a/examples/foundational/10-wake-word.py +++ b/examples/foundational/10-wake-word.py @@ -157,8 +157,16 @@ async def main(room_url: str, token): tma_out = LLMAssistantContextAggregator(messages) ncf = NameCheckFilter(["Santa Cat", "Santa"]) - pipeline = Pipeline([transport.input(), isa, ncf, tma_in, - llm, tma_out, tts, transport.output()]) + pipeline = Pipeline([ + transport.input(), + isa, + ncf, + tma_in, + llm, + tts, + transport.output(), + tma_out + ]) @transport.event_handler("on_first_participant_joined") async def on_first_participant_joined(transport, participant): diff --git a/examples/foundational/11-sound-effects.py b/examples/foundational/11-sound-effects.py index 2515a4418..292a89052 100644 --- a/examples/foundational/11-sound-effects.py +++ b/examples/foundational/11-sound-effects.py @@ -111,8 +111,18 @@ async def main(room_url: str, token): fl = FrameLogger("LLM Out") fl2 = FrameLogger("Transcription In") - pipeline = Pipeline([transport.input(), tma_in, in_sound, fl2, llm, - tma_out, fl, tts, out_sound, transport.output()]) + pipeline = Pipeline([ + transport.input(), + tma_in, + in_sound, + fl2, + llm, + fl, + tts, + out_sound, + transport.output(), + tma_out + ]) @transport.event_handler("on_first_participant_joined") async def on_first_participant_joined(transport, participant): diff --git a/examples/foundational/12-describe-video.py b/examples/foundational/12-describe-video.py index 68033b77f..256580c07 100644 --- a/examples/foundational/12-describe-video.py +++ b/examples/foundational/12-describe-video.py @@ -89,8 +89,15 @@ async def on_first_participant_joined(transport, participant): transport.capture_participant_transcription(participant["id"]) image_requester.set_participant_id(participant["id"]) - pipeline = Pipeline([transport.input(), user_response, image_requester, - vision_aggregator, moondream, tts, transport.output()]) + pipeline = Pipeline([ + transport.input(), + user_response, + image_requester, + vision_aggregator, + moondream, + tts, + transport.output() + ]) task = PipelineTask(pipeline) diff --git a/examples/moondream-chatbot/bot.py b/examples/moondream-chatbot/bot.py index 6a43f617e..b09f4345d 100644 --- a/examples/moondream-chatbot/bot.py +++ b/examples/moondream-chatbot/bot.py @@ -168,11 +168,17 @@ async def main(room_url: str, token): ura = LLMUserResponseAggregator(messages) - pipeline = Pipeline([transport.input(), ura, llm, - ParallelPipeline( - [sa, ir, va, moondream], - [tf, imgf]), - tts, ta, transport.output()]) + pipeline = Pipeline([ + transport.input(), + ura, + llm, + ParallelPipeline( + [sa, ir, va, moondream], + [tf, imgf]), + tts, + ta, + transport.output() + ]) task = PipelineTask(pipeline) await task.queue_frame(quiet_frame) diff --git a/examples/simple-chatbot/bot.py b/examples/simple-chatbot/bot.py index e379ae049..c605414eb 100644 --- a/examples/simple-chatbot/bot.py +++ b/examples/simple-chatbot/bot.py @@ -138,8 +138,14 @@ async def main(room_url: str, token): ta = TalkingAnimation() - pipeline = Pipeline([transport.input(), user_response, - llm, tts, ta, transport.output()]) + pipeline = Pipeline([ + transport.input(), + user_response, + llm, + tts, + ta, + transport.output() + ]) task = PipelineTask(pipeline, allow_interruptions=True) await task.queue_frame(quiet_frame) diff --git a/examples/storytelling-chatbot/src/bot.py b/examples/storytelling-chatbot/src/bot.py index 24f02ccac..c5a75e949 100644 --- a/examples/storytelling-chatbot/src/bot.py +++ b/examples/storytelling-chatbot/src/bot.py @@ -133,8 +133,8 @@ async def on_first_participant_joined(transport, participant): story_processor, image_processor, tts_service, - llm_responses, - transport.output() + transport.output(), + llm_responses ]) main_task = PipelineTask(main_pipeline) diff --git a/examples/translation-chatbot/bot.py b/examples/translation-chatbot/bot.py index 376b3570e..7cd2cbd83 100644 --- a/examples/translation-chatbot/bot.py +++ b/examples/translation-chatbot/bot.py @@ -103,7 +103,16 @@ async def main(room_url: str, token): lfra = LLMFullResponseAggregator() ts = TranslationSubtitles("spanish") - pipeline = Pipeline([transport.input(), sa, tp, llm, lfra, ts, tts, transport.output()]) + pipeline = Pipeline([ + transport.input(), + sa, + tp, + llm, + lfra, + ts, + tts, + transport.output() + ]) task = PipelineTask(pipeline) From 8dc81042c324cf71b62c0038261bc3d9a979e0d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 17 May 2024 15:36:30 -0700 Subject: [PATCH 16/25] examples: use DailyTranscriptionSettings in translation-chatbot --- examples/translation-chatbot/bot.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/translation-chatbot/bot.py b/examples/translation-chatbot/bot.py index 7cd2cbd83..89ca461b1 100644 --- a/examples/translation-chatbot/bot.py +++ b/examples/translation-chatbot/bot.py @@ -3,7 +3,7 @@ import os import sys -from pipecat.frames.frames import Frame, InterimTranscriptionFrame, LLMMessagesFrame, TextFrame, TranscriptionFrame, TransportMessageFrame +from pipecat.frames.frames import Frame, LLMMessagesFrame, TextFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineTask @@ -12,7 +12,7 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.services.azure import AzureTTSService from pipecat.services.openai import OpenAILLMService -from pipecat.transports.services.daily import DailyParams, DailyTransport, DailyTransportMessageFrame +from pipecat.transports.services.daily import DailyParams, DailyTranscriptionSettings, DailyTransport, DailyTransportMessageFrame from runner import configure @@ -84,7 +84,9 @@ async def main(room_url: str, token): DailyParams( audio_out_enabled=True, transcription_enabled=True, - transcription_interim_results=False, + transcription_settings=DailyTranscriptionSettings(extra={ + "interim_results": False + }) ) ) From 455ec4f1fd9f9725b3f747eb59f62ca2a02fb25e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 17 May 2024 17:11:11 -0700 Subject: [PATCH 17/25] services(tts): always send received TextFrame downstream --- .../processors/aggregators/llm_response.py | 4 ---- .../processors/aggregators/user_response.py | 4 ---- src/pipecat/services/ai_services.py | 15 +++++++++++++-- src/pipecat/services/elevenlabs.py | 4 +--- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 82288a4e9..5f23a9ed1 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -72,8 +72,6 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): if isinstance(frame, self._start_frame): self._seen_start_frame = True self._aggregating = True - - await self.push_frame(frame, direction) elif isinstance(frame, self._end_frame): self._seen_end_frame = True @@ -85,8 +83,6 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): # Send the aggregation if we are not aggregating anymore (i.e. no # more interim results received). send_aggregation = not self._aggregating - - await self.push_frame(frame, direction) elif isinstance(frame, self._accumulator_frame): if self._aggregating: self._aggregation += f" {frame.text}" diff --git a/src/pipecat/processors/aggregators/user_response.py b/src/pipecat/processors/aggregators/user_response.py index ae063b3cd..5c6520f1f 100644 --- a/src/pipecat/processors/aggregators/user_response.py +++ b/src/pipecat/processors/aggregators/user_response.py @@ -89,8 +89,6 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): if isinstance(frame, self._start_frame): self._seen_start_frame = True self._aggregating = True - - await self.push_frame(frame, direction) elif isinstance(frame, self._end_frame): self._seen_end_frame = True @@ -102,8 +100,6 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): # Send the aggregation if we are not aggregating anymore (i.e. no # more interim results received). send_aggregation = not self._aggregating - - await self.push_frame(frame, direction) elif isinstance(frame, self._accumulator_frame): if self._aggregating: self._aggregation += f" {frame.text}" diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index a3cb5cde5..696c61a73 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -18,6 +18,8 @@ EndFrame, ErrorFrame, Frame, + TTSStartedFrame, + TTSStoppedFrame, TextFrame, VisionImageRawFrame, ) @@ -69,14 +71,20 @@ async def _process_text_frame(self, frame: TextFrame): self._current_sentence = "" if text: - await self.process_generator(self.run_tts(text)) + await self._push_tts_frames(text) + + async def _push_tts_frames(self, text: str): + await self.push_frame(TextFrame(text)) + await self.push_frame(TTSStartedFrame()) + await self.process_generator(self.run_tts(text)) + await self.push_frame(TTSStoppedFrame()) async def process_frame(self, frame: Frame, direction: FrameDirection): if isinstance(frame, TextFrame): await self._process_text_frame(frame) elif isinstance(frame, EndFrame): if self._current_sentence: - await self.process_generator(self.run_tts(self._current_sentence)) + await self._push_tts_frames(self._current_sentence) await self.push_frame(frame) else: await self.push_frame(frame, direction) @@ -154,6 +162,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): self._wave.close() await self.push_frame(frame, direction) elif isinstance(frame, AudioRawFrame): + # In this service we accumulate audio internally and at the end we + # push a TextFrame. We don't really want to push audio frames down. await self._append_audio(frame) else: await self.push_frame(frame, direction) @@ -171,6 +181,7 @@ async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: async def process_frame(self, frame: Frame, direction: FrameDirection): if isinstance(frame, TextFrame): + await self.push_frame(frame, direction) await self.process_generator(self.run_image_gen(frame.text)) else: await self.push_frame(frame, direction) diff --git a/src/pipecat/services/elevenlabs.py b/src/pipecat/services/elevenlabs.py index 0bc207aef..53660a80e 100644 --- a/src/pipecat/services/elevenlabs.py +++ b/src/pipecat/services/elevenlabs.py @@ -8,7 +8,7 @@ from typing import AsyncGenerator -from pipecat.frames.frames import AudioRawFrame, ErrorFrame, Frame, TTSStartedFrame, TTSStoppedFrame +from pipecat.frames.frames import AudioRawFrame, ErrorFrame, Frame, TTSStartedFrame, TTSStoppedFrame, TextFrame from pipecat.services.ai_services import TTSService from loguru import logger @@ -53,9 +53,7 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: yield ErrorFrame(f"Audio fetch status code: {r.status}, error: {r.text}") return - yield TTSStartedFrame() async for chunk in r.content: if len(chunk) > 0: frame = AudioRawFrame(chunk, 16000, 1) yield frame - yield TTSStoppedFrame() From 3e13678f238810e32d06dd3ea443cfef8246834f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 17 May 2024 17:13:31 -0700 Subject: [PATCH 18/25] vad: use exponential smoothed volume to improve speech detection --- CHANGELOG.md | 6 +++++- src/pipecat/services/ai_services.py | 6 +++--- src/pipecat/vad/vad_analyzer.py | 31 ++++++++++++++++++----------- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f80a6cf3d..69c9e9091 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,10 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `VADParams` so you can control voice confidence level and others. -- `VADAnalyzer` now uses an exponential smoothing to avoid sudden changes. +- `VADAnalyzer` now uses an exponential smoothed volume to improve speech + detection. This is useful when voice confidence is high (because there's + someone talking near you) but volume is low. ### Fixed +- Fixed an issue where TTSService was not pushing TextFrames downstream. + - Fixed issues with Ctrl-C program termination. - Fixed an issue that was causing `StopTaskFrame` to actually not exit the diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index 696c61a73..bb697a84f 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -94,7 +94,7 @@ class STTService(AIService): """STTService is a base class for speech-to-text services.""" def __init__(self, - min_rms: int = 75, + min_rms: int = 100, max_silence_secs: float = 0.3, max_buffer_secs: float = 1.5, sample_rate: int = 16000, @@ -107,8 +107,8 @@ def __init__(self, self._num_channels = num_channels (self._content, self._wave) = self._new_wave() self._silence_num_frames = 0 - # Exponential smoothing - self._smoothing_factor = 0.08 + # Volume exponential smoothing + self._smoothing_factor = 0.5 self._prev_rms = 1 - self._smoothing_factor @abstractmethod diff --git a/src/pipecat/vad/vad_analyzer.py b/src/pipecat/vad/vad_analyzer.py index 1d68e2af6..15f036387 100644 --- a/src/pipecat/vad/vad_analyzer.py +++ b/src/pipecat/vad/vad_analyzer.py @@ -4,6 +4,9 @@ # SPDX-License-Identifier: BSD 2-Clause License # +import array +import math + from abc import abstractmethod from enum import Enum @@ -20,9 +23,10 @@ class VADState(Enum): class VADParams(BaseModel): - confidence: float = 0.5 + confidence: float = 0.6 start_secs: float = 0.2 stop_secs: float = 0.8 + min_rms: int = 1000 class VADAnalyzer: @@ -43,9 +47,9 @@ def __init__(self, sample_rate: int, num_channels: int, params: VADParams): self._vad_buffer = b"" - # Exponential smoothing - self._smoothing_factor = 0.6 - self._prev_confidence = 1 - self._smoothing_factor + # Volume exponential smoothing + self._smoothing_factor = 0.5 + self._prev_rms = 1 - self._smoothing_factor @property def sample_rate(self): @@ -59,10 +63,13 @@ def num_frames_required(self) -> int: def voice_confidence(self, buffer) -> float: pass - def _smoothed_confidence(self, audio_frames, prev_confidence, factor): - confidence = self.voice_confidence(audio_frames) - smoothed = exp_smoothing(confidence, prev_confidence, factor) - return smoothed + def _get_smoothed_volume(self, audio: bytes, prev_rms: float, factor: float) -> float: + # https://docs.python.org/3/library/array.html + audio_array = array.array('h', audio) + squares = [sample**2 for sample in audio_array] + mean = sum(squares) / len(audio_array) + rms = math.sqrt(mean) + return exp_smoothing(rms, prev_rms, factor) def analyze_audio(self, buffer) -> VADState: self._vad_buffer += buffer @@ -74,11 +81,11 @@ def analyze_audio(self, buffer) -> VADState: audio_frames = self._vad_buffer[:num_required_bytes] self._vad_buffer = self._vad_buffer[num_required_bytes:] - confidence = self._smoothed_confidence( - audio_frames, self._prev_confidence, self._smoothing_factor) - self._prev_confidence = confidence + confidence = self.voice_confidence(audio_frames) + rms = self._get_smoothed_volume(audio_frames, self._prev_rms, self._smoothing_factor) + self._prev_rms = rms - speaking = confidence >= self._params.confidence + speaking = confidence >= self._params.confidence and rms >= self._params.min_rms if speaking: match self._vad_state: From 0e8c7a9b285a44c86e5436221880a70cd2106f13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 17 May 2024 17:46:58 -0700 Subject: [PATCH 19/25] transports(output): create an downstream push frame task --- src/pipecat/transports/base_input.py | 3 +- src/pipecat/transports/base_output.py | 63 +++++++++++++++++++-------- 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index 1aa3ceaab..c670bb757 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -39,8 +39,7 @@ def __init__(self, params: TransportParams): self._audio_in_queue = queue.Queue() # Create push frame task. This is the task that will push frames in - # order. So, a transport guarantees that all frames are pushed in the - # same task. + # order. We also guarantee that all frames are pushed in the same task. self._create_push_task() async def start(self, frame: StartFrame): diff --git a/src/pipecat/transports/base_output.py b/src/pipecat/transports/base_output.py index 9d6c567b1..a960f1534 100644 --- a/src/pipecat/transports/base_output.py +++ b/src/pipecat/transports/base_output.py @@ -71,6 +71,10 @@ async def start(self, frame: StartFrame): self._sink_thread = loop.run_in_executor(None, self._sink_thread_handler) + # Create push frame task. This is the task that will push frames in + # order. We also guarantee that all frames are pushed in the same task. + self._create_push_task() + async def stop(self): if not self._running: return @@ -101,9 +105,14 @@ async def cleanup(self): await self._sink_thread async def process_frame(self, frame: Frame, direction: FrameDirection): + # + # Out-of-band frames like (CancelFrame or StartInterruptionFrame) are + # pushed immediately. Other frames require order so they are put in the + # sink queue. + # if isinstance(frame, StartFrame): await self.start(frame) - await self.push_frame(frame, direction) + self._sink_queue.put(frame) # EndFrame is managed in the queue handler. elif isinstance(frame, CancelFrame): await self.stop() @@ -111,10 +120,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): elif isinstance(frame, StartInterruptionFrame) or isinstance(frame, StopInterruptionFrame): await self._handle_interruptions(frame) await self.push_frame(frame, direction) - elif self._frame_managed_by_sink(frame): - self._sink_queue.put(frame) else: - await self.push_frame(frame, direction) + self._sink_queue.put(frame) # If we are finishing, wait here until we have stopped, otherwise we might # close things too early upstream. We need this event because we don't @@ -122,19 +129,14 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): if isinstance(frame, CancelFrame) or isinstance(frame, EndFrame): await self._stopped_event.wait() - def _frame_managed_by_sink(self, frame: Frame): - return (isinstance(frame, AudioRawFrame) - or isinstance(frame, ImageRawFrame) - or isinstance(frame, SpriteFrame) - or isinstance(frame, TransportMessageFrame) - or isinstance(frame, EndFrame)) - async def _handle_interruptions(self, frame: Frame): if not self._allow_interruptions: return if isinstance(frame, StartInterruptionFrame): self._is_interrupted.set() + self._push_frame_task.cancel() + self._create_push_task() elif isinstance(frame, StopInterruptionFrame): self._is_interrupted.clear() @@ -152,12 +154,6 @@ def _sink_thread_handler(self): try: frame = self._sink_queue.get(timeout=1) - if isinstance(frame, EndFrame): - # Send all remaining audio before stopping (multiple of 10ms of audio). - self._send_audio_truncated(buffer, bytes_size_10ms) - future = asyncio.run_coroutine_threadsafe(self.stop(), self.get_event_loop()) - future.result() - if not self._is_interrupted.is_set(): if isinstance(frame, AudioRawFrame): if self._params.audio_out_enabled: @@ -169,17 +165,50 @@ def _sink_thread_handler(self): self._set_camera_images(frame.images) elif isinstance(frame, TransportMessageFrame): self.send_message(frame) + else: + future = asyncio.run_coroutine_threadsafe( + self._internal_push_frame(frame), self.get_event_loop()) + future.result() else: # Send any remaining audio self._send_audio_truncated(buffer, bytes_size_10ms) buffer = bytearray() + if isinstance(frame, EndFrame): + # Send all remaining audio before stopping (multiple of 10ms of audio). + self._send_audio_truncated(buffer, bytes_size_10ms) + future = asyncio.run_coroutine_threadsafe(self.stop(), self.get_event_loop()) + future.result() + self._sink_queue.task_done() except queue.Empty: pass except BaseException as e: logger.error(f"Error processing sink queue: {e}") + # + # Push frames task + # + + def _create_push_task(self): + loop = self.get_event_loop() + self._push_frame_task = loop.create_task(self._push_frame_task_handler()) + self._push_queue = asyncio.Queue() + + async def _internal_push_frame( + self, + frame: Frame | None, + direction: FrameDirection | None = FrameDirection.DOWNSTREAM): + await self._push_queue.put((frame, direction)) + + async def _push_frame_task_handler(self): + while True: + try: + (frame, direction) = await self._push_queue.get() + await self.push_frame(frame, direction) + except asyncio.CancelledError: + break + # # Camera out # From 2b8f1c4cda21d9e405ad78cc60f77e3e875395e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 17 May 2024 17:47:33 -0700 Subject: [PATCH 20/25] services(openai): send LLMResponseStartFrame for each completion --- CHANGELOG.md | 4 +++- examples/foundational/07-interruptible.py | 3 ++- src/pipecat/processors/frame_processor.py | 3 +-- src/pipecat/services/ai_services.py | 4 +++- src/pipecat/services/openai.py | 6 ++---- 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 69c9e9091..a26ebfbb3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Added initial interruptions support. +- Added initial interruptions support. The assitant contexts (or aggregators) + should now be placed after the output transport. This way, only the completed + spoken context is added to the assistant context. - Added `VADParams` so you can control voice confidence level and others. diff --git a/examples/foundational/07-interruptible.py b/examples/foundational/07-interruptible.py index 71f9a22ab..2349c4303 100644 --- a/examples/foundational/07-interruptible.py +++ b/examples/foundational/07-interruptible.py @@ -71,7 +71,8 @@ async def main(room_url: str, token): llm, tts, transport.output(), - tma_out]) + tma_out + ]) task = PipelineTask(pipeline, allow_interruptions=True) diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index a7d330680..a79352f7a 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -47,8 +47,7 @@ async def push_error(self, error: ErrorFrame): async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): if direction == FrameDirection.DOWNSTREAM and self._next: - if not isinstance(frame, AudioRawFrame): - logger.trace(f"Pushing {frame} from {self} to {self._next}") + logger.trace(f"Pushing {frame} from {self} to {self._next}") await self._next.process_frame(frame, direction) elif direction == FrameDirection.UPSTREAM and self._prev: logger.trace(f"Pushing {frame} upstream from {self} to {self._prev}") diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index bb697a84f..52d62b7c8 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -74,10 +74,12 @@ async def _process_text_frame(self, frame: TextFrame): await self._push_tts_frames(text) async def _push_tts_frames(self, text: str): - await self.push_frame(TextFrame(text)) await self.push_frame(TTSStartedFrame()) await self.process_generator(self.run_tts(text)) await self.push_frame(TTSStoppedFrame()) + # We send the original text after the audio. This way, if we are + # interrupted, the text is not added to the assistant context. + await self.push_frame(TextFrame(text)) async def process_frame(self, frame: Frame, direction: FrameDirection): if isinstance(frame, TextFrame): diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index b29f3aaec..3e2f2ae91 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -100,8 +100,6 @@ async def _process_context(self, context: OpenAILLMContext): function_name = "" arguments = "" - await self.push_frame(LLMResponseStartFrame()) - chunk_stream: AsyncStream[ChatCompletionChunk] = ( await self._stream_chat_completions(context) ) @@ -132,15 +130,15 @@ async def _process_context(self, context: OpenAILLMContext): # completes arguments += tool_call.function.arguments elif chunk.choices[0].delta.content: + await self.push_frame(LLMResponseStartFrame()) await self.push_frame(TextFrame(chunk.choices[0].delta.content)) + await self.push_frame(LLMResponseEndFrame()) # if we got a function name and arguments, yield the frame with all the info so # frame consumers can take action based on the function call. # if function_name and arguments: # yield LLMFunctionCallFrame(function_name=function_name, arguments=arguments) - await self.push_frame(LLMResponseEndFrame()) - async def process_frame(self, frame: Frame, direction: FrameDirection): context = None if isinstance(frame, OpenAILLMContextFrame): From 435fffe1b06d4433d7b6a42cb16a3253caa0054c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Sat, 18 May 2024 09:49:38 -0700 Subject: [PATCH 21/25] add LLMFullResponseStartFrame/LLMFullResponseEndFrame --- .../foundational/05-sync-speech-and-image.py | 24 +++++++++---------- examples/foundational/07-interruptible.py | 12 +++++----- src/pipecat/frames/frames.py | 14 +++++++++++ .../processors/aggregators/llm_response.py | 3 ++- src/pipecat/services/openai.py | 6 +++++ 5 files changed, 40 insertions(+), 19 deletions(-) diff --git a/examples/foundational/05-sync-speech-and-image.py b/examples/foundational/05-sync-speech-and-image.py index 747eccb63..60dd50d07 100644 --- a/examples/foundational/05-sync-speech-and-image.py +++ b/examples/foundational/05-sync-speech-and-image.py @@ -13,12 +13,12 @@ from pipecat.frames.frames import ( AppFrame, + EndFrame, Frame, ImageRawFrame, - TextFrame, - EndFrame, + LLMFullResponseStartFrame, LLMMessagesFrame, - LLMResponseStartFrame, + TextFrame ) from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner @@ -64,7 +64,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): elif self.prepend_to_next_text_frame and isinstance(frame, TextFrame): await self.push_frame(TextFrame(f"{self.most_recent_month}: {frame.text}")) self.prepend_to_next_text_frame = False - elif isinstance(frame, LLMResponseStartFrame): + elif isinstance(frame, LLMFullResponseStartFrame): self.prepend_to_next_text_frame = True await self.push_frame(frame) else: @@ -105,7 +105,7 @@ async def main(room_url): gated_aggregator = GatedAggregator( gate_open_fn=lambda frame: isinstance(frame, ImageRawFrame), - gate_close_fn=lambda frame: isinstance(frame, LLMResponseStartFrame), + gate_close_fn=lambda frame: isinstance(frame, LLMFullResponseStartFrame), start_open=False ) @@ -114,14 +114,14 @@ async def main(room_url): llm_full_response_aggregator = LLMFullResponseAggregator() pipeline = Pipeline([ - llm, - sentence_aggregator, - ParallelTask( - [month_prepender, tts], - [llm_full_response_aggregator, imagegen] + llm, # LLM + sentence_aggregator, # Aggregates LLM output into full sentences + ParallelTask( # Run pipelines in parallel aggregating the result + [month_prepender, tts], # Create "Month: sentence" and output audio + [llm_full_response_aggregator, imagegen] # Aggregate full LLM response ), - gated_aggregator, - transport.output() + gated_aggregator, # Queues everything until an image is available + transport.output() # Transport output ]) frames = [] diff --git a/examples/foundational/07-interruptible.py b/examples/foundational/07-interruptible.py index 2349c4303..ae337f7f6 100644 --- a/examples/foundational/07-interruptible.py +++ b/examples/foundational/07-interruptible.py @@ -66,12 +66,12 @@ async def main(room_url: str, token): tma_out = LLMAssistantResponseAggregator(messages) pipeline = Pipeline([ - transport.input(), - tma_in, - llm, - tts, - transport.output(), - tma_out + transport.input(), # Transport user input + tma_in, # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + tma_out # Assistant spoken responses ]) task = PipelineTask(pipeline, allow_interruptions=True) diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 9d5b15b36..8eb32664c 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -260,6 +260,20 @@ class EndFrame(ControlFrame): pass +@dataclass +class LLMFullResponseStartFrame(ControlFrame): + """Used to indicate the beginning of a full LLM response. Following + LLMResponseStartFrame, TextFrame and LLMResponseEndFrame for each sentence + until a LLMFullResponseEndFrame.""" + pass + + +@dataclass +class LLMFullResponseEndFrame(ControlFrame): + """Indicates the end of a full LLM response.""" + pass + + @dataclass class LLMResponseStartFrame(ControlFrame): """Used to indicate the beginning of an LLM response. Following TextFrames diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 5f23a9ed1..3b9c07fe6 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -10,6 +10,7 @@ from pipecat.frames.frames import ( Frame, InterimTranscriptionFrame, + LLMFullResponseEndFrame, LLMMessagesFrame, LLMResponseStartFrame, TextFrame, @@ -182,7 +183,7 @@ def __init__(self): async def process_frame(self, frame: Frame, direction: FrameDirection): if isinstance(frame, TextFrame): self._aggregation += frame.text - elif isinstance(frame, LLMResponseEndFrame): + elif isinstance(frame, LLMFullResponseEndFrame): await self.push_frame(TextFrame(self._aggregation)) await self.push_frame(frame) self._aggregation = "" diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index 3e2f2ae91..56224e8fe 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -16,6 +16,8 @@ from pipecat.frames.frames import ( ErrorFrame, Frame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, LLMMessagesFrame, LLMResponseEndFrame, LLMResponseStartFrame, @@ -104,6 +106,8 @@ async def _process_context(self, context: OpenAILLMContext): await self._stream_chat_completions(context) ) + await self.push_frame(LLMFullResponseStartFrame()) + async for chunk in chunk_stream: if len(chunk.choices) == 0: continue @@ -134,6 +138,8 @@ async def _process_context(self, context: OpenAILLMContext): await self.push_frame(TextFrame(chunk.choices[0].delta.content)) await self.push_frame(LLMResponseEndFrame()) + await self.push_frame(LLMFullResponseEndFrame()) + # if we got a function name and arguments, yield the frame with all the info so # frame consumers can take action based on the function call. # if function_name and arguments: From 36dd4933e9102bc9604de7f7e8840855d7ab3cfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Sat, 18 May 2024 10:01:46 -0700 Subject: [PATCH 22/25] example: add assistant responses to simple chatbot --- examples/simple-chatbot/bot.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/simple-chatbot/bot.py b/examples/simple-chatbot/bot.py index c605414eb..87d1ca37f 100644 --- a/examples/simple-chatbot/bot.py +++ b/examples/simple-chatbot/bot.py @@ -8,7 +8,7 @@ from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineTask -from pipecat.processors.aggregators.llm_response import LLMUserResponseAggregator +from pipecat.processors.aggregators.llm_response import LLMAssistantResponseAggregator, LLMUserResponseAggregator from pipecat.frames.frames import ( AudioRawFrame, ImageRawFrame, @@ -135,6 +135,7 @@ async def main(room_url: str, token): ] user_response = LLMUserResponseAggregator() + assistant_response = LLMAssistantResponseAggregator() ta = TalkingAnimation() @@ -144,7 +145,8 @@ async def main(room_url: str, token): llm, tts, ta, - transport.output() + transport.output(), + assistant_response, ]) task = PipelineTask(pipeline, allow_interruptions=True) From 810dc30d3d513c5b8547e9374677463e62f5e47f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Sun, 19 May 2024 09:39:34 -0700 Subject: [PATCH 23/25] examples: fix examples to use LLMFullResponseEndFrame --- CHANGELOG.md | 2 +- examples/foundational/11-sound-effects.py | 4 ++-- examples/moondream-chatbot/bot.py | 2 +- examples/simple-chatbot/bot.py | 2 +- examples/storytelling-chatbot/src/processors.py | 10 +++++++--- src/pipecat/pipeline/parallel_pipeline.py | 2 +- 6 files changed, 13 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a26ebfbb3..10b964a30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Added initial interruptions support. The assitant contexts (or aggregators) +- Added initial interruptions support. The assistant contexts (or aggregators) should now be placed after the output transport. This way, only the completed spoken context is added to the assistant context. diff --git a/examples/foundational/11-sound-effects.py b/examples/foundational/11-sound-effects.py index 292a89052..c8d113c30 100644 --- a/examples/foundational/11-sound-effects.py +++ b/examples/foundational/11-sound-effects.py @@ -13,7 +13,7 @@ from pipecat.frames.frames import ( Frame, AudioRawFrame, - LLMResponseEndFrame, + LLMFullResponseEndFrame, LLMMessagesFrame, ) from pipecat.pipeline.pipeline import Pipeline @@ -59,7 +59,7 @@ class OutboundSoundEffectWrapper(FrameProcessor): async def process_frame(self, frame: Frame, direction: FrameDirection): - if isinstance(frame, LLMResponseEndFrame): + if isinstance(frame, LLMFullResponseEndFrame): await self.push_frame(sounds["ding1.wav"]) # In case anything else downstream needs it await self.push_frame(frame, direction) diff --git a/examples/moondream-chatbot/bot.py b/examples/moondream-chatbot/bot.py index b09f4345d..76948be55 100644 --- a/examples/moondream-chatbot/bot.py +++ b/examples/moondream-chatbot/bot.py @@ -66,7 +66,7 @@ class TalkingAnimation(FrameProcessor): """ This class starts a talking animation when it receives an first AudioFrame, - and then returns to a "quiet" sprite when it sees a LLMResponseEndFrame. + and then returns to a "quiet" sprite when it sees a TTSStoppedFrame. """ def __init__(self): diff --git a/examples/simple-chatbot/bot.py b/examples/simple-chatbot/bot.py index 87d1ca37f..eb15dc05a 100644 --- a/examples/simple-chatbot/bot.py +++ b/examples/simple-chatbot/bot.py @@ -56,7 +56,7 @@ class TalkingAnimation(FrameProcessor): """ This class starts a talking animation when it receives an first AudioFrame, - and then returns to a "quiet" sprite when it sees a LLMResponseEndFrame. + and then returns to a "quiet" sprite when it sees a TTSStoppedFrame. """ def __init__(self): diff --git a/examples/storytelling-chatbot/src/processors.py b/examples/storytelling-chatbot/src/processors.py index 30528af94..18428eb72 100644 --- a/examples/storytelling-chatbot/src/processors.py +++ b/examples/storytelling-chatbot/src/processors.py @@ -2,7 +2,11 @@ from async_timeout import timeout -from pipecat.frames.frames import Frame, LLMResponseEndFrame, TextFrame, UserStoppedSpeakingFrame +from pipecat.frames.frames import ( + Frame, + LLMFullResponseEndFrame, + TextFrame, + UserStoppedSpeakingFrame) from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.transports.services.daily import DailyTransportMessageFrame @@ -128,9 +132,9 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): # Clear the buffer self._text = "" - # End of LLM response + # End of a full LLM response # Driven by the prompt, the LLM should have asked the user for input - elif isinstance(frame, LLMResponseEndFrame): + elif isinstance(frame, LLMFullResponseEndFrame): # We use a different frame type, as to avoid image generation ingest await self.push_frame(StoryPromptFrame(self._text)) self._text = "" diff --git a/src/pipecat/pipeline/parallel_pipeline.py b/src/pipecat/pipeline/parallel_pipeline.py index 5dd840281..22ffdfdf2 100644 --- a/src/pipecat/pipeline/parallel_pipeline.py +++ b/src/pipecat/pipeline/parallel_pipeline.py @@ -63,7 +63,7 @@ def __init__(self, *args): if not isinstance(processors, list): raise TypeError(f"ParallelPipeline argument {processors} is not a list") - # We add a source at before the pipeline and a sink after. + # We will add a source before the pipeline and a sink after. source = Source(self._up_queue) sink = Sink(self._down_queue) self._sources.append(source) From c0d5054798924576d9317df465ce87b77ede84dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Sun, 19 May 2024 09:41:36 -0700 Subject: [PATCH 24/25] examples: some prompt tweaking --- examples/foundational/06-listen-and-respond.py | 2 +- examples/foundational/06a-image-sync.py | 2 +- examples/foundational/07-interruptible.py | 2 +- examples/moondream-chatbot/bot.py | 2 +- examples/simple-chatbot/bot.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/foundational/06-listen-and-respond.py b/examples/foundational/06-listen-and-respond.py index b7e0d13bc..a0475bd4b 100644 --- a/examples/foundational/06-listen-and-respond.py +++ b/examples/foundational/06-listen-and-respond.py @@ -64,7 +64,7 @@ async def main(room_url: str, token): messages = [ { "role": "system", - "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so never use special characters in your answers. Respond to what the user said in a creative and helpful way.", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.", }, ] tma_in = LLMUserResponseAggregator(messages) diff --git a/examples/foundational/06a-image-sync.py b/examples/foundational/06a-image-sync.py index 30f2eea95..4c9925b20 100644 --- a/examples/foundational/06a-image-sync.py +++ b/examples/foundational/06a-image-sync.py @@ -83,7 +83,7 @@ async def main(room_url: str, token): messages = [ { "role": "system", - "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so never use special characters in your answers. Respond to what the user said in a creative and helpful way.", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.", }, ] diff --git a/examples/foundational/07-interruptible.py b/examples/foundational/07-interruptible.py index ae337f7f6..44c785d2c 100644 --- a/examples/foundational/07-interruptible.py +++ b/examples/foundational/07-interruptible.py @@ -58,7 +58,7 @@ async def main(room_url: str, token): messages = [ { "role": "system", - "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so never use special characters. Respond to what the user said in a creative and helpful way.", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.", }, ] diff --git a/examples/moondream-chatbot/bot.py b/examples/moondream-chatbot/bot.py index 76948be55..7830cf46a 100644 --- a/examples/moondream-chatbot/bot.py +++ b/examples/moondream-chatbot/bot.py @@ -162,7 +162,7 @@ async def main(room_url: str, token): messages = [ { "role": "system", - "content": f"You are Chatbot, a friendly, helpful robot. Let the user know that you are capable of chatting or describing what you see. Your goal is to demonstrate your capabilities in a succinct way. Reply with only '{user_request_answer}' if the user asks you to describe what you see. Your output will be converted to audio so never use special characters in your answers. Respond to what the user said in a creative and helpful way, but keep your responses brief. Start by introducing yourself.", + "content": f"You are Chatbot, a friendly, helpful robot. Let the user know that you are capable of chatting or describing what you see. Your goal is to demonstrate your capabilities in a succinct way. Reply with only '{user_request_answer}' if the user asks you to describe what you see. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way, but keep your responses brief. Start by introducing yourself.", }, ] diff --git a/examples/simple-chatbot/bot.py b/examples/simple-chatbot/bot.py index eb15dc05a..2c03a70f4 100644 --- a/examples/simple-chatbot/bot.py +++ b/examples/simple-chatbot/bot.py @@ -125,7 +125,7 @@ async def main(room_url: str, token): # # English # - "content": "You are Chatbot, a friendly, helpful robot. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so never use special characters in your answers. Respond to what the user said in a creative and helpful way, but keep your responses brief. Start by introducing yourself.", + "content": "You are Chatbot, a friendly, helpful robot. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way, but keep your responses brief. Start by introducing yourself.", # # Spanish From c3bfcbd562cdc8ce8452b335ee6c8b774664a0f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Sun, 19 May 2024 10:20:17 -0700 Subject: [PATCH 25/25] aggregators: clear accumulated responses if interruption happens --- .../processors/aggregators/llm_response.py | 26 +++++++++++-------- .../processors/aggregators/user_response.py | 26 +++++++++++-------- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 3b9c07fe6..853217064 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -13,6 +13,7 @@ LLMFullResponseEndFrame, LLMMessagesFrame, LLMResponseStartFrame, + StartInterruptionFrame, TextFrame, LLMResponseEndFrame, TranscriptionFrame, @@ -40,12 +41,9 @@ def __init__( self._end_frame = end_frame self._accumulator_frame = accumulator_frame self._interim_accumulator_frame = interim_accumulator_frame - self._seen_start_frame = False - self._seen_end_frame = False - self._seen_interim_results = False - self._aggregation = "" - self._aggregating = False + # Reset our accumulator state. + self._reset() # # Frame processor @@ -96,6 +94,9 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): self._seen_interim_results = False elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame): self._seen_interim_results = True + elif isinstance(frame, StartInterruptionFrame): + self._reset() + await self.push_frame(frame, direction) else: await self.push_frame(frame, direction) @@ -108,12 +109,15 @@ async def _push_aggregation(self): frame = LLMMessagesFrame(self._messages) await self.push_frame(frame) - # Reset - self._aggregation = "" - self._aggregating = False - self._seen_start_frame = False - self._seen_end_frame = False - self._seen_interim_results = False + # Reset our accumulator state. + self._reset() + + def _reset(self): + self._aggregation = "" + self._aggregating = False + self._seen_start_frame = False + self._seen_end_frame = False + self._seen_interim_results = False class LLMAssistantResponseAggregator(LLMResponseAggregator): diff --git a/src/pipecat/processors/aggregators/user_response.py b/src/pipecat/processors/aggregators/user_response.py index 5c6520f1f..5b1a8e309 100644 --- a/src/pipecat/processors/aggregators/user_response.py +++ b/src/pipecat/processors/aggregators/user_response.py @@ -8,6 +8,7 @@ from pipecat.frames.frames import ( Frame, InterimTranscriptionFrame, + StartInterruptionFrame, TextFrame, TranscriptionFrame, UserStartedSpeakingFrame, @@ -56,12 +57,9 @@ def __init__( self._end_frame = end_frame self._accumulator_frame = accumulator_frame self._interim_accumulator_frame = interim_accumulator_frame - self._seen_start_frame = False - self._seen_end_frame = False - self._seen_interim_results = False - self._aggregation = "" - self._aggregating = False + # Reset our accumulator state. + self._reset() # # Frame processor @@ -112,6 +110,9 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): self._seen_interim_results = False elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame): self._seen_interim_results = True + elif isinstance(frame, StartInterruptionFrame): + self._reset() + await self.push_frame(frame, direction) else: await self.push_frame(frame, direction) @@ -122,12 +123,15 @@ async def _push_aggregation(self): if len(self._aggregation) > 0: await self.push_frame(TextFrame(self._aggregation.strip())) - # Reset - self._aggregation = "" - self._aggregating = False - self._seen_start_frame = False - self._seen_end_frame = False - self._seen_interim_results = False + # Reset our accumulator state. + self._reset() + + def _reset(self): + self._aggregation = "" + self._aggregating = False + self._seen_start_frame = False + self._seen_end_frame = False + self._seen_interim_results = False class UserResponseAggregator(ResponseAggregator):