diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index ce3e13494..b806efad4 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -27,6 +27,13 @@ jobs: uses: actions/setup-python@v4 with: python-version: "3.10" + - name: Cache virtual environment + uses: actions/cache@v3 + with: + # We are hashing dev-requirements.txt and test-requirements.txt which + # contain all dependencies needed to run the tests. + key: venv-${{ runner.os }}-${{ steps.setup_python.outputs.python-version}}-${{ hashFiles('dev-requirements.txt') }}-${{ hashFiles('test-requirements.txt') }} + path: .venv - name: Install system packages id: install_system_packages run: | diff --git a/CHANGELOG.md b/CHANGELOG.md index c1d571da5..b59ed56c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,30 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added Google TTS service and corresponding foundational example `07n-interruptible-google.py` + +- Added AWS Polly TTS support and `07m-interruptible-aws.py` as an example. + +- Added InputParams to Azure TTS service. + +- All `FrameProcessors` can now register event handlers. + +``` +tts = SomeTTSService(...) + +@tts.event_handler("on_connected"): +async def on_connected(processor): + ... +``` + +- Added `AsyncGeneratorProcessor`. This processor can be used together with a + `FrameSerializer` as an async generator. It provides a `generator()` function + that returns an `AsyncGenerator` and that yields serialized frames. + +- Added `EndTaskFrame` and `CancelTaskFrame`. These are new frames that are + meant to be pushed upstream to tell the pipeline task to stop nicely or + immediately respectively. + - Added configurable LLM parameters (e.g., temperature, top_p, max_tokens, seed) for OpenAI, Anthropic, and Together AI services along with corresponding setter functions. @@ -24,15 +48,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 frames. To achieve that, each frame processor should only output frames from a single task. - In this version we introduce synchronous and asynchronous frame - processors. The synchronous processors push output frames from the same task - that they receive input frames, and therefore only pushing frames from one - task. Asynchronous frame processors can have internal tasks to perform things - asynchronously (e.g. receiving data from a websocket) but they also have a - single task where they push frames from. - - By default, frame processors are synchronous. To change a frame processor to - asynchronous you only need to pass `sync=False` to the base class constructor. + In this version all the frame processors have their own task to push + frames. That is, when `push_frame()` is called the given frame will be put + into an internal queue (with the exception of system frames) and a frame + processor task will push it out. - Added pipeline clocks. A pipeline clock is used by the output transport to know when a frame needs to be presented. For that, all frames now have an @@ -44,9 +63,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 `SystemClock`). This clock will be passed to each frame processor via the `StartFrame`. -- Added `CartesiaHttpTTSService`. This is a synchronous frame processor - (i.e. given an input text frame it will wait for the whole output before - returning). +- Added `CartesiaHttpTTSService`. - `DailyTransport` now supports setting the audio bitrate to improve audio quality through the `DailyParams.audio_out_bitrate` parameter. The new @@ -69,6 +86,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Updated individual update settings frame classes into a single UpdateSettingsFrame + class for STT, LLM, and TTS. + - We now distinguish between input and output audio and image frames. We introduce `InputAudioRawFrame`, `OutputAudioRawFrame`, `InputImageRawFrame` and `OutputImageRawFrame` (and other subclasses of those). The input frames @@ -83,8 +103,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 pipelines to be executed concurrently. The difference between a `SyncParallelPipeline` and a `ParallelPipeline` is that, given an input frame, the `SyncParallelPipeline` will wait for all the internal pipelines to - complete. This is achieved by ensuring all the processors in each of the - internal pipelines are synchronous. + complete. This is achieved by making sure the last processor in each of the + pipelines is synchronous (e.g. an HTTP-based service that waits for the + response). - `StartFrame` is back a system frame so we make sure it's processed immediately by all processors. `EndFrame` stays a control frame since it needs to be diff --git a/README.md b/README.md index faf0137dc..793d1f630 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ pip install "pipecat-ai[option,...]" Your project may or may not need these, so they're made available as optional requirements. Here is a list: -- **AI services**: `anthropic`, `azure`, `deepgram`, `gladia`, `google`, `fal`, `lmnt`, `moondream`, `openai`, `openpipe`, `playht`, `silero`, `whisper`, `xtts` +- **AI services**: `anthropic`, `aws`, `azure`, `deepgram`, `gladia`, `google`, `fal`, `lmnt`, `moondream`, `openai`, `openpipe`, `playht`, `silero`, `whisper`, `xtts` - **Transports**: `local`, `websocket`, `daily` ## Code examples @@ -110,7 +110,6 @@ python app.py Daily provides a prebuilt WebRTC user interface. Whilst the app is running, you can visit at `https://.daily.co/` and listen to the bot say hello! - ## WebRTC for production use WebSockets are fine for server-to-server communication or for initial development. But for production use, you’ll need client-server audio to use a protocol designed for real-time media transport. (For an explanation of the difference between WebSockets and WebRTC, see [this post.](https://www.daily.co/blog/how-to-talk-to-an-llm-with-your-voice/#webrtc)) @@ -131,7 +130,6 @@ pip install pipecat-ai[silero] The first time your run your bot with Silero, startup may take a while whilst it downloads and caches the model in the background. You can check the progress of this in the console. - ## Hacking on the framework itself _Note that you may need to set up a virtual environment before following the instructions below. For instance, you might need to run the following from the root of the repo:_ diff --git a/dot-env.template b/dot-env.template index 085e8b19d..e940b1076 100644 --- a/dot-env.template +++ b/dot-env.template @@ -1,6 +1,11 @@ # Anthropic ANTHROPIC_API_KEY=... +# AWS +AWS_SECRET_ACCESS_KEY=... +AWS_ACCESS_KEY_ID=... +AWS_REGION=... + # Azure AZURE_SPEECH_REGION=... AZURE_SPEECH_API_KEY=... diff --git a/examples/foundational/05-sync-speech-and-image.py b/examples/foundational/05-sync-speech-and-image.py index dae860a92..5477d0691 100644 --- a/examples/foundational/05-sync-speech-and-image.py +++ b/examples/foundational/05-sync-speech-and-image.py @@ -86,13 +86,13 @@ async def main(): ), ) + llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") + tts = CartesiaHttpTTSService( api_key=os.getenv("CARTESIA_API_KEY"), voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady ) - llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") - imagegen = FalImageGenService( params=FalImageGenService.InputParams(image_size="square_hd"), aiohttp_session=session, @@ -107,8 +107,10 @@ async def main(): # that, each pipeline runs concurrently and `SyncParallelPipeline` will # wait for the input frame to be processed. # - # Note that `SyncParallelPipeline` requires all processors in it to be - # synchronous (which is the default for most processors). + # Note that `SyncParallelPipeline` requires the last processor in each + # of the pipelines to be synchronous. In this case, we use + # `CartesiaHttpTTSService` and `FalImageGenService` which make HTTP + # requests and wait for the response. pipeline = Pipeline( [ llm, # LLM diff --git a/examples/foundational/05a-local-sync-speech-and-image.py b/examples/foundational/05a-local-sync-speech-and-image.py index 27c36f6ce..4a561c073 100644 --- a/examples/foundational/05a-local-sync-speech-and-image.py +++ b/examples/foundational/05a-local-sync-speech-and-image.py @@ -82,6 +82,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): self.frame = OutputAudioRawFrame( bytes(self.audio), frame.sample_rate, frame.num_channels ) + await self.push_frame(frame, direction) class ImageGrabber(FrameProcessor): def __init__(self): @@ -93,6 +94,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): if isinstance(frame, URLImageRawFrame): self.frame = frame + await self.push_frame(frame, direction) llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") @@ -121,8 +123,10 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): # `SyncParallelPipeline` will wait for the input frame to be # processed. # - # Note that `SyncParallelPipeline` requires all processors in it to - # be synchronous (which is the default for most processors). + # Note that `SyncParallelPipeline` requires the last processor in + # each of the pipelines to be synchronous. In this case, we use + # `CartesiaHttpTTSService` and `FalImageGenService` which make HTTP + # requests and wait for the response. pipeline = Pipeline( [ llm, # LLM diff --git a/examples/foundational/07a-interruptible-anthropic.py b/examples/foundational/07a-interruptible-anthropic.py index 2bded2480..288cb1b31 100644 --- a/examples/foundational/07a-interruptible-anthropic.py +++ b/examples/foundational/07a-interruptible-anthropic.py @@ -5,29 +5,24 @@ # import asyncio -import aiohttp import os import sys +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from runner import configure + from pipecat.frames.frames import LLMMessagesFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask -from pipecat.processors.aggregators.llm_response import ( - LLMAssistantResponseAggregator, - LLMUserResponseAggregator, -) -from pipecat.services.cartesia import CartesiaTTSService +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.services.anthropic import AnthropicLLMService +from pipecat.services.cartesia import CartesiaTTSService from pipecat.transports.services.daily import DailyParams, DailyTransport from pipecat.vad.silero import SileroVADAnalyzer -from runner import configure - -from loguru import logger - -from dotenv import load_dotenv - load_dotenv(override=True) logger.remove(0) @@ -69,17 +64,17 @@ async def main(): }, ] - tma_in = LLMUserResponseAggregator(messages) - tma_out = LLMAssistantResponseAggregator(messages) + context = OpenAILLMContext(messages) + context_aggregator = llm.create_context_aggregator(context) pipeline = Pipeline( [ transport.input(), # Transport user input - tma_in, # User responses + context_aggregator.user(), # User responses llm, # LLM tts, # TTS transport.output(), # Transport bot output - tma_out, # Assistant spoken responses + context_aggregator.assistant(), # Assistant spoken responses ] ) diff --git a/examples/foundational/07c-interruptible-deepgram.py b/examples/foundational/07c-interruptible-deepgram.py index 41bef8a47..fc33c246f 100644 --- a/examples/foundational/07c-interruptible-deepgram.py +++ b/examples/foundational/07c-interruptible-deepgram.py @@ -5,10 +5,14 @@ # import asyncio -import aiohttp import os import sys +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from runner import configure + from pipecat.frames.frames import LLMMessagesFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner @@ -22,12 +26,6 @@ from pipecat.transports.services.daily import DailyParams, DailyTransport from pipecat.vad.silero import SileroVADAnalyzer -from runner import configure - -from loguru import logger - -from dotenv import load_dotenv - load_dotenv(override=True) logger.remove(0) @@ -52,9 +50,7 @@ async def main(): stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) - tts = DeepgramTTSService( - aiohttp_session=session, api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-helios-en" - ) + tts = DeepgramTTSService(api_key=os.getenv("DEEPGRAM_API_KEY"), voice="aura-helios-en") llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") diff --git a/examples/foundational/07l-interruptible-together.py b/examples/foundational/07l-interruptible-together.py index e2cb55fed..ca3386718 100644 --- a/examples/foundational/07l-interruptible-together.py +++ b/examples/foundational/07l-interruptible-together.py @@ -5,29 +5,24 @@ # import asyncio -import aiohttp import os import sys +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from runner import configure + from pipecat.frames.frames import LLMMessagesFrame from pipecat.pipeline.pipeline import Pipeline from pipecat.pipeline.runner import PipelineRunner from pipecat.pipeline.task import PipelineParams, PipelineTask -from pipecat.processors.aggregators.llm_response import ( - LLMAssistantResponseAggregator, - LLMUserResponseAggregator, -) +from pipecat.services.ai_services import OpenAILLMContext from pipecat.services.cartesia import CartesiaTTSService from pipecat.services.together import TogetherLLMService from pipecat.transports.services.daily import DailyParams, DailyTransport from pipecat.vad.silero import SileroVADAnalyzer -from runner import configure - -from loguru import logger - -from dotenv import load_dotenv - load_dotenv(override=True) logger.remove(0) @@ -76,17 +71,19 @@ async def main(): }, ] - tma_in = LLMUserResponseAggregator(messages) - tma_out = LLMAssistantResponseAggregator(messages) + context = OpenAILLMContext(messages) + context_aggregator = llm.create_context_aggregator(context) + user_aggregator = context_aggregator.user() + assistant_aggregator = context_aggregator.assistant() pipeline = Pipeline( [ transport.input(), # Transport user input - tma_in, # User responses + user_aggregator, # User responses llm, # LLM tts, # TTS transport.output(), # Transport bot output - tma_out, # Assistant spoken responses + assistant_aggregator, # Assistant spoken responses ] ) diff --git a/examples/foundational/07m-interruptible-aws.py b/examples/foundational/07m-interruptible-aws.py new file mode 100644 index 000000000..69d4b84c1 --- /dev/null +++ b/examples/foundational/07m-interruptible-aws.py @@ -0,0 +1,102 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import os +import sys + +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from runner import configure + +from pipecat.frames.frames import LLMMessagesFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantResponseAggregator, + LLMUserResponseAggregator, +) +from pipecat.services.aws import AWSTTSService +from pipecat.services.deepgram import DeepgramSTTService +from pipecat.services.openai import OpenAILLMService +from pipecat.transports.services.daily import DailyParams, DailyTransport +from pipecat.vad.silero import SileroVADAnalyzer + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +async def main(): + async with aiohttp.ClientSession() as session: + (room_url, token) = await configure(session) + + transport = DailyTransport( + room_url, + token, + "Respond bot", + DailyParams( + audio_out_enabled=True, + audio_out_sample_rate=16000, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + vad_audio_passthrough=True, + ), + ) + + stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + + tts = AWSTTSService( + api_key=os.getenv("AWS_SECRET_ACCESS_KEY"), + aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), + region=os.getenv("AWS_REGION"), + voice_id="Amy", + params=AWSTTSService.InputParams(engine="neural", language="en-GB", rate="1.05"), + ) + + llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.", + }, + ] + + tma_in = LLMUserResponseAggregator(messages) + tma_out = LLMAssistantResponseAggregator(messages) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + stt, # STT + tma_in, # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + tma_out, # Assistant spoken responses + ] + ) + + task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True)) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + transport.capture_participant_transcription(participant["id"]) + # Kick off the conversation. + messages.append({"role": "system", "content": "Please introduce yourself to the user."}) + await task.queue_frames([LLMMessagesFrame(messages)]) + + runner = PipelineRunner() + + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/foundational/07n-interruptible-google.py b/examples/foundational/07n-interruptible-google.py new file mode 100644 index 000000000..713b3dce3 --- /dev/null +++ b/examples/foundational/07n-interruptible-google.py @@ -0,0 +1,100 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import os +import sys + +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from runner import configure + +from pipecat.frames.frames import LLMMessagesFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantResponseAggregator, + LLMUserResponseAggregator, +) +from pipecat.services.deepgram import DeepgramSTTService +from pipecat.services.google import GoogleTTSService +from pipecat.services.openai import OpenAILLMService +from pipecat.transports.services.daily import DailyParams, DailyTransport +from pipecat.vad.silero import SileroVADAnalyzer + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +async def main(): + async with aiohttp.ClientSession() as session: + (room_url, token) = await configure(session) + + transport = DailyTransport( + room_url, + token, + "Respond bot", + DailyParams( + audio_out_enabled=True, + audio_out_sample_rate=24000, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + vad_audio_passthrough=True, + ), + ) + + stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + + tts = GoogleTTSService( + credentials=os.getenv("GOOGLE_CREDENTIALS"), + voice_id="en-US-Neural2-J", + params=GoogleTTSService.InputParams(language="en-US", rate="1.05"), + ) + + llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.", + }, + ] + + tma_in = LLMUserResponseAggregator(messages) + tma_out = LLMAssistantResponseAggregator(messages) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + stt, # STT + tma_in, # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + tma_out, # Assistant spoken responses + ] + ) + + task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True)) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + transport.capture_participant_transcription(participant["id"]) + # Kick off the conversation. + messages.append({"role": "system", "content": "Please introduce yourself to the user."}) + await task.queue_frames([LLMMessagesFrame(messages)]) + + runner = PipelineRunner() + + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/foundational/12c-describe-video-anthropic.py b/examples/foundational/12c-describe-video-anthropic.py index 7458adf69..c7267467a 100644 --- a/examples/foundational/12c-describe-video-anthropic.py +++ b/examples/foundational/12c-describe-video-anthropic.py @@ -78,7 +78,9 @@ async def main(): tts = CartesiaTTSService( api_key=os.getenv("CARTESIA_API_KEY"), voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady - sample_rate=16000, + params=CartesiaTTSService.InputParams( + sample_rate=16000, + ), ) @transport.event_handler("on_first_participant_joined") diff --git a/examples/foundational/14-function-calling.py b/examples/foundational/14-function-calling.py index b5aba449c..9141029ca 100644 --- a/examples/foundational/14-function-calling.py +++ b/examples/foundational/14-function-calling.py @@ -34,7 +34,12 @@ async def start_fetch_weather(function_name, llm, context): - await llm.push_frame(TextFrame("Let me check on that.")) + # note: we can't push a frame to the LLM here. the bot + # can interrupt itself and/or cause audio overlapping glitches. + # possible question for Aleix and Chad about what the right way + # to trigger speech is, now, with the new queues/async/sync refactors. + # await llm.push_frame(TextFrame("Let me check on that.")) + logger.debug(f"Starting fetch_weather_from_api with function_name: {function_name}") async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback): @@ -106,11 +111,11 @@ async def main(): pipeline = Pipeline( [ - fl_in, + # fl_in, transport.input(), context_aggregator.user(), llm, - fl_out, + # fl_out, tts, transport.output(), context_aggregator.assistant(), diff --git a/examples/studypal/studypal.py b/examples/studypal/studypal.py index 2364c65cf..58d5eb2f5 100644 --- a/examples/studypal/studypal.py +++ b/examples/studypal/studypal.py @@ -131,7 +131,9 @@ async def main(): api_key=os.getenv("CARTESIA_API_KEY"), voice_id=os.getenv("CARTESIA_VOICE_ID", "4d2fd738-3b3d-4368-957a-bb4805275bd9"), # British Narration Lady: 4d2fd738-3b3d-4368-957a-bb4805275bd9 - sample_rate=44100, + params=CartesiaTTSService.InputParams( + sample_rate=44100, + ), ) llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o-mini") diff --git a/pyproject.toml b/pyproject.toml index 943153d2d..a29697bdb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,15 +35,16 @@ Website = "https://pipecat.ai" [project.optional-dependencies] anthropic = [ "anthropic~=0.34.0" ] +aws = [ "boto3~=1.35.27" ] azure = [ "azure-cognitiveservices-speech~=1.40.0" ] cartesia = [ "cartesia~=1.0.13", "websockets~=12.0" ] -daily = [ "daily-python~=0.10.1" ] +daily = [ "daily-python~=0.11.0" ] deepgram = [ "deepgram-sdk~=3.5.0" ] elevenlabs = [ "websockets~=12.0" ] examples = [ "python-dotenv~=1.0.1", "flask~=3.0.3", "flask_cors~=4.0.1" ] fal = [ "fal-client~=0.4.1" ] gladia = [ "websockets~=12.0" ] -google = [ "google-generativeai~=0.7.2" ] +google = [ "google-generativeai~=0.7.2", "google-cloud-texttospeech~=2.17.2" ] gstreamer = [ "pygobject~=3.48.2" ] fireworks = [ "openai~=1.37.2" ] langchain = [ "langchain~=0.2.14", "langchain-community~=0.2.12", "langchain-openai~=0.1.20" ] @@ -56,7 +57,7 @@ openpipe = [ "openpipe~=4.24.0" ] playht = [ "pyht~=0.0.28" ] silero = [ "onnxruntime>=1.16.1" ] together = [ "together~=1.2.7" ] -websocket = [ "websockets~=12.0", "fastapi~=0.112.1" ] +websocket = [ "websockets~=12.0", "fastapi~=0.115.0" ] whisper = [ "faster-whisper~=1.0.3" ] xtts = [ "resampy~=0.4.3" ] diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index e4495098b..f7faa8ef0 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -4,9 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # -from typing import Any, List, Optional, Tuple - from dataclasses import dataclass, field +from typing import Any, List, Optional, Tuple, Union from pipecat.clocks.base_clock import BaseClock from pipecat.metrics.metrics import MetricsData @@ -132,9 +131,7 @@ class VisionImageRawFrame(InputImageRawFrame): def __str__(self): pts = format_pts(self.pts) - return ( - f"{self.name}(pts: {pts}, text: {self.text}, size: {self.size}, format: {self.format})" - ) + return f"{self.name}(pts: {pts}, text: [{self.text}], size: {self.size}, format: {self.format})" @dataclass @@ -177,7 +174,7 @@ class TextFrame(DataFrame): def __str__(self): pts = format_pts(self.pts) - return f"{self.name}(pts: {pts}, text: {self.text})" + return f"{self.name}(pts: {pts}, text: [{self.text}])" @dataclass @@ -192,7 +189,7 @@ class TranscriptionFrame(TextFrame): language: Language | None = None def __str__(self): - return f"{self.name}(user: {self.user_id}, text: {self.text}, language: {self.language}, timestamp: {self.timestamp})" + return f"{self.name}(user: {self.user_id}, text: [{self.text}], language: {self.language}, timestamp: {self.timestamp})" @dataclass @@ -205,7 +202,7 @@ class InterimTranscriptionFrame(TextFrame): language: Language | None = None def __str__(self): - return f"{self.name}(user: {self.user_id}, text: {self.text}, language: {self.language}, timestamp: {self.timestamp})" + return f"{self.name}(user: {self.user_id}, text: [{self.text}], language: {self.language}, timestamp: {self.timestamp})" @dataclass @@ -341,6 +338,27 @@ class FatalErrorFrame(ErrorFrame): fatal: bool = field(default=True, init=False) +@dataclass +class EndTaskFrame(SystemFrame): + """This is used to notify the pipeline task that the pipeline should be + closed nicely (flushing all the queued frames) by pushing an EndFrame + downstream. + + """ + + pass + + +@dataclass +class CancelTaskFrame(SystemFrame): + """This is used to notify the pipeline task that the pipeline should be + stopped immediately by pushing a CancelFrame downstream. + + """ + + pass + + @dataclass class StopTaskFrame(SystemFrame): """Indicates that a pipeline task should be stopped but that the pipeline @@ -509,113 +527,45 @@ def __str__(self): @dataclass -class LLMModelUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new LLM model.""" - - model: str - - -@dataclass -class LLMTemperatureUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new LLM temperature.""" +class LLMUpdateSettingsFrame(ControlFrame): + """A control frame containing a request to update LLM settings.""" - temperature: float + model: Optional[str] = None + temperature: Optional[float] = None + top_k: Optional[int] = None + top_p: Optional[float] = None + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + max_tokens: Optional[int] = None + seed: Optional[int] = None + extra: dict = field(default_factory=dict) @dataclass -class LLMTopKUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new LLM top_k.""" - - top_k: int - - -@dataclass -class LLMTopPUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new LLM top_p.""" - - top_p: float - - -@dataclass -class LLMFrequencyPenaltyUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new LLM frequency - penalty. - - """ - - frequency_penalty: float - - -@dataclass -class LLMPresencePenaltyUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new LLM presence - penalty. - - """ - - presence_penalty: float - - -@dataclass -class LLMMaxTokensUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new LLM max tokens.""" - - max_tokens: int - - -@dataclass -class LLMSeedUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new LLM seed.""" - - seed: int - - -@dataclass -class LLMExtraUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new LLM extra params.""" - - extra: dict - - -@dataclass -class TTSModelUpdateFrame(ControlFrame): - """A control frame containing a request to update the TTS model.""" - - model: str - - -@dataclass -class TTSVoiceUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new TTS voice.""" - - voice: str - - -@dataclass -class TTSLanguageUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new TTS language and - optional voice. - - """ - - language: Language - - -@dataclass -class STTModelUpdateFrame(ControlFrame): - """A control frame containing a request to update the STT model and optional - language. - - """ +class TTSUpdateSettingsFrame(ControlFrame): + """A control frame containing a request to update TTS settings.""" - model: str + model: Optional[str] = None + voice: Optional[str] = None + language: Optional[Language] = None + speed: Optional[Union[str, float]] = None + emotion: Optional[List[str]] = None + engine: Optional[str] = None + pitch: Optional[str] = None + rate: Optional[str] = None + volume: Optional[str] = None + emphasis: Optional[str] = None + style: Optional[str] = None + style_degree: Optional[str] = None + role: Optional[str] = None @dataclass -class STTLanguageUpdateFrame(ControlFrame): - """A control frame containing a request to update to STT language.""" +class STTUpdateSettingsFrame(ControlFrame): + """A control frame containing a request to update STT settings.""" - language: Language + model: Optional[str] = None + language: Optional[Language] = None @dataclass @@ -635,6 +585,7 @@ class FunctionCallResultFrame(DataFrame): tool_call_id: str arguments: str result: Any + run_llm: bool = True @dataclass diff --git a/src/pipecat/pipeline/sync_parallel_pipeline.py b/src/pipecat/pipeline/sync_parallel_pipeline.py index 854cea89d..20f4275e4 100644 --- a/src/pipecat/pipeline/sync_parallel_pipeline.py +++ b/src/pipecat/pipeline/sync_parallel_pipeline.py @@ -6,17 +6,25 @@ import asyncio +from dataclasses import dataclass from itertools import chain from typing import List +from pipecat.frames.frames import ControlFrame, EndFrame, Frame, SystemFrame from pipecat.pipeline.base_pipeline import BasePipeline from pipecat.pipeline.pipeline import Pipeline from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -from pipecat.frames.frames import Frame from loguru import logger +@dataclass +class SyncFrame(ControlFrame): + """This frame is used to know when the internal pipelines have finished.""" + + pass + + class Source(FrameProcessor): def __init__(self, upstream_queue: asyncio.Queue): super().__init__() @@ -67,13 +75,16 @@ def __init__(self, *args): raise TypeError(f"SyncParallelPipeline argument {processors} is not a list") # We add a source at the beginning of the pipeline and a sink at the end. - source = Source(self._up_queue) - sink = Sink(self._down_queue) + up_queue = asyncio.Queue() + down_queue = asyncio.Queue() + source = Source(up_queue) + sink = Sink(down_queue) processors: List[FrameProcessor] = [source] + processors + [sink] - # Keep track of sources and sinks. - self._sources.append(source) - self._sinks.append(sink) + # Keep track of sources and sinks. We also keep the output queue of + # the source and the sinks so we can use it later. + self._sources.append({"processor": source, "queue": down_queue}) + self._sinks.append({"processor": sink, "queue": up_queue}) # Create pipeline pipeline = Pipeline(processors) @@ -94,17 +105,52 @@ def processors_with_metrics(self) -> List[FrameProcessor]: async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) + # The last processor of each pipeline needs to be synchronous otherwise + # this element won't work. Since, we know it should be synchronous we + # push a SyncFrame. Since frames are ordered we know this frame will be + # pushed after the synchronous processor has pushed its data allowing us + # to synchrnonize all the internal pipelines by waiting for the + # SyncFrame in all of them. + async def wait_for_sync( + obj, main_queue: asyncio.Queue, frame: Frame, direction: FrameDirection + ): + processor = obj["processor"] + queue = obj["queue"] + + await processor.process_frame(frame, direction) + + if isinstance(frame, (SystemFrame, EndFrame)): + new_frame = await queue.get() + if isinstance(new_frame, (SystemFrame, EndFrame)): + await main_queue.put(new_frame) + else: + while not isinstance(new_frame, (SystemFrame, EndFrame)): + await main_queue.put(new_frame) + queue.task_done() + new_frame = await queue.get() + else: + await processor.process_frame(SyncFrame(), direction) + new_frame = await queue.get() + while not isinstance(new_frame, SyncFrame): + await main_queue.put(new_frame) + queue.task_done() + new_frame = await queue.get() + if direction == FrameDirection.UPSTREAM: # If we get an upstream frame we process it in each sink. - await asyncio.gather(*[s.process_frame(frame, direction) for s in self._sinks]) + await asyncio.gather( + *[wait_for_sync(s, self._up_queue, frame, direction) for s in self._sinks] + ) elif direction == FrameDirection.DOWNSTREAM: # If we get a downstream frame we process it in each source. - await asyncio.gather(*[s.process_frame(frame, direction) for s in self._sources]) + await asyncio.gather( + *[wait_for_sync(s, self._down_queue, frame, direction) for s in self._sources] + ) seen_ids = set() while not self._up_queue.empty(): frame = await self._up_queue.get() - if frame and frame.id not in seen_ids: + if frame.id not in seen_ids: await self.push_frame(frame, FrameDirection.UPSTREAM) seen_ids.add(frame.id) self._up_queue.task_done() @@ -112,7 +158,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): seen_ids = set() while not self._down_queue.empty(): frame = await self._down_queue.get() - if frame and frame.id not in seen_ids: + if frame.id not in seen_ids: await self.push_frame(frame, FrameDirection.DOWNSTREAM) seen_ids.add(frame.id) self._down_queue.task_done() diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index 2b46c47c2..96845430d 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -14,7 +14,9 @@ from pipecat.clocks.system_clock import SystemClock from pipecat.frames.frames import ( CancelFrame, + CancelTaskFrame, EndFrame, + EndTaskFrame, ErrorFrame, Frame, MetricsFrame, @@ -52,7 +54,13 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): await self.push_frame(frame, direction) async def _handle_upstream_frame(self, frame: Frame): - if isinstance(frame, ErrorFrame): + if isinstance(frame, EndTaskFrame): + # Tell the task we should end nicely. + await self._up_queue.put(EndTaskFrame()) + elif isinstance(frame, CancelTaskFrame): + # Tell the task we should end right away. + await self._up_queue.put(CancelTaskFrame()) + elif isinstance(frame, ErrorFrame): logger.error(f"Error running app: {frame}") if frame.fatal: # Cancel all tasks downstream. @@ -61,6 +69,19 @@ async def _handle_upstream_frame(self, frame: Frame): await self._up_queue.put(StopTaskFrame()) +class Sink(FrameProcessor): + def __init__(self, down_queue: asyncio.Queue): + super().__init__() + self._down_queue = down_queue + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + # We really just want to know when the EndFrame reached the sink. + if isinstance(frame, EndFrame): + await self._down_queue.put(frame) + + class PipelineTask: def __init__( self, @@ -76,12 +97,16 @@ def __init__( self._params = params self._finished = False - self._down_queue = asyncio.Queue() self._up_queue = asyncio.Queue() + self._down_queue = asyncio.Queue() + self._push_queue = asyncio.Queue() self._source = Source(self._up_queue) self._source.link(pipeline) + self._sink = Sink(self._down_queue) + pipeline.link(self._sink) + def has_finished(self): return self._finished @@ -95,19 +120,19 @@ async def cancel(self): # out-of-band from the main streaming task which is what we want since # we want to cancel right away. await self._source.push_frame(CancelFrame()) - self._process_down_task.cancel() + self._process_push_task.cancel() self._process_up_task.cancel() - await self._process_down_task + await self._process_push_task await self._process_up_task async def run(self): self._process_up_task = asyncio.create_task(self._process_up_queue()) - self._process_down_task = asyncio.create_task(self._process_down_queue()) - await asyncio.gather(self._process_up_task, self._process_down_task) + self._process_push_task = asyncio.create_task(self._process_push_queue()) + await asyncio.gather(self._process_up_task, self._process_push_task) self._finished = True async def queue_frame(self, frame: Frame): - await self._down_queue.put(frame) + await self._push_queue.put(frame) async def queue_frames(self, frames: Iterable[Frame] | AsyncIterable[Frame]): if isinstance(frames, AsyncIterable): @@ -125,7 +150,7 @@ def _initial_metrics_frame(self) -> MetricsFrame: data.append(ProcessingMetricsData(processor=p.name, value=0.0)) return MetricsFrame(data=data) - async def _process_down_queue(self): + async def _process_push_queue(self): self._clock.start() start_frame = StartFrame( @@ -146,11 +171,13 @@ async def _process_down_queue(self): should_cleanup = True while running: try: - frame = await self._down_queue.get() + frame = await self._push_queue.get() await self._source.process_frame(frame, FrameDirection.DOWNSTREAM) + if isinstance(frame, EndFrame): + await self._wait_for_endframe() running = not (isinstance(frame, StopTaskFrame) or isinstance(frame, EndFrame)) should_cleanup = not isinstance(frame, StopTaskFrame) - self._down_queue.task_done() + self._push_queue.task_done() except asyncio.CancelledError: break # Cleanup only if we need to. @@ -161,11 +188,21 @@ async def _process_down_queue(self): self._process_up_task.cancel() await self._process_up_task + async def _wait_for_endframe(self): + # NOTE(aleix): the Sink element just pushes EndFrames to the down queue, + # so just wait for it. In the future we might do something else here, + # but for now this is fine. + await self._down_queue.get() + async def _process_up_queue(self): while True: try: frame = await self._up_queue.get() - if isinstance(frame, StopTaskFrame): + if isinstance(frame, EndTaskFrame): + await self.queue_frame(EndFrame()) + elif isinstance(frame, CancelTaskFrame): + await self.queue_frame(CancelFrame()) + elif isinstance(frame, StopTaskFrame): await self.queue_frame(StopTaskFrame()) self._up_queue.task_done() except asyncio.CancelledError: diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 036f5fe47..479746471 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -6,12 +6,6 @@ from typing import List, Type -from pipecat.processors.aggregators.openai_llm_context import ( - OpenAILLMContextFrame, - OpenAILLMContext, -) - -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.frames.frames import ( Frame, InterimTranscriptionFrame, @@ -22,11 +16,16 @@ LLMMessagesUpdateFrame, LLMSetToolsFrame, StartInterruptionFrame, - TranscriptionFrame, TextFrame, + TranscriptionFrame, UserStartedSpeakingFrame, UserStoppedSpeakingFrame, ) +from pipecat.processors.aggregators.openai_llm_context import ( + OpenAILLMContext, + OpenAILLMContextFrame, +) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor class LLMResponseAggregator(FrameProcessor): @@ -40,6 +39,7 @@ def __init__( accumulator_frame: Type[TextFrame], interim_accumulator_frame: Type[TextFrame] | None = None, handle_interruptions: bool = False, + expect_stripped_words: bool = True, # if True, need to add spaces between words ): super().__init__() @@ -50,6 +50,7 @@ def __init__( self._accumulator_frame = accumulator_frame self._interim_accumulator_frame = interim_accumulator_frame self._handle_interruptions = handle_interruptions + self._expect_stripped_words = expect_stripped_words # Reset our accumulator state. self._reset() @@ -111,7 +112,10 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): await self.push_frame(frame, direction) elif isinstance(frame, self._accumulator_frame): if self._aggregating: - self._aggregation += f" {frame.text}" if self._aggregation else frame.text + if self._expect_stripped_words: + self._aggregation += f" {frame.text}" if self._aggregation else frame.text + else: + self._aggregation += frame.text # We have recevied a complete sentence, so if we have seen the # end frame and we were still aggregating, it means we should # send the aggregation. @@ -290,7 +294,7 @@ async def _push_aggregation(self): class LLMAssistantContextAggregator(LLMContextAggregator): - def __init__(self, context: OpenAILLMContext): + def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True): super().__init__( messages=[], context=context, @@ -299,6 +303,7 @@ def __init__(self, context: OpenAILLMContext): end_frame=LLMFullResponseEndFrame, accumulator_frame=TextFrame, handle_interruptions=True, + expect_stripped_words=expect_stripped_words, ) diff --git a/src/pipecat/processors/aggregators/openai_llm_context.py b/src/pipecat/processors/aggregators/openai_llm_context.py index 83ec3e57f..4bf3f042c 100644 --- a/src/pipecat/processors/aggregators/openai_llm_context.py +++ b/src/pipecat/processors/aggregators/openai_llm_context.py @@ -133,6 +133,7 @@ async def call_function( tool_call_id: str, arguments: str, llm: FrameProcessor, + run_llm: bool = True, ) -> None: # Push a SystemFrame downstream. This frame will let our assistant context aggregator # know that we are in the middle of a function call. Some contexts/aggregators may @@ -153,6 +154,7 @@ async def function_call_result_callback(result): tool_call_id=tool_call_id, arguments=arguments, result=result, + run_llm=run_llm, ) ) diff --git a/src/pipecat/processors/async_generator.py b/src/pipecat/processors/async_generator.py new file mode 100644 index 000000000..4f9bc85d0 --- /dev/null +++ b/src/pipecat/processors/async_generator.py @@ -0,0 +1,44 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio + +from typing import Any, AsyncGenerator + +from pipecat.frames.frames import ( + CancelFrame, + EndFrame, + Frame, +) +from pipecat.processors.frame_processor import FrameProcessor, FrameDirection +from pipecat.serializers.base_serializer import FrameSerializer + + +class AsyncGeneratorProcessor(FrameProcessor): + def __init__(self, *, serializer: FrameSerializer, **kwargs): + super().__init__(**kwargs) + self._serializer = serializer + self._data_queue = asyncio.Queue() + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + await self.push_frame(frame, direction) + + if isinstance(frame, (CancelFrame, EndFrame)): + await self._data_queue.put(None) + else: + data = self._serializer.serialize(frame) + if data: + await self._data_queue.put(data) + + async def generator(self) -> AsyncGenerator[Any, None]: + running = True + while running: + data = await self._data_queue.get() + running = data is not None + if data: + yield data diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index 1bf42311d..f458f43ff 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -5,7 +5,7 @@ # import asyncio -import time +import inspect from enum import Enum @@ -37,7 +37,6 @@ def __init__( *, name: str | None = None, metrics: FrameProcessorMetrics | None = None, - sync: bool = True, loop: asyncio.AbstractEventLoop | None = None, **kwargs, ): @@ -47,7 +46,8 @@ def __init__( self._prev: "FrameProcessor" | None = None self._next: "FrameProcessor" | None = None self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_running_loop() - self._sync = sync + + self._event_handlers: dict = {} # Clock self._clock: BaseClock | None = None @@ -64,11 +64,8 @@ def __init__( # Every processor in Pipecat should only output frames from a single # task. This avoid problems like audio overlapping. System frames are - # the exception to this rule. - # - # This create this task. - if not self._sync: - self.__create_push_task() + # the exception to this rule. This create this task. + self.__create_push_task() @property def interruptions_allowed(self): @@ -165,23 +162,39 @@ async def push_error(self, error: ErrorFrame): await self.push_frame(error, FrameDirection.UPSTREAM) async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): - if self._sync or isinstance(frame, SystemFrame): + if isinstance(frame, SystemFrame): await self.__internal_push_frame(frame, direction) else: await self.__push_queue.put((frame, direction)) + def event_handler(self, event_name: str): + def decorator(handler): + self.add_event_handler(event_name, handler) + return handler + + return decorator + + def add_event_handler(self, event_name: str, handler): + if event_name not in self._event_handlers: + raise Exception(f"Event handler {event_name} not registered") + self._event_handlers[event_name].append(handler) + + def _register_event_handler(self, event_name: str): + if event_name in self._event_handlers: + raise Exception(f"Event handler {event_name} already registered") + self._event_handlers[event_name] = [] + # # Handle interruptions # async def _start_interruption(self): - if not self._sync: - # Cancel the task. This will stop pushing frames downstream. - self.__push_frame_task.cancel() - await self.__push_frame_task + # Cancel the task. This will stop pushing frames downstream. + self.__push_frame_task.cancel() + await self.__push_frame_task - # Create a new queue and task. - self.__create_push_task() + # Create a new queue and task. + self.__create_push_task() async def _stop_interruption(self): # Nothing to do right now. @@ -213,5 +226,15 @@ async def __push_frame_task_handler(self): except asyncio.CancelledError: break + async def _call_event_handler(self, event_name: str, *args, **kwargs): + try: + for handler in self._event_handlers[event_name]: + if inspect.iscoroutinefunction(handler): + await handler(self, *args, **kwargs) + else: + handler(self, *args, **kwargs) + except Exception as e: + logger.exception(f"Exception in event handler {event_name}: {e}") + def __str__(self): return self.name diff --git a/src/pipecat/processors/frameworks/rtvi.py b/src/pipecat/processors/frameworks/rtvi.py index 820ea716c..7a6054c3c 100644 --- a/src/pipecat/processors/frameworks/rtvi.py +++ b/src/pipecat/processors/frameworks/rtvi.py @@ -5,6 +5,7 @@ # import asyncio +import base64 from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Union from pydantic import BaseModel, Field, PrivateAttr, ValidationError @@ -20,8 +21,14 @@ ErrorFrame, Frame, InterimTranscriptionFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + OutputAudioRawFrame, StartFrame, SystemFrame, + TTSStartedFrame, + TTSStoppedFrame, + TextFrame, TranscriptionFrame, TransportMessageFrame, UserStartedSpeakingFrame, @@ -34,7 +41,7 @@ from loguru import logger -RTVI_PROTOCOL_VERSION = "0.1" +RTVI_PROTOCOL_VERSION = "0.2" ActionResult = Union[bool, int, float, str, list, dict] @@ -242,17 +249,75 @@ class RTVILLMFunctionCallResultData(BaseModel): result: dict | str -class RTVITranscriptionMessageData(BaseModel): +class RTVIBotLLMStartedMessage(BaseModel): + label: Literal["rtvi-ai"] = "rtvi-ai" + type: Literal["bot-llm-started"] = "bot-llm-started" + + +class RTVIBotLLMStoppedMessage(BaseModel): + label: Literal["rtvi-ai"] = "rtvi-ai" + type: Literal["bot-llm-stopped"] = "bot-llm-stopped" + + +class RTVIBotTTSStartedMessage(BaseModel): + label: Literal["rtvi-ai"] = "rtvi-ai" + type: Literal["bot-tts-started"] = "bot-tts-started" + + +class RTVIBotTTSStoppedMessage(BaseModel): + label: Literal["rtvi-ai"] = "rtvi-ai" + type: Literal["bot-tts-stopped"] = "bot-tts-stopped" + + +class RTVITextMessageData(BaseModel): + text: str + + +class RTVIBotLLMTextMessage(BaseModel): + label: Literal["rtvi-ai"] = "rtvi-ai" + type: Literal["bot-llm-text"] = "bot-llm-text" + data: RTVITextMessageData + + +class RTVIBotTTSTextMessage(BaseModel): + label: Literal["rtvi-ai"] = "rtvi-ai" + type: Literal["bot-tts-text"] = "bot-tts-text" + data: RTVITextMessageData + + +class RTVIAudioMessageData(BaseModel): + audio: str + sample_rate: int + num_channels: int + + +class RTVIBotAudioMessage(BaseModel): + label: Literal["rtvi-ai"] = "rtvi-ai" + type: Literal["bot-audio"] = "bot-audio" + data: RTVIAudioMessageData + + +class RTVIBotTranscriptionMessageData(BaseModel): + text: str + + +class RTVIBotTranscriptionMessage(BaseModel): + label: Literal["rtvi-ai"] = "rtvi-ai" + type: Literal["bot-transcription"] = "bot-transcription" + data: RTVIBotTranscriptionMessageData + + +class RTVIUserTranscriptionMessageData(BaseModel): text: str user_id: str timestamp: str final: bool -class RTVITranscriptionMessage(BaseModel): +class RTVIUserTranscriptionMessage(BaseModel): label: Literal["rtvi-ai"] = "rtvi-ai" type: Literal["user-transcription"] = "user-transcription" - data: RTVITranscriptionMessageData + data: RTVIUserTranscriptionMessageData class RTVIUserStartedSpeakingMessage(BaseModel): @@ -279,6 +344,170 @@ class RTVIProcessorParams(BaseModel): send_bot_ready: bool = True +class RTVIFrameProcessor(FrameProcessor): + def __init__(self, direction: FrameDirection = FrameDirection.DOWNSTREAM, **kwargs): + super().__init__(**kwargs) + self._direction = direction + + async def _push_transport_message(self, model: BaseModel, exclude_none: bool = True): + frame = TransportMessageFrame( + message=model.model_dump(exclude_none=exclude_none), urgent=True + ) + await self.push_frame(frame, self._direction) + + +class RTVISpeakingProcessor(RTVIFrameProcessor): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + await self.push_frame(frame, direction) + + if isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame)): + await self._handle_interruptions(frame) + elif isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame)): + await self._handle_bot_speaking(frame) + + async def _handle_interruptions(self, frame: Frame): + message = None + if isinstance(frame, UserStartedSpeakingFrame): + message = RTVIUserStartedSpeakingMessage() + elif isinstance(frame, UserStoppedSpeakingFrame): + message = RTVIUserStoppedSpeakingMessage() + + if message: + await self._push_transport_message(message) + + async def _handle_bot_speaking(self, frame: Frame): + message = None + if isinstance(frame, BotStartedSpeakingFrame): + message = RTVIBotStartedSpeakingMessage() + elif isinstance(frame, BotStoppedSpeakingFrame): + message = RTVIBotStoppedSpeakingMessage() + + if message: + await self._push_transport_message(message) + + +class RTVIUserTranscriptionProcessor(RTVIFrameProcessor): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + await self.push_frame(frame, direction) + + if isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame)): + await self._handle_user_transcriptions(frame) + + async def _handle_user_transcriptions(self, frame: Frame): + message = None + if isinstance(frame, TranscriptionFrame): + message = RTVIUserTranscriptionMessage( + data=RTVIUserTranscriptionMessageData( + text=frame.text, user_id=frame.user_id, timestamp=frame.timestamp, final=True + ) + ) + elif isinstance(frame, InterimTranscriptionFrame): + message = RTVIUserTranscriptionMessage( + data=RTVIUserTranscriptionMessageData( + text=frame.text, user_id=frame.user_id, timestamp=frame.timestamp, final=False + ) + ) + + if message: + await self._push_transport_message(message) + + +class RTVIBotLLMProcessor(RTVIFrameProcessor): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + await self.push_frame(frame, direction) + + if isinstance(frame, LLMFullResponseStartFrame): + await self._push_transport_message(RTVIBotLLMStartedMessage()) + elif isinstance(frame, LLMFullResponseEndFrame): + await self._push_transport_message(RTVIBotLLMStoppedMessage()) + + +class RTVIBotTTSProcessor(RTVIFrameProcessor): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + await self.push_frame(frame, direction) + + if isinstance(frame, TTSStartedFrame): + await self._push_transport_message(RTVIBotTTSStartedMessage()) + elif isinstance(frame, TTSStoppedFrame): + await self._push_transport_message(RTVIBotTTSStoppedMessage()) + + +class RTVIBotLLMTextProcessor(RTVIFrameProcessor): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + await self.push_frame(frame, direction) + + if isinstance(frame, TextFrame): + await self._handle_text(frame) + + async def _handle_text(self, frame: TextFrame): + message = RTVIBotLLMTextMessage(data=RTVITextMessageData(text=frame.text)) + await self._push_transport_message(message) + + +class RTVIBotTTSTextProcessor(RTVIFrameProcessor): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + await self.push_frame(frame, direction) + + if isinstance(frame, TextFrame): + await self._handle_text(frame) + + async def _handle_text(self, frame: TextFrame): + message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text)) + await self._push_transport_message(message) + + +class RTVIBotAudioProcessor(RTVIFrameProcessor): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + await self.push_frame(frame, direction) + + if isinstance(frame, OutputAudioRawFrame): + await self._handle_audio(frame) + + async def _handle_audio(self, frame: OutputAudioRawFrame): + encoded = base64.b64encode(frame.audio).decode("utf-8") + message = RTVIBotAudioMessage( + data=RTVIAudioMessageData( + audio=encoded, sample_rate=frame.sample_rate, num_channels=frame.num_channels + ) + ) + await self._push_transport_message(message) + + class RTVIProcessor(FrameProcessor): def __init__( self, @@ -287,7 +516,7 @@ def __init__( params: RTVIProcessorParams = RTVIProcessorParams(), **kwargs, ): - super().__init__(sync=False, **kwargs) + super().__init__(**kwargs) self._config = config self._params = params @@ -300,9 +529,16 @@ def __init__( self._registered_actions: Dict[str, RTVIAction] = {} self._registered_services: Dict[str, RTVIService] = {} + # A task to process incoming action frames. + self._action_task = self.get_event_loop().create_task(self._action_task_handler()) + self._action_queue = asyncio.Queue() + + # A task to process incoming transport messages. self._message_task = self.get_event_loop().create_task(self._message_task_handler()) self._message_queue = asyncio.Queue() + self._register_event_handler("on_bot_ready") + def register_action(self, action: RTVIAction): id = self._action_id(action.service, action.action) self._registered_actions[id] = action @@ -322,6 +558,9 @@ async def set_client_ready(self): self._client_ready = True await self._maybe_send_bot_ready() + async def handle_message(self, message: RTVIMessage): + await self._message_queue.put(message) + async def handle_function_call( self, function_name: str, @@ -368,24 +607,11 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): # finish and the task finishes when EndFrame is processed. await self.push_frame(frame, direction) await self._stop(frame) - elif isinstance(frame, UserStartedSpeakingFrame) or isinstance( - frame, UserStoppedSpeakingFrame - ): - await self._handle_interruptions(frame) - await self.push_frame(frame, direction) - elif isinstance(frame, BotStartedSpeakingFrame) or isinstance( - frame, BotStoppedSpeakingFrame - ): - await self._handle_bot_speaking(frame) - await self.push_frame(frame, direction) # Data frames - elif isinstance(frame, TranscriptionFrame) or isinstance(frame, InterimTranscriptionFrame): - await self._handle_transcriptions(frame) - await self.push_frame(frame, direction) elif isinstance(frame, TransportMessageFrame): - await self._message_queue.put(frame) + await self._handle_transport_message(frame) elif isinstance(frame, RTVIActionFrame): - await self._handle_action(frame.message_id, frame.rtvi_action_run) + await self._action_queue.put(frame) # Other frames else: await self.push_frame(frame, direction) @@ -399,14 +625,26 @@ async def _start(self, frame: StartFrame): await self._maybe_send_bot_ready() async def _stop(self, frame: EndFrame): - # We need to cancel the message task handler because that one is not - # processing EndFrames. - self._message_task.cancel() - await self._message_task + if self._action_task: + self._action_task.cancel() + await self._action_task + self._action_task = None + + if self._message_task: + self._message_task.cancel() + await self._message_task + self._message_task = None async def _cancel(self, frame: CancelFrame): - self._message_task.cancel() - await self._message_task + if self._action_task: + self._action_task.cancel() + await self._action_task + self._action_task = None + + if self._message_task: + self._message_task.cancel() + await self._message_task + self._message_task = None async def _push_transport_message(self, model: BaseModel, exclude_none: bool = True): frame = TransportMessageFrame( @@ -414,64 +652,33 @@ async def _push_transport_message(self, model: BaseModel, exclude_none: bool = T ) await self.push_frame(frame) - async def _handle_transcriptions(self, frame: Frame): - # TODO(aleix): Once we add support for using custom pipelines, the STTs will - # be in the pipeline after this processor. - - message = None - if isinstance(frame, TranscriptionFrame): - message = RTVITranscriptionMessage( - data=RTVITranscriptionMessageData( - text=frame.text, user_id=frame.user_id, timestamp=frame.timestamp, final=True - ) - ) - elif isinstance(frame, InterimTranscriptionFrame): - message = RTVITranscriptionMessage( - data=RTVITranscriptionMessageData( - text=frame.text, user_id=frame.user_id, timestamp=frame.timestamp, final=False - ) - ) - - if message: - await self._push_transport_message(message) - - async def _handle_interruptions(self, frame: Frame): - message = None - if isinstance(frame, UserStartedSpeakingFrame): - message = RTVIUserStartedSpeakingMessage() - elif isinstance(frame, UserStoppedSpeakingFrame): - message = RTVIUserStoppedSpeakingMessage() - - if message: - await self._push_transport_message(message) - - async def _handle_bot_speaking(self, frame: Frame): - message = None - if isinstance(frame, BotStartedSpeakingFrame): - message = RTVIBotStartedSpeakingMessage() - elif isinstance(frame, BotStoppedSpeakingFrame): - message = RTVIBotStoppedSpeakingMessage() - - if message: - await self._push_transport_message(message) + async def _action_task_handler(self): + while True: + try: + frame = await self._action_queue.get() + await self._handle_action(frame.message_id, frame.rtvi_action_run) + self._action_queue.task_done() + except asyncio.CancelledError: + break async def _message_task_handler(self): while True: try: - frame = await self._message_queue.get() - await self._handle_message(frame) + message = await self._message_queue.get() + await self._handle_message(message) self._message_queue.task_done() except asyncio.CancelledError: break - async def _handle_message(self, frame: TransportMessageFrame): + async def _handle_transport_message(self, frame: TransportMessageFrame): try: message = RTVIMessage.model_validate(frame.message) + await self._message_queue.put(message) except ValidationError as e: - await self.send_error(f"Invalid incoming message: {e}") - logger.warning(f"Invalid incoming message: {e}") - return + await self.send_error(f"Invalid RTVI transport message: {e}") + logger.warning(f"Invalid RTVI transport message: {e}") + async def _handle_message(self, message: RTVIMessage): try: match message.type: case "client-ready": @@ -487,7 +694,8 @@ async def _handle_message(self, frame: TransportMessageFrame): await self._handle_update_config(message.id, update_config) case "action": action = RTVIActionRun.model_validate(message.data) - await self._handle_action(message.id, action) + action_frame = RTVIActionFrame(message_id=message.id, rtvi_action_run=action) + await self._action_queue.put(action_frame) case "llm-function-call-result": data = RTVILLMFunctionCallResultData.model_validate(message.data) await self._handle_function_call_result(data) @@ -496,8 +704,8 @@ async def _handle_message(self, frame: TransportMessageFrame): await self._send_error_response(message.id, f"Unsupported type {message.type}") except ValidationError as e: - await self._send_error_response(message.id, f"Invalid incoming message: {e}") - logger.warning(f"Invalid incoming message: {e}") + await self._send_error_response(message.id, f"Invalid message: {e}") + logger.warning(f"Invalid message: {e}") except Exception as e: await self._send_error_response(message.id, f"Exception processing message: {e}") logger.warning(f"Exception processing message: {e}") @@ -577,8 +785,9 @@ async def _handle_action(self, request_id: str | None, data: RTVIActionRun): async def _maybe_send_bot_ready(self): if self._pipeline_started and self._client_ready: - await self._send_bot_ready() await self._update_config(self._config, False) + await self._send_bot_ready() + await self._call_event_handler("on_bot_ready") async def _send_bot_ready(self): if not self._params.send_bot_ready: diff --git a/src/pipecat/processors/gstreamer/pipeline_source.py b/src/pipecat/processors/gstreamer/pipeline_source.py index 9f8471153..426eab50a 100644 --- a/src/pipecat/processors/gstreamer/pipeline_source.py +++ b/src/pipecat/processors/gstreamer/pipeline_source.py @@ -44,7 +44,7 @@ class OutputParams(BaseModel): clock_sync: bool = True def __init__(self, *, pipeline: str, out_params: OutputParams = OutputParams(), **kwargs): - super().__init__(sync=False, **kwargs) + super().__init__(**kwargs) self._out_params = out_params diff --git a/src/pipecat/processors/idle_frame_processor.py b/src/pipecat/processors/idle_frame_processor.py index 576cb9087..e674b6b84 100644 --- a/src/pipecat/processors/idle_frame_processor.py +++ b/src/pipecat/processors/idle_frame_processor.py @@ -26,7 +26,7 @@ def __init__( types: List[type] = [], **kwargs, ): - super().__init__(sync=False, **kwargs) + super().__init__(**kwargs) self._callback = callback self._timeout = timeout diff --git a/src/pipecat/processors/user_idle_processor.py b/src/pipecat/processors/user_idle_processor.py index 31d49cf5a..507dcb495 100644 --- a/src/pipecat/processors/user_idle_processor.py +++ b/src/pipecat/processors/user_idle_processor.py @@ -31,7 +31,7 @@ def __init__( timeout: float, **kwargs, ): - super().__init__(sync=False, **kwargs) + super().__init__(**kwargs) self._callback = callback self._timeout = timeout diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index cdad3de52..0089a152e 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -7,9 +7,10 @@ import asyncio import io import wave - from abc import abstractmethod -from typing import AsyncGenerator, List, Optional, Tuple +from typing import AsyncGenerator, List, Optional, Tuple, Union + +from loguru import logger from pipecat.frames.frames import ( AudioRawFrame, @@ -18,31 +19,26 @@ ErrorFrame, Frame, LLMFullResponseEndFrame, - STTLanguageUpdateFrame, - STTModelUpdateFrame, StartFrame, StartInterruptionFrame, + STTUpdateSettingsFrame, + TextFrame, TTSAudioRawFrame, - TTSLanguageUpdateFrame, - TTSModelUpdateFrame, TTSSpeakFrame, TTSStartedFrame, TTSStoppedFrame, - TTSVoiceUpdateFrame, - TextFrame, + TTSUpdateSettingsFrame, UserImageRequestFrame, VisionImageRawFrame, ) from pipecat.metrics.metrics import MetricsData +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.transcriptions.language import Language from pipecat.utils.audio import calculate_audio_volume from pipecat.utils.string import match_endofsentence from pipecat.utils.time import seconds_to_nanoseconds from pipecat.utils.utils import exp_smoothing -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext - -from loguru import logger class AIService(FrameProcessor): @@ -114,7 +110,13 @@ def has_function(self, function_name: str): return function_name in self._callbacks.keys() async def call_function( - self, *, context: OpenAILLMContext, tool_call_id: str, function_name: str, arguments: str + self, + *, + context: OpenAILLMContext, + tool_call_id: str, + function_name: str, + arguments: str, + run_llm: bool, ) -> None: f = None if function_name in self._callbacks.keys(): @@ -124,7 +126,12 @@ async def call_function( else: return None await context.call_function( - f, function_name=function_name, tool_call_id=tool_call_id, arguments=arguments, llm=self + f, + function_name=function_name, + tool_call_id=tool_call_id, + arguments=arguments, + llm=self, + run_llm=run_llm, ) # QUESTION FOR CB: maybe this isn't needed anymore? @@ -148,6 +155,10 @@ def __init__( # if True, TTSService will push TextFrames and LLMFullResponseEndFrames, # otherwise subclass must do it push_text_frames: bool = True, + # if True, TTSService will push TTSStoppedFrames, otherwise subclass must do it + push_stop_frames: bool = False, + # if push_stop_frames is True, wait for this idle period before pushing TTSStoppedFrame + stop_frame_timeout_s: float = 1.0, # TTS output sample rate sample_rate: int = 16000, **kwargs, @@ -155,9 +166,15 @@ def __init__( super().__init__(**kwargs) self._aggregate_sentences: bool = aggregate_sentences self._push_text_frames: bool = push_text_frames - self._current_sentence: str = "" + self._push_stop_frames: bool = push_stop_frames + self._stop_frame_timeout_s: float = stop_frame_timeout_s self._sample_rate: int = sample_rate + self._stop_frame_task: Optional[asyncio.Task] = None + self._stop_frame_queue: asyncio.Queue = asyncio.Queue() + + self._current_sentence: str = "" + @property def sample_rate(self) -> int: return self._sample_rate @@ -174,91 +191,55 @@ async def set_voice(self, voice: str): async def set_language(self, language: Language): pass - # Converts the text to audio. @abstractmethod - async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + async def set_speed(self, speed: Union[str, float]): pass - async def say(self, text: str): - await self.process_frame(TextFrame(text=text), FrameDirection.DOWNSTREAM) - - async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection): - self._current_sentence = "" - await self.push_frame(frame, direction) + @abstractmethod + async def set_emotion(self, emotion: List[str]): + pass - async def _process_text_frame(self, frame: TextFrame): - text: str | None = None - if not self._aggregate_sentences: - text = frame.text - else: - self._current_sentence += frame.text - if match_endofsentence(self._current_sentence): - text = self._current_sentence - self._current_sentence = "" + @abstractmethod + async def set_engine(self, engine: str): + pass - if text: - await self._push_tts_frames(text) + @abstractmethod + async def set_pitch(self, pitch: str): + pass - async def _push_tts_frames(self, text: str): - text = text.strip() - if not text: - return + @abstractmethod + async def set_rate(self, rate: str): + pass - await self.start_processing_metrics() - await self.process_generator(self.run_tts(text)) - await self.stop_processing_metrics() - if self._push_text_frames: - # We send the original text after the audio. This way, if we are - # interrupted, the text is not added to the assistant context. - await self.push_frame(TextFrame(text)) + @abstractmethod + async def set_volume(self, volume: str): + pass - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) + @abstractmethod + async def set_emphasis(self, emphasis: str): + pass - if isinstance(frame, TextFrame): - await self._process_text_frame(frame) - elif isinstance(frame, StartInterruptionFrame): - await self._handle_interruption(frame, direction) - elif isinstance(frame, LLMFullResponseEndFrame) or isinstance(frame, EndFrame): - sentence = self._current_sentence - self._current_sentence = "" - await self._push_tts_frames(sentence) - if isinstance(frame, LLMFullResponseEndFrame): - if self._push_text_frames: - await self.push_frame(frame, direction) - else: - await self.push_frame(frame, direction) - elif isinstance(frame, TTSSpeakFrame): - await self._push_tts_frames(frame.text) - elif isinstance(frame, TTSModelUpdateFrame): - await self.set_model(frame.model) - elif isinstance(frame, TTSVoiceUpdateFrame): - await self.set_voice(frame.voice) - elif isinstance(frame, TTSLanguageUpdateFrame): - await self.set_language(frame.language) - else: - await self.push_frame(frame, direction) + @abstractmethod + async def set_style(self, style: str): + pass + @abstractmethod + async def set_style_degree(self, style_degree: str): + pass -class AsyncTTSService(TTSService): - def __init__( - self, - # if True, TTSService will push TTSStoppedFrames, otherwise subclass must do it - push_stop_frames: bool = False, - # if push_stop_frames is True, wait for this idle period before pushing TTSStoppedFrame - stop_frame_timeout_s: float = 1.0, - **kwargs, - ): - super().__init__(sync=False, **kwargs) - self._push_stop_frames: bool = push_stop_frames - self._stop_frame_timeout_s: float = stop_frame_timeout_s - self._stop_frame_task: Optional[asyncio.Task] = None - self._stop_frame_queue: asyncio.Queue = asyncio.Queue() + @abstractmethod + async def set_role(self, role: str): + pass @abstractmethod async def flush_audio(self): pass + # Converts the text to audio. + @abstractmethod + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + pass + async def start(self, frame: StartFrame): await super().start(frame) if self._push_stop_frames: @@ -278,6 +259,37 @@ async def cancel(self, frame: CancelFrame): await self._stop_frame_task self._stop_frame_task = None + async def say(self, text: str): + aggregate_sentences = self._aggregate_sentences + self._aggregate_sentences = False + await self.process_frame(TextFrame(text=text), FrameDirection.DOWNSTREAM) + self._aggregate_sentences = aggregate_sentences + await self.flush_audio() + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, TextFrame): + await self._process_text_frame(frame) + elif isinstance(frame, StartInterruptionFrame): + await self._handle_interruption(frame, direction) + elif isinstance(frame, LLMFullResponseEndFrame) or isinstance(frame, EndFrame): + sentence = self._current_sentence + self._current_sentence = "" + await self._push_tts_frames(sentence) + if isinstance(frame, LLMFullResponseEndFrame): + if self._push_text_frames: + await self.push_frame(frame, direction) + else: + await self.push_frame(frame, direction) + elif isinstance(frame, TTSSpeakFrame): + await self._push_tts_frames(frame.text) + await self.flush_audio() + elif isinstance(frame, TTSUpdateSettingsFrame): + await self._update_tts_settings(frame) + else: + await self.push_frame(frame, direction) + async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): await super().push_frame(frame, direction) @@ -289,6 +301,66 @@ async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirect ): await self._stop_frame_queue.put(frame) + async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection): + self._current_sentence = "" + await self.push_frame(frame, direction) + + async def _process_text_frame(self, frame: TextFrame): + text: str | None = None + if not self._aggregate_sentences: + text = frame.text + else: + self._current_sentence += frame.text + eos_end_marker = match_endofsentence(self._current_sentence) + if eos_end_marker: + text = self._current_sentence[:eos_end_marker] + self._current_sentence = self._current_sentence[eos_end_marker:] + + if text: + await self._push_tts_frames(text) + + async def _push_tts_frames(self, text: str): + # Don't send only whitespace. This causes problems for some TTS models. But also don't + # strip all whitespace, as whitespace can influence prosody. + if not text.strip(): + return + + await self.start_processing_metrics() + await self.process_generator(self.run_tts(text)) + await self.stop_processing_metrics() + if self._push_text_frames: + # We send the original text after the audio. This way, if we are + # interrupted, the text is not added to the assistant context. + await self.push_frame(TextFrame(text)) + + async def _update_tts_settings(self, frame: TTSUpdateSettingsFrame): + if frame.model is not None: + await self.set_model(frame.model) + if frame.voice is not None: + await self.set_voice(frame.voice) + if frame.language is not None: + await self.set_language(frame.language) + if frame.speed is not None: + await self.set_speed(frame.speed) + if frame.emotion is not None: + await self.set_emotion(frame.emotion) + if frame.engine is not None: + await self.set_engine(frame.engine) + if frame.pitch is not None: + await self.set_pitch(frame.pitch) + if frame.rate is not None: + await self.set_rate(frame.rate) + if frame.volume is not None: + await self.set_volume(frame.volume) + if frame.emphasis is not None: + await self.set_emphasis(frame.emphasis) + if frame.style is not None: + await self.set_style(frame.style) + if frame.style_degree is not None: + await self.set_style_degree(frame.style_degree) + if frame.role is not None: + await self.set_role(frame.role) + async def _stop_frame_handler(self): try: has_started = False @@ -309,7 +381,7 @@ async def _stop_frame_handler(self): pass -class AsyncWordTTSService(AsyncTTSService): +class WordTTSService(TTSService): def __init__(self, **kwargs): super().__init__(**kwargs) self._initial_word_timestamp = -1 @@ -350,6 +422,7 @@ async def _stop_words_task(self): if self._words_task: self._words_task.cancel() await self._words_task + self._words_task = None async def _words_task_handler(self): while True: @@ -387,6 +460,12 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: """Returns transcript as a string""" pass + async def _update_stt_settings(self, frame: STTUpdateSettingsFrame): + if frame.model is not None: + await self.set_model(frame.model) + if frame.language is not None: + await self.set_language(frame.language) + async def process_audio_frame(self, frame: AudioRawFrame): await self.process_generator(self.run_stt(frame.audio)) @@ -398,10 +477,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): # In this service we accumulate audio internally and at the end we # push a TextFrame. We don't really want to push audio frames down. await self.process_audio_frame(frame) - elif isinstance(frame, STTModelUpdateFrame): - await self.set_model(frame.model) - elif isinstance(frame, STTLanguageUpdateFrame): - await self.set_language(frame.language) + elif isinstance(frame, STTUpdateSettingsFrame): + await self._update_stt_settings(frame) else: await self.push_frame(frame, direction) diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic.py index 8b8e187ea..86e1e3726 100644 --- a/src/pipecat/services/anthropic.py +++ b/src/pipecat/services/anthropic.py @@ -5,47 +5,47 @@ # import base64 -import json -import io import copy -from typing import Any, Dict, List, Optional +import io +import json +import re +from asyncio import CancelledError from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from loguru import logger from PIL import Image -from asyncio import CancelledError -import re from pydantic import BaseModel, Field from pipecat.frames.frames import ( Frame, + FunctionCallInProgressFrame, + FunctionCallResultFrame, LLMEnablePromptCachingFrame, - LLMModelUpdateFrame, - TextFrame, - VisionImageRawFrame, - UserImageRequestFrame, - UserImageRawFrame, - LLMMessagesFrame, - LLMFullResponseStartFrame, LLMFullResponseEndFrame, - FunctionCallResultFrame, - FunctionCallInProgressFrame, + LLMFullResponseStartFrame, + LLMMessagesFrame, + LLMUpdateSettingsFrame, StartInterruptionFrame, + TextFrame, + UserImageRawFrame, + UserImageRequestFrame, + VisionImageRawFrame, ) from pipecat.metrics.metrics import LLMTokenUsage -from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.ai_services import LLMService +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantContextAggregator, + LLMUserContextAggregator, +) from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, ) -from pipecat.processors.aggregators.llm_response import ( - LLMUserContextAggregator, - LLMAssistantContextAggregator, -) - -from loguru import logger +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.ai_services import LLMService try: - from anthropic import AsyncAnthropic, NOT_GIVEN, NotGiven + from anthropic import NOT_GIVEN, AsyncAnthropic, NotGiven except ModuleNotFoundError as e: logger.error(f"Exception: {e}") logger.error( @@ -110,9 +110,13 @@ def enable_prompt_caching_beta(self) -> bool: return self._enable_prompt_caching_beta @staticmethod - def create_context_aggregator(context: OpenAILLMContext) -> AnthropicContextAggregatorPair: + def create_context_aggregator( + context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True + ) -> AnthropicContextAggregatorPair: user = AnthropicUserContextAggregator(context) - assistant = AnthropicAssistantContextAggregator(user) + assistant = AnthropicAssistantContextAggregator( + user, expect_stripped_words=assistant_expect_stripped_words + ) return AnthropicContextAggregatorPair(_user=user, _assistant=assistant) async def set_enable_prompt_caching_beta(self, enable_prompt_caching_beta: bool): @@ -279,6 +283,21 @@ async def _process_context(self, context: OpenAILLMContext): cache_read_input_tokens=cache_read_input_tokens, ) + async def _update_settings(self, frame: LLMUpdateSettingsFrame): + if frame.model is not None: + logger.debug(f"Switching LLM model to: [{frame.model}]") + self.set_model_name(frame.model) + if frame.max_tokens is not None: + await self.set_max_tokens(frame.max_tokens) + if frame.temperature is not None: + await self.set_temperature(frame.temperature) + if frame.top_k is not None: + await self.set_top_k(frame.top_k) + if frame.top_p is not None: + await self.set_top_p(frame.top_p) + if frame.extra: + await self.set_extra(frame.extra) + async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -293,9 +312,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): # UserImageRawFrames coming through the pipeline and add them # to the context. context = AnthropicLLMContext.from_image_frame(frame) - elif isinstance(frame, LLMModelUpdateFrame): - logger.debug(f"Switching LLM model to: [{frame.model}]") - self.set_model_name(frame.model) + elif isinstance(frame, LLMUpdateSettingsFrame): + await self._update_settings(frame) elif isinstance(frame, LLMEnablePromptCachingFrame): logger.debug(f"Setting enable prompt caching to: [{frame.enable}]") self._enable_prompt_caching_beta = frame.enable @@ -527,8 +545,8 @@ async def process_frame(self, frame, direction): class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator): - def __init__(self, user_context_aggregator: AnthropicUserContextAggregator): - super().__init__(context=user_context_aggregator._context) + def __init__(self, user_context_aggregator: AnthropicUserContextAggregator, **kwargs): + super().__init__(context=user_context_aggregator._context, **kwargs) self._user_context_aggregator = user_context_aggregator self._function_call_in_progress = None self._function_call_result = None @@ -565,7 +583,7 @@ async def _push_aggregation(self): run_llm = False aggregation = self._aggregation - self._aggregation = "" + self._reset() try: if self._function_call_result: @@ -616,5 +634,8 @@ async def _push_aggregation(self): if run_llm: await self._user_context_aggregator.push_context_frame() + frame = OpenAILLMContextFrame(self._context) + await self.push_frame(frame) + except Exception as e: logger.error(f"Error processing frame: {e}") diff --git a/src/pipecat/services/aws.py b/src/pipecat/services/aws.py new file mode 100644 index 000000000..80240985f --- /dev/null +++ b/src/pipecat/services/aws.py @@ -0,0 +1,174 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import AsyncGenerator, Optional + +from pydantic import BaseModel + +from pipecat.frames.frames import ( + ErrorFrame, + Frame, + TTSAudioRawFrame, + TTSStartedFrame, + TTSStoppedFrame, +) +from pipecat.services.ai_services import TTSService + +from loguru import logger + +try: + import boto3 + from botocore.exceptions import BotoCoreError, ClientError +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use Deepgram, you need to `pip install pipecat-ai[aws]`. Also, set `AWS_SECRET_ACCESS_KEY`, `AWS_ACCESS_KEY_ID`, and `AWS_REGION` environment variable." + ) + raise Exception(f"Missing module: {e}") + + +class AWSTTSService(TTSService): + class InputParams(BaseModel): + engine: Optional[str] = None + language: Optional[str] = None + pitch: Optional[str] = None + rate: Optional[str] = None + volume: Optional[str] = None + + def __init__( + self, + *, + api_key: str, + aws_access_key_id: str, + region: str, + voice_id: str = "Joanna", + sample_rate: int = 16000, + params: InputParams = InputParams(), + **kwargs, + ): + super().__init__(sample_rate=sample_rate, **kwargs) + + self._polly_client = boto3.client( + "polly", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=api_key, + region_name=region, + ) + self._voice_id = voice_id + self._sample_rate = sample_rate + self._params = params + + def can_generate_metrics(self) -> bool: + return True + + def _construct_ssml(self, text: str) -> str: + ssml = "" + + if self._params.language: + ssml += f"" + + prosody_attrs = [] + # Prosody tags are only supported for standard and neural engines + if self._params.engine != "generative": + if self._params.rate: + prosody_attrs.append(f"rate='{self._params.rate}'") + if self._params.pitch: + prosody_attrs.append(f"pitch='{self._params.pitch}'") + if self._params.volume: + prosody_attrs.append(f"volume='{self._params.volume}'") + + if prosody_attrs: + ssml += f"" + else: + logger.warning("Prosody tags are not supported for generative engine. Ignoring.") + + ssml += text + + if prosody_attrs: + ssml += "" + + if self._params.language: + ssml += "" + + ssml += "" + + return ssml + + async def set_voice(self, voice: str): + logger.debug(f"Switching TTS voice to: [{voice}]") + self._voice_id = voice + + async def set_engine(self, engine: str): + logger.debug(f"Switching TTS engine to: [{engine}]") + self._params.engine = engine + + async def set_language(self, language: str): + logger.debug(f"Switching TTS language to: [{language}]") + self._params.language = language + + async def set_pitch(self, pitch: str): + logger.debug(f"Switching TTS pitch to: [{pitch}]") + self._params.pitch = pitch + + async def set_rate(self, rate: str): + logger.debug(f"Switching TTS rate to: [{rate}]") + self._params.rate = rate + + async def set_volume(self, volume: str): + logger.debug(f"Switching TTS volume to: [{volume}]") + self._params.volume = volume + + async def set_params(self, params: InputParams): + logger.debug(f"Switching TTS params to: [{params}]") + self._params = params + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"Generating TTS: [{text}]") + + try: + await self.start_ttfb_metrics() + + # Construct the parameters dictionary + ssml = self._construct_ssml(text) + + params = { + "Text": ssml, + "TextType": "ssml", + "OutputFormat": "pcm", + "VoiceId": self._voice_id, + "Engine": self._params.engine, + "SampleRate": str(self._sample_rate), + } + + # Filter out None values + filtered_params = {k: v for k, v in params.items() if v is not None} + + response = self._polly_client.synthesize_speech(**filtered_params) + + await self.start_tts_usage_metrics(text) + + await self.push_frame(TTSStartedFrame()) + + if "AudioStream" in response: + with response["AudioStream"] as stream: + audio_data = stream.read() + chunk_size = 8192 + for i in range(0, len(audio_data), chunk_size): + chunk = audio_data[i : i + chunk_size] + if len(chunk) > 0: + await self.stop_ttfb_metrics() + frame = TTSAudioRawFrame(chunk, self._sample_rate, 1) + yield frame + + await self.push_frame(TTSStoppedFrame()) + + except (BotoCoreError, ClientError) as error: + logger.exception(f"{self} error generating TTS: {error}") + error_message = f"AWS Polly TTS error: {str(error)}" + yield ErrorFrame(error=error_message) + + finally: + await self.push_frame(TTSStoppedFrame()) diff --git a/src/pipecat/services/azure.py b/src/pipecat/services/azure.py index 24e73cd2a..a1349cefe 100644 --- a/src/pipecat/services/azure.py +++ b/src/pipecat/services/azure.py @@ -8,8 +8,9 @@ import asyncio import io -from PIL import Image -from typing import AsyncGenerator +from typing import AsyncGenerator, Optional + +from pydantic import BaseModel from pipecat.frames.frames import ( CancelFrame, @@ -17,32 +18,35 @@ ErrorFrame, Frame, StartFrame, + TranscriptionFrame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame, - TranscriptionFrame, URLImageRawFrame, ) -from pipecat.metrics.metrics import TTSUsageMetricsData -from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.ai_services import STTService, TTSService, ImageGenService +from pipecat.services.ai_services import ImageGenService, STTService, TTSService from pipecat.services.openai import BaseOpenAILLMService from pipecat.utils.time import time_now_iso8601 +from PIL import Image + from loguru import logger # See .env.example for Azure configuration needed try: - from openai import AsyncAzureOpenAI from azure.cognitiveservices.speech import ( + CancellationReason, + ResultReason, SpeechConfig, SpeechRecognizer, SpeechSynthesizer, - ResultReason, - CancellationReason, ) - from azure.cognitiveservices.speech.audio import AudioStreamFormat, PushAudioInputStream + from azure.cognitiveservices.speech.audio import ( + AudioStreamFormat, + PushAudioInputStream, + ) from azure.cognitiveservices.speech.dialog import AudioConfig + from openai import AsyncAzureOpenAI except ModuleNotFoundError as e: logger.error(f"Exception: {e}") logger.error( @@ -70,6 +74,16 @@ def create_client(self, api_key=None, base_url=None, **kwargs): class AzureTTSService(TTSService): + class InputParams(BaseModel): + emphasis: Optional[str] = None + language: Optional[str] = "en-US" + pitch: Optional[str] = None + rate: Optional[str] = "1.05" + role: Optional[str] = None + style: Optional[str] = None + style_degree: Optional[str] = None + volume: Optional[str] = None + def __init__( self, *, @@ -77,6 +91,7 @@ def __init__( region: str, voice="en-US-SaraNeural", sample_rate: int = 16000, + params: InputParams = InputParams(), **kwargs, ): super().__init__(sample_rate=sample_rate, **kwargs) @@ -86,29 +101,118 @@ def __init__( self._voice = voice self._sample_rate = sample_rate + self._params = params def can_generate_metrics(self) -> bool: return True + def _construct_ssml(self, text: str) -> str: + ssml = ( + f"" + f"" + "" + ) + + if self._params.style: + ssml += f"" + + if self._params.emphasis: + ssml += f"" + + ssml += text + + if self._params.emphasis: + ssml += "" + + ssml += "" + + if self._params.style: + ssml += "" + + ssml += "" + + return ssml + async def set_voice(self, voice: str): logger.debug(f"Switching TTS voice to: [{voice}]") self._voice = voice + async def set_emphasis(self, emphasis: str): + logger.debug(f"Setting TTS emphasis to: [{emphasis}]") + self._params.emphasis = emphasis + + async def set_language(self, language: str): + logger.debug(f"Setting TTS language code to: [{language}]") + self._params.language = language + + async def set_pitch(self, pitch: str): + logger.debug(f"Setting TTS pitch to: [{pitch}]") + self._params.pitch = pitch + + async def set_rate(self, rate: str): + logger.debug(f"Setting TTS rate to: [{rate}]") + self._params.rate = rate + + async def set_role(self, role: str): + logger.debug(f"Setting TTS role to: [{role}]") + self._params.role = role + + async def set_style(self, style: str): + logger.debug(f"Setting TTS style to: [{style}]") + self._params.style = style + + async def set_style_degree(self, style_degree: str): + logger.debug(f"Setting TTS style degree to: [{style_degree}]") + self._params.style_degree = style_degree + + async def set_volume(self, volume: str): + logger.debug(f"Setting TTS volume to: [{volume}]") + self._params.volume = volume + + async def set_params(self, **kwargs): + valid_params = { + "voice": self.set_voice, + "emphasis": self.set_emphasis, + "language_code": self.set_language, + "pitch": self.set_pitch, + "rate": self.set_rate, + "role": self.set_role, + "style": self.set_style, + "style_degree": self.set_style_degree, + "volume": self.set_volume, + } + + for param, value in kwargs.items(): + if param in valid_params: + await valid_params[param](value) + else: + logger.warning(f"Ignoring unknown parameter: {param}") + + logger.debug(f"Updated TTS parameters: {', '.join(kwargs.keys())}") + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: logger.debug(f"Generating TTS: [{text}]") await self.start_ttfb_metrics() - ssml = ( - "" - f"" - "" - "" - "" - f"{text}" - " " - ) + ssml = self._construct_ssml(text) result = await asyncio.to_thread(self._speech_synthesizer.speak_ssml, (ssml)) diff --git a/src/pipecat/services/cartesia.py b/src/pipecat/services/cartesia.py index 90deeda15..5f798b1e5 100644 --- a/src/pipecat/services/cartesia.py +++ b/src/pipecat/services/cartesia.py @@ -9,7 +9,8 @@ import base64 import asyncio -from typing import AsyncGenerator +from typing import AsyncGenerator, Optional, Union, List +from pydantic.main import BaseModel from pipecat.frames.frames import ( CancelFrame, @@ -25,7 +26,7 @@ ) from pipecat.processors.frame_processor import FrameDirection from pipecat.transcriptions.language import Language -from pipecat.services.ai_services import AsyncWordTTSService, TTSService +from pipecat.services.ai_services import WordTTSService, TTSService from loguru import logger @@ -60,7 +61,15 @@ def language_to_cartesia_language(language: Language) -> str | None: return None -class CartesiaTTSService(AsyncWordTTSService): +class CartesiaTTSService(WordTTSService): + class InputParams(BaseModel): + encoding: Optional[str] = "pcm_s16le" + sample_rate: Optional[int] = 16000 + container: Optional[str] = "raw" + language: Optional[str] = "en" + speed: Optional[Union[str, float]] = "" + emotion: Optional[List[str]] = [] + def __init__( self, *, @@ -69,9 +78,7 @@ def __init__( cartesia_version: str = "2024-06-10", url: str = "wss://api.cartesia.ai/tts/websocket", model_id: str = "sonic-english", - encoding: str = "pcm_s16le", - sample_rate: int = 16000, - language: str = "en", + params: InputParams = InputParams(), **kwargs, ): # Aggregating sentences still gives cleaner-sounding results and fewer @@ -85,20 +92,26 @@ def __init__( # can use those to generate text frames ourselves aligned with the # playout timing of the audio! super().__init__( - aggregate_sentences=True, push_text_frames=False, sample_rate=sample_rate, **kwargs + aggregate_sentences=True, + push_text_frames=False, + sample_rate=params.sample_rate, + **kwargs, ) self._api_key = api_key self._cartesia_version = cartesia_version self._url = url self._voice_id = voice_id + self._model_id = model_id self.set_model_name(model_id) self._output_format = { - "container": "raw", - "encoding": encoding, - "sample_rate": sample_rate, + "container": params.container, + "encoding": params.encoding, + "sample_rate": params.sample_rate, } - self._language = language + self._language = params.language + self._speed = params.speed + self._emotion = params.emotion self._websocket = None self._context_id = None @@ -108,6 +121,7 @@ def can_generate_metrics(self) -> bool: return True async def set_model(self, model: str): + self._model_id = model await super().set_model(model) logger.debug(f"Switching TTS model to: [{model}]") @@ -115,10 +129,42 @@ async def set_voice(self, voice: str): logger.debug(f"Switching TTS voice to: [{voice}]") self._voice_id = voice + async def set_speed(self, speed: str): + logger.debug(f"Switching TTS speed to: [{speed}]") + self._speed = speed + + async def set_emotion(self, emotion: list[str]): + logger.debug(f"Switching TTS emotion to: [{emotion}]") + self._emotion = emotion + async def set_language(self, language: Language): logger.debug(f"Switching TTS language to: [{language}]") self._language = language_to_cartesia_language(language) + def _build_msg( + self, text: str = "", continue_transcript: bool = True, add_timestamps: bool = True + ): + voice_config = {"mode": "id", "id": self._voice_id} + + if self._speed or self._emotion: + voice_config["__experimental_controls"] = {} + if self._speed: + voice_config["__experimental_controls"]["speed"] = self._speed + if self._emotion: + voice_config["__experimental_controls"]["emotion"] = self._emotion + + msg = { + "transcript": text, + "continue": continue_transcript, + "context_id": self._context_id, + "model_id": self._model_name, + "voice": voice_config, + "output_format": self._output_format, + "language": self._language, + "add_timestamps": add_timestamps, + } + return json.dumps(msg) + async def start(self, frame: StartFrame): await super().start(frame) await self._connect() @@ -173,17 +219,8 @@ async def flush_audio(self): if not self._context_id or not self._websocket: return logger.trace("Flushing audio") - msg = { - "transcript": "", - "continue": False, - "context_id": self._context_id, - "model_id": self.model_name, - "voice": {"mode": "id", "id": self._voice_id}, - "output_format": self._output_format, - "language": self._language, - "add_timestamps": True, - } - await self._websocket.send(json.dumps(msg)) + msg = self._build_msg(text="", continue_transcript=False) + await self._websocket.send(msg) async def _receive_task_handler(self): try: @@ -236,18 +273,10 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: await self.start_ttfb_metrics() self._context_id = str(uuid.uuid4()) - msg = { - "transcript": text + " ", - "continue": True, - "context_id": self._context_id, - "model_id": self.model_name, - "voice": {"mode": "id", "id": self._voice_id}, - "output_format": self._output_format, - "language": self._language, - "add_timestamps": True, - } + msg = self._build_msg(text=text) + try: - await self._get_websocket().send(json.dumps(msg)) + await self._get_websocket().send(msg) await self.start_tts_usage_metrics(text) except Exception as e: logger.error(f"{self} error sending message: {e}") @@ -261,6 +290,14 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: class CartesiaHttpTTSService(TTSService): + class InputParams(BaseModel): + encoding: Optional[str] = "pcm_s16le" + sample_rate: Optional[int] = 16000 + container: Optional[str] = "raw" + language: Optional[str] = "en" + speed: Optional[Union[str, float]] = "" + emotion: Optional[List[str]] = [] + def __init__( self, *, @@ -268,9 +305,7 @@ def __init__( voice_id: str, model_id: str = "sonic-english", base_url: str = "https://api.cartesia.ai", - encoding: str = "pcm_s16le", - sample_rate: int = 16000, - language: str = "en", + params: InputParams = InputParams(), **kwargs, ): super().__init__(**kwargs) @@ -278,12 +313,15 @@ def __init__( self._api_key = api_key self._voice_id = voice_id self._model_id = model_id + self.set_model_name(model_id) self._output_format = { - "container": "raw", - "encoding": encoding, - "sample_rate": sample_rate, + "container": params.container, + "encoding": params.encoding, + "sample_rate": params.sample_rate, } - self._language = language + self._language = params.language + self._speed = params.speed + self._emotion = params.emotion self._client = AsyncCartesia(api_key=api_key, base_url=base_url) @@ -293,11 +331,20 @@ def can_generate_metrics(self) -> bool: async def set_model(self, model: str): logger.debug(f"Switching TTS model to: [{model}]") self._model_id = model + await super().set_model(model) async def set_voice(self, voice: str): logger.debug(f"Switching TTS voice to: [{voice}]") self._voice_id = voice + async def set_speed(self, speed: str): + logger.debug(f"Switching TTS speed to: [{speed}]") + self._speed = speed + + async def set_emotion(self, emotion: list[str]): + logger.debug(f"Switching TTS emotion to: [{emotion}]") + self._emotion = emotion + async def set_language(self, language: Language): logger.debug(f"Switching TTS language to: [{language}]") self._language = language_to_cartesia_language(language) @@ -317,6 +364,14 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: await self.start_ttfb_metrics() try: + voice_controls = None + if self._speed or self._emotion: + voice_controls = {} + if self._speed: + voice_controls["speed"] = self._speed + if self._emotion: + voice_controls["emotion"] = self._emotion + output = await self._client.tts.sse( model_id=self._model_id, transcript=text, @@ -324,6 +379,7 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: output_format=self._output_format, language=self._language, stream=False, + _experimental_voice_controls=voice_controls, ) await self.stop_ttfb_metrics() diff --git a/src/pipecat/services/deepgram.py b/src/pipecat/services/deepgram.py index fab12e080..d109cce3c 100644 --- a/src/pipecat/services/deepgram.py +++ b/src/pipecat/services/deepgram.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD 2-Clause License # -import aiohttp +import asyncio from typing import AsyncGenerator @@ -15,10 +15,10 @@ Frame, InterimTranscriptionFrame, StartFrame, + TranscriptionFrame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame, - TranscriptionFrame, ) from pipecat.services.ai_services import STTService, TTSService from pipecat.transcriptions.language import Language @@ -26,16 +26,16 @@ from loguru import logger - # See .env.example for Deepgram configuration needed try: from deepgram import ( AsyncListenWebSocketClient, DeepgramClient, DeepgramClientOptions, - LiveTranscriptionEvents, LiveOptions, LiveResultResponse, + LiveTranscriptionEvents, + SpeakOptions, ) except ModuleNotFoundError as e: logger.error(f"Exception: {e}") @@ -50,9 +50,7 @@ def __init__( self, *, api_key: str, - aiohttp_session: aiohttp.ClientSession, voice: str = "aura-helios-en", - base_url: str = "https://api.deepgram.com/v1/speak", sample_rate: int = 16000, encoding: str = "linear16", **kwargs, @@ -60,11 +58,9 @@ def __init__( super().__init__(**kwargs) self._voice = voice - self._api_key = api_key - self._base_url = base_url self._sample_rate = sample_rate self._encoding = encoding - self._aiohttp_session = aiohttp_session + self._deepgram_client = DeepgramClient(api_key=api_key) def can_generate_metrics(self) -> bool: return True @@ -76,44 +72,45 @@ async def set_voice(self, voice: str): async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: logger.debug(f"Generating TTS: [{text}]") - base_url = self._base_url - request_url = f"{base_url}?model={self._voice}&encoding={ - self._encoding}&container=none&sample_rate={self._sample_rate}" - headers = {"authorization": f"token {self._api_key}"} - body = {"text": text} + options = SpeakOptions( + model=self._voice, + encoding=self._encoding, + sample_rate=self._sample_rate, + container="none", + ) try: await self.start_ttfb_metrics() - async with self._aiohttp_session.post(request_url, headers=headers, json=body) as r: - if r.status != 200: - response_text = await r.text() - # If we get a a "Bad Request: Input is unutterable", just print out a debug log. - # All other unsuccesful requests should emit an error frame. If not specifically - # handled by the running PipelineTask, the ErrorFrame will cancel the task. - if "unutterable" in response_text: - logger.debug(f"Unutterable text: [{text}]") - return - - logger.error( - f"{self} error getting audio (status: {r.status}, error: {response_text})" - ) - yield ErrorFrame( - f"Error getting audio (status: {r.status}, error: {response_text})" - ) - return - - await self.start_tts_usage_metrics(text) - - await self.push_frame(TTSStartedFrame()) - async for data in r.content: - await self.stop_ttfb_metrics() - frame = TTSAudioRawFrame( - audio=data, sample_rate=self._sample_rate, num_channels=1 - ) - yield frame - await self.push_frame(TTSStoppedFrame()) + + response = await asyncio.to_thread( + self._deepgram_client.speak.v("1").stream, {"text": text}, options + ) + + await self.start_tts_usage_metrics(text) + await self.push_frame(TTSStartedFrame()) + + # The response.stream_memory is already a BytesIO object + audio_buffer = response.stream_memory + + if audio_buffer is None: + raise ValueError("No audio data received from Deepgram") + + # Read and yield the audio data in chunks + audio_buffer.seek(0) # Ensure we're at the start of the buffer + chunk_size = 8192 # Use a fixed buffer size + while True: + await self.stop_ttfb_metrics() + chunk = audio_buffer.read(chunk_size) + if not chunk: + break + frame = TTSAudioRawFrame(audio=chunk, sample_rate=self._sample_rate, num_channels=1) + yield frame + + await self.push_frame(TTSStoppedFrame()) + except Exception as e: logger.exception(f"{self} exception: {e}") + yield ErrorFrame(f"Error getting audio: {str(e)}") class DeepgramSTTService(STTService): diff --git a/src/pipecat/services/elevenlabs.py b/src/pipecat/services/elevenlabs.py index 00a32cbfd..611f2a024 100644 --- a/src/pipecat/services/elevenlabs.py +++ b/src/pipecat/services/elevenlabs.py @@ -23,7 +23,7 @@ TTSStoppedFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.ai_services import AsyncWordTTSService +from pipecat.services.ai_services import WordTTSService # See .env.example for ElevenLabs configuration needed try: @@ -70,8 +70,9 @@ def calculate_word_times( return word_times -class ElevenLabsTTSService(AsyncWordTTSService): +class ElevenLabsTTSService(WordTTSService): class InputParams(BaseModel): + language: Optional[str] = None output_format: Literal["pcm_16000", "pcm_22050", "pcm_24000", "pcm_44100"] = "pcm_16000" optimize_streaming_latency: Optional[str] = None stability: Optional[float] = None @@ -228,6 +229,15 @@ async def _connect(self): if self._params.optimize_streaming_latency: url += f"&optimize_streaming_latency={self._params.optimize_streaming_latency}" + # language can only be used with the 'eleven_turbo_v2_5' model + if self._params.language: + if model == "eleven_turbo_v2_5": + url += f"&language_code={self._params.language}" + else: + logger.debug( + f"Language code [{self._params.language}] not applied. Language codes can only be used with the 'eleven_turbo_v2_5' model." + ) + self._websocket = await websockets.connect(url) self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) self._keepalive_task = self.get_event_loop().create_task(self._keepalive_task_handler()) diff --git a/src/pipecat/services/fal.py b/src/pipecat/services/fal.py index bb7b47dfc..aecdeb709 100644 --- a/src/pipecat/services/fal.py +++ b/src/pipecat/services/fal.py @@ -8,13 +8,14 @@ import io import os -from PIL import Image from pydantic import BaseModel from typing import AsyncGenerator, Optional, Union, Dict from pipecat.frames.frames import ErrorFrame, Frame, URLImageRawFrame from pipecat.services.ai_services import ImageGenService +from PIL import Image + from loguru import logger try: diff --git a/src/pipecat/services/gladia.py b/src/pipecat/services/gladia.py index 12183adde..a590d73cf 100644 --- a/src/pipecat/services/gladia.py +++ b/src/pipecat/services/gladia.py @@ -51,7 +51,7 @@ def __init__( params: InputParams = InputParams(), **kwargs, ): - super().__init__(sync=False, **kwargs) + super().__init__(**kwargs) self._api_key = api_key self._url = url diff --git a/src/pipecat/services/google.py b/src/pipecat/services/google.py index 4de6b77fa..519f47028 100644 --- a/src/pipecat/services/google.py +++ b/src/pipecat/services/google.py @@ -5,30 +5,37 @@ # import asyncio +import json +from typing import AsyncGenerator, List, Literal, Optional -from typing import List +from loguru import logger +from pydantic import BaseModel from pipecat.frames.frames import ( + ErrorFrame, Frame, - LLMModelUpdateFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + LLMMessagesFrame, + LLMUpdateSettingsFrame, TextFrame, + TTSAudioRawFrame, + TTSStartedFrame, + TTSStoppedFrame, VisionImageRawFrame, - LLMMessagesFrame, - LLMFullResponseStartFrame, - LLMFullResponseEndFrame, ) -from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.ai_services import LLMService from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, ) - -from loguru import logger +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.ai_services import LLMService, TTSService try: - import google.generativeai as gai import google.ai.generativelanguage as glm + import google.generativeai as gai + from google.cloud import texttospeech_v1 + from google.oauth2 import service_account except ModuleNotFoundError as e: logger.error(f"Exception: {e}") logger.error( @@ -129,11 +136,197 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): context = OpenAILLMContext.from_messages(frame.messages) elif isinstance(frame, VisionImageRawFrame): context = OpenAILLMContext.from_image_frame(frame) - elif isinstance(frame, LLMModelUpdateFrame): - logger.debug(f"Switching LLM model to: [{frame.model}]") - self._create_client(frame.model) + elif isinstance(frame, LLMUpdateSettingsFrame): + if frame.model is not None: + logger.debug(f"Switching LLM model to: [{frame.model}]") + self.set_model_name(frame.model) else: await self.push_frame(frame, direction) if context: await self._process_context(context) + + +class GoogleTTSService(TTSService): + class InputParams(BaseModel): + pitch: Optional[str] = None + rate: Optional[str] = None + volume: Optional[str] = None + emphasis: Optional[Literal["strong", "moderate", "reduced", "none"]] = None + language: Optional[str] = "en-US" + gender: Optional[Literal["male", "female", "neutral"]] = None + google_style: Optional[Literal["apologetic", "calm", "empathetic", "firm", "lively"]] = None + + def __init__( + self, + *, + credentials: Optional[str] = None, + credentials_path: Optional[str] = None, + voice_id: str = "en-US-Neural2-A", + sample_rate: int = 24000, + params: InputParams = InputParams(), + **kwargs, + ): + super().__init__(sample_rate=sample_rate, **kwargs) + + self._voice_id: str = voice_id + self._params = params + self._client: texttospeech_v1.TextToSpeechAsyncClient = self._create_client( + credentials, credentials_path + ) + + def _create_client( + self, credentials: Optional[str], credentials_path: Optional[str] + ) -> texttospeech_v1.TextToSpeechAsyncClient: + creds: Optional[service_account.Credentials] = None + + # Create a Google Cloud service account for the Cloud Text-to-Speech API + # Using either the provided credentials JSON string or the path to a service account JSON + # file, create a Google Cloud service account and use it to authenticate with the API. + if credentials: + # Use provided credentials JSON string + json_account_info = json.loads(credentials) + creds = service_account.Credentials.from_service_account_info(json_account_info) + elif credentials_path: + # Use service account JSON file if provided + creds = service_account.Credentials.from_service_account_file(credentials_path) + else: + raise ValueError("Either 'credentials' or 'credentials_path' must be provided.") + + return texttospeech_v1.TextToSpeechAsyncClient(credentials=creds) + + def can_generate_metrics(self) -> bool: + return True + + def _construct_ssml(self, text: str) -> str: + ssml = "" + + # Voice tag + voice_attrs = [f"name='{self._voice_id}'"] + if self._params.language: + voice_attrs.append(f"language='{self._params.language}'") + if self._params.gender: + voice_attrs.append(f"gender='{self._params.gender}'") + ssml += f"" + + # Prosody tag + prosody_attrs = [] + if self._params.pitch: + prosody_attrs.append(f"pitch='{self._params.pitch}'") + if self._params.rate: + prosody_attrs.append(f"rate='{self._params.rate}'") + if self._params.volume: + prosody_attrs.append(f"volume='{self._params.volume}'") + + if prosody_attrs: + ssml += f"" + + # Emphasis tag + if self._params.emphasis: + ssml += f"" + + # Google style tag + if self._params.google_style: + ssml += f"" + + ssml += text + + # Close tags + if self._params.google_style: + ssml += "" + if self._params.emphasis: + ssml += "" + if prosody_attrs: + ssml += "" + ssml += "" + + return ssml + + async def set_voice(self, voice: str) -> None: + logger.debug(f"Switching TTS voice to: [{voice}]") + self._voice_id = voice + + async def set_language(self, language: str) -> None: + logger.debug(f"Switching TTS language to: [{language}]") + self._params.language = language + + async def set_pitch(self, pitch: str) -> None: + logger.debug(f"Switching TTS pitch to: [{pitch}]") + self._params.pitch = pitch + + async def set_rate(self, rate: str) -> None: + logger.debug(f"Switching TTS rate to: [{rate}]") + self._params.rate = rate + + async def set_volume(self, volume: str) -> None: + logger.debug(f"Switching TTS volume to: [{volume}]") + self._params.volume = volume + + async def set_emphasis( + self, emphasis: Literal["strong", "moderate", "reduced", "none"] + ) -> None: + logger.debug(f"Switching TTS emphasis to: [{emphasis}]") + self._params.emphasis = emphasis + + async def set_gender(self, gender: Literal["male", "female", "neutral"]) -> None: + logger.debug(f"Switch TTS gender to [{gender}]") + self._params.gender = gender + + async def google_style( + self, google_style: Literal["apologetic", "calm", "empathetic", "firm", "lively"] + ) -> None: + logger.debug(f"Switching TTS google style to: [{google_style}]") + self._params.google_style = google_style + + async def set_params(self, params: InputParams) -> None: + logger.debug(f"Switching TTS params to: [{params}]") + self._params = params + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"Generating TTS: [{text}]") + + try: + await self.start_ttfb_metrics() + + ssml = self._construct_ssml(text) + synthesis_input = texttospeech_v1.SynthesisInput(ssml=ssml) + voice = texttospeech_v1.VoiceSelectionParams( + language_code=self._params.language, name=self._voice_id + ) + audio_config = texttospeech_v1.AudioConfig( + audio_encoding=texttospeech_v1.AudioEncoding.LINEAR16, + sample_rate_hertz=self.sample_rate, + ) + + request = texttospeech_v1.SynthesizeSpeechRequest( + input=synthesis_input, voice=voice, audio_config=audio_config + ) + + response = await self._client.synthesize_speech(request=request) + + await self.start_tts_usage_metrics(text) + + await self.push_frame(TTSStartedFrame()) + + # Skip the first 44 bytes to remove the WAV header + audio_content = response.audio_content[44:] + + # Read and yield audio data in chunks + chunk_size = 8192 + for i in range(0, len(audio_content), chunk_size): + chunk = audio_content[i : i + chunk_size] + if not chunk: + break + await self.stop_ttfb_metrics() + frame = TTSAudioRawFrame(chunk, self.sample_rate, 1) + yield frame + await asyncio.sleep(0) # Allow other tasks to run + + await self.push_frame(TTSStoppedFrame()) + + except Exception as e: + logger.exception(f"{self} error generating TTS: {e}") + error_message = f"TTS generation error: {str(e)}" + yield ErrorFrame(error=error_message) + finally: + await self.push_frame(TTSStoppedFrame()) diff --git a/src/pipecat/services/lmnt.py b/src/pipecat/services/lmnt.py index 1ac24d731..8f18002c5 100644 --- a/src/pipecat/services/lmnt.py +++ b/src/pipecat/services/lmnt.py @@ -20,7 +20,7 @@ TTSStartedFrame, TTSStoppedFrame, ) -from pipecat.services.ai_services import AsyncTTSService +from pipecat.services.ai_services import TTSService from loguru import logger @@ -35,7 +35,7 @@ raise Exception(f"Missing module: {e}") -class LmntTTSService(AsyncTTSService): +class LmntTTSService(TTSService): def __init__( self, *, @@ -47,7 +47,7 @@ def __init__( ): # Let TTSService produce TTSStoppedFrames after a short delay of # no activity. - super().__init__(sync=False, push_stop_frames=True, sample_rate=sample_rate, **kwargs) + super().__init__(push_stop_frames=True, sample_rate=sample_rate, **kwargs) self._api_key = api_key self._voice_id = voice_id diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index e54898525..49fd04371 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -4,42 +4,40 @@ # SPDX-License-Identifier: BSD 2-Clause License # -import aiohttp import base64 import io import json -import httpx from dataclasses import dataclass - from typing import Any, AsyncGenerator, Dict, List, Literal, Optional -from pydantic import BaseModel, Field +import aiohttp +import httpx from loguru import logger from PIL import Image +from pydantic import BaseModel, Field from pipecat.frames.frames import ( ErrorFrame, Frame, + FunctionCallInProgressFrame, + FunctionCallResultFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, LLMMessagesFrame, - LLMModelUpdateFrame, + LLMUpdateSettingsFrame, + StartInterruptionFrame, + TextFrame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame, - TextFrame, URLImageRawFrame, VisionImageRawFrame, - FunctionCallResultFrame, - FunctionCallInProgressFrame, - StartInterruptionFrame, ) from pipecat.metrics.metrics import LLMTokenUsage from pipecat.processors.aggregators.llm_response import ( - LLMUserContextAggregator, LLMAssistantContextAggregator, + LLMUserContextAggregator, ) - from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, @@ -48,7 +46,13 @@ from pipecat.services.ai_services import ImageGenService, LLMService, TTSService try: - from openai import AsyncOpenAI, AsyncStream, DefaultAsyncHttpxClient, BadRequestError, NOT_GIVEN + from openai import ( + NOT_GIVEN, + AsyncOpenAI, + AsyncStream, + BadRequestError, + DefaultAsyncHttpxClient, + ) from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam except ModuleNotFoundError as e: logger.error(f"Exception: {e}") @@ -201,6 +205,10 @@ async def _stream_chat_completions( return chunks async def _process_context(self, context: OpenAILLMContext): + functions_list = [] + arguments_list = [] + tool_id_list = [] + func_idx = 0 function_name = "" arguments = "" tool_call_id = "" @@ -238,6 +246,14 @@ async def _process_context(self, context: OpenAILLMContext): # yield a frame containing the function name and the arguments. tool_call = chunk.choices[0].delta.tool_calls[0] + if tool_call.index != func_idx: + functions_list.append(function_name) + arguments_list.append(arguments) + tool_id_list.append(tool_call_id) + function_name = "" + arguments = "" + tool_call_id = "" + func_idx += 1 if tool_call.function and tool_call.function.name: function_name += tool_call.function.name tool_call_id = tool_call.id @@ -253,21 +269,46 @@ async def _process_context(self, context: OpenAILLMContext): # the context, and re-prompt to get a chat answer. If we don't have a registered # handler, raise an exception. if function_name and arguments: - if self.has_function(function_name): - await self._handle_function_call(context, tool_call_id, function_name, arguments) - else: - raise OpenAIUnhandledFunctionException( - f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function." - ) + # added to the list as last function name and arguments not added to the list + functions_list.append(function_name) + arguments_list.append(arguments) + tool_id_list.append(tool_call_id) + + total_items = len(functions_list) + for index, (function_name, arguments, tool_id) in enumerate( + zip(functions_list, arguments_list, tool_id_list), start=1 + ): + if self.has_function(function_name): + run_llm = index == total_items + arguments = json.loads(arguments) + await self.call_function( + context=context, + function_name=function_name, + arguments=arguments, + tool_call_id=tool_id, + run_llm=run_llm, + ) + else: + raise OpenAIUnhandledFunctionException( + f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function." + ) - async def _handle_function_call(self, context, tool_call_id, function_name, arguments): - arguments = json.loads(arguments) - await self.call_function( - context=context, - tool_call_id=tool_call_id, - function_name=function_name, - arguments=arguments, - ) + async def _update_settings(self, frame: LLMUpdateSettingsFrame): + if frame.model is not None: + logger.debug(f"Switching LLM model to: [{frame.model}]") + self.set_model_name(frame.model) + if frame.frequency_penalty is not None: + await self.set_frequency_penalty(frame.frequency_penalty) + if frame.presence_penalty is not None: + await self.set_presence_penalty(frame.presence_penalty) + if frame.seed is not None: + await self.set_seed(frame.seed) + if frame.temperature is not None: + await self.set_temperature(frame.temperature) + if frame.top_p is not None: + await self.set_top_p(frame.top_p) + if frame.extra: + await self.set_extra(frame.extra) async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -279,9 +320,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): context = OpenAILLMContext.from_messages(frame.messages) elif isinstance(frame, VisionImageRawFrame): context = OpenAILLMContext.from_image_frame(frame) - elif isinstance(frame, LLMModelUpdateFrame): - logger.debug(f"Switching LLM model to: [{frame.model}]") - self.set_model_name(frame.model) + elif isinstance(frame, LLMUpdateSettingsFrame): + await self._update_settings(frame) else: await self.push_frame(frame, direction) @@ -316,9 +356,13 @@ def __init__( super().__init__(model=model, params=params, **kwargs) @staticmethod - def create_context_aggregator(context: OpenAILLMContext) -> OpenAIContextAggregatorPair: + def create_context_aggregator( + context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True + ) -> OpenAIContextAggregatorPair: user = OpenAIUserContextAggregator(context) - assistant = OpenAIAssistantContextAggregator(user) + assistant = OpenAIAssistantContextAggregator( + user, expect_stripped_words=assistant_expect_stripped_words + ) return OpenAIContextAggregatorPair(_user=user, _assistant=assistant) @@ -438,34 +482,30 @@ def __init__(self, context: OpenAILLMContext): class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): - def __init__(self, user_context_aggregator: OpenAIUserContextAggregator): - super().__init__(context=user_context_aggregator._context) + def __init__(self, user_context_aggregator: OpenAIUserContextAggregator, **kwargs): + super().__init__(context=user_context_aggregator._context, **kwargs) self._user_context_aggregator = user_context_aggregator - self._function_call_in_progress = None + self._function_calls_in_progress = {} self._function_call_result = None async def process_frame(self, frame, direction): await super().process_frame(frame, direction) # See note above about not calling push_frame() here. if isinstance(frame, StartInterruptionFrame): - self._function_call_in_progress = None + self._function_calls_in_progress.clear() self._function_call_finished = None elif isinstance(frame, FunctionCallInProgressFrame): - self._function_call_in_progress = frame + self._function_calls_in_progress[frame.tool_call_id] = frame elif isinstance(frame, FunctionCallResultFrame): - if ( - self._function_call_in_progress - and self._function_call_in_progress.tool_call_id == frame.tool_call_id - ): - self._function_call_in_progress = None + if frame.tool_call_id in self._function_calls_in_progress: + del self._function_calls_in_progress[frame.tool_call_id] self._function_call_result = frame # TODO-CB: Kwin wants us to refactor this out of here but I REFUSE await self._push_aggregation() else: logger.warning( - f"FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id" + "FunctionCallResultFrame tool_call_id does not match any function call in progress" ) - self._function_call_in_progress = None self._function_call_result = None async def _push_aggregation(self): @@ -475,7 +515,7 @@ async def _push_aggregation(self): run_llm = False aggregation = self._aggregation - self._aggregation = "" + self._reset() try: if self._function_call_result: @@ -504,12 +544,15 @@ async def _push_aggregation(self): "tool_call_id": frame.tool_call_id, } ) - run_llm = True + run_llm = frame.run_llm else: self._context.add_message({"role": "assistant", "content": aggregation}) if run_llm: await self._user_context_aggregator.push_context_frame() + frame = OpenAILLMContextFrame(self._context) + await self.push_frame(frame) + except Exception as e: logger.error(f"Error processing frame: {e}") diff --git a/src/pipecat/services/together.py b/src/pipecat/services/together.py index b1365bc69..3f4d97964 100644 --- a/src/pipecat/services/together.py +++ b/src/pipecat/services/together.py @@ -7,37 +7,36 @@ import json import re import uuid -from pydantic import BaseModel, Field - -from typing import Any, Dict, List, Optional -from dataclasses import dataclass from asyncio import CancelledError +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from loguru import logger +from pydantic import BaseModel, Field from pipecat.frames.frames import ( Frame, - LLMModelUpdateFrame, - TextFrame, - UserImageRequestFrame, - LLMMessagesFrame, - LLMFullResponseStartFrame, - LLMFullResponseEndFrame, - FunctionCallResultFrame, FunctionCallInProgressFrame, + FunctionCallResultFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + LLMMessagesFrame, + LLMUpdateSettingsFrame, StartInterruptionFrame, + TextFrame, + UserImageRequestFrame, ) from pipecat.metrics.metrics import LLMTokenUsage -from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.ai_services import LLMService +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantContextAggregator, + LLMUserContextAggregator, +) from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContext, OpenAILLMContextFrame, ) -from pipecat.processors.aggregators.llm_response import ( - LLMUserContextAggregator, - LLMAssistantContextAggregator, -) - -from loguru import logger +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.ai_services import LLMService try: from together import AsyncTogether @@ -96,9 +95,13 @@ def can_generate_metrics(self) -> bool: return True @staticmethod - def create_context_aggregator(context: OpenAILLMContext) -> TogetherContextAggregatorPair: + def create_context_aggregator( + context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True + ) -> TogetherContextAggregatorPair: user = TogetherUserContextAggregator(context) - assistant = TogetherAssistantContextAggregator(user) + assistant = TogetherAssistantContextAggregator( + user, expect_stripped_words=assistant_expect_stripped_words + ) return TogetherContextAggregatorPair(_user=user, _assistant=assistant) async def set_frequency_penalty(self, frequency_penalty: float): @@ -129,6 +132,25 @@ async def set_extra(self, extra: Dict[str, Any]): logger.debug(f"Switching LLM extra to: [{extra}]") self._extra = extra + async def _update_settings(self, frame: LLMUpdateSettingsFrame): + if frame.model is not None: + logger.debug(f"Switching LLM model to: [{frame.model}]") + self.set_model_name(frame.model) + if frame.frequency_penalty is not None: + await self.set_frequency_penalty(frame.frequency_penalty) + if frame.max_tokens is not None: + await self.set_max_tokens(frame.max_tokens) + if frame.presence_penalty is not None: + await self.set_presence_penalty(frame.presence_penalty) + if frame.temperature is not None: + await self.set_temperature(frame.temperature) + if frame.top_k is not None: + await self.set_top_k(frame.top_k) + if frame.top_p is not None: + await self.set_top_p(frame.top_p) + if frame.extra: + await self.set_extra(frame.extra) + async def _process_context(self, context: OpenAILLMContext): try: await self.push_frame(LLMFullResponseStartFrame()) @@ -188,7 +210,7 @@ async def _process_context(self, context: OpenAILLMContext): if chunk.choices[0].finish_reason == "eos" and accumulating_function_call: await self._extract_function_call(context, function_call_accumulator) - except CancelledError as e: + except CancelledError: # todo: implement token counting estimates for use when the user interrupts a long generation # we do this in the anthropic.py service raise @@ -206,9 +228,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): context = frame.context elif isinstance(frame, LLMMessagesFrame): context = TogetherLLMContext.from_messages(frame.messages) - elif isinstance(frame, LLMModelUpdateFrame): - logger.debug(f"Switching LLM model to: [{frame.model}]") - self.set_model_name(frame.model) + elif isinstance(frame, LLMUpdateSettingsFrame): + await self._update_settings(frame) else: await self.push_frame(frame, direction) @@ -314,8 +335,8 @@ async def process_frame(self, frame, direction): class TogetherAssistantContextAggregator(LLMAssistantContextAggregator): - def __init__(self, user_context_aggregator: TogetherUserContextAggregator): - super().__init__(context=user_context_aggregator._context) + def __init__(self, user_context_aggregator: TogetherUserContextAggregator, **kwargs): + super().__init__(context=user_context_aggregator._context, **kwargs) self._user_context_aggregator = user_context_aggregator self._function_call_in_progress = None self._function_call_result = None @@ -338,7 +359,7 @@ async def process_frame(self, frame, direction): await self._push_aggregation() else: logger.warning( - f"FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id" + "FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id" ) self._function_call_in_progress = None self._function_call_result = None @@ -353,7 +374,7 @@ async def _push_aggregation(self): run_llm = False aggregation = self._aggregation - self._aggregation = "" + self._reset() try: if self._function_call_result: @@ -373,5 +394,8 @@ async def _push_aggregation(self): if run_llm: await self._user_context_aggregator.push_messages_frame() + frame = OpenAILLMContextFrame(self._context) + await self.push_frame(frame) + except Exception as e: logger.error(f"Error processing frame: {e}") diff --git a/src/pipecat/services/xtts.py b/src/pipecat/services/xtts.py index 5161efcf6..2c47d59e8 100644 --- a/src/pipecat/services/xtts.py +++ b/src/pipecat/services/xtts.py @@ -18,10 +18,10 @@ ) from pipecat.services.ai_services import TTSService -from loguru import logger - import numpy as np +from loguru import logger + try: import resampy except ModuleNotFoundError as e: diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index 73ad3f5e3..710f8108a 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -31,12 +31,16 @@ class BaseInputTransport(FrameProcessor): def __init__(self, params: TransportParams, **kwargs): - super().__init__(sync=False, **kwargs) + super().__init__(**kwargs) self._params = params self._executor = ThreadPoolExecutor(max_workers=5) + # Task to process incoming audio (VAD) and push audio frames downstream + # if passthrough is enabled. + self._audio_task = None + async def start(self, frame: StartFrame): # Create audio input queue and task if needed. if self._params.audio_in_enabled or self._params.vad_enabled: @@ -45,16 +49,17 @@ async def start(self, frame: StartFrame): async def stop(self, frame: EndFrame): # Cancel and wait for the audio input task to finish. - if self._params.audio_in_enabled or self._params.vad_enabled: + if self._audio_task and (self._params.audio_in_enabled or self._params.vad_enabled): self._audio_task.cancel() await self._audio_task + self._audio_task = None async def cancel(self, frame: CancelFrame): - # Cancel all the tasks and wait for them to finish. - - if self._params.audio_in_enabled or self._params.vad_enabled: + # Cancel and wait for the audio input task to finish. + if self._audio_task and (self._params.audio_in_enabled or self._params.vad_enabled): self._audio_task.cancel() await self._audio_task + self._audio_task = None def vad_analyzer(self) -> VADAnalyzer | None: return self._params.vad_analyzer diff --git a/src/pipecat/transports/base_output.py b/src/pipecat/transports/base_output.py index 5423b122f..c3b9c792b 100644 --- a/src/pipecat/transports/base_output.py +++ b/src/pipecat/transports/base_output.py @@ -43,10 +43,22 @@ class BaseOutputTransport(FrameProcessor): def __init__(self, params: TransportParams, **kwargs): - super().__init__(sync=False, **kwargs) + super().__init__(**kwargs) self._params = params + # Task to process incoming frames so we don't block upstream elements. + self._sink_task = None + + # Task to process incoming frames using a clock. + self._sink_clock_task = None + + # Task to write/send audio frames. + self._audio_out_task = None + + # Task to write/send image frames. + self._camera_out_task = None + # These are the images that we should send to the camera at our desired # framerate. self._camera_images = None @@ -88,36 +100,53 @@ async def stop(self, frame: EndFrame): # that EndFrame to be processed by the sink tasks. We also need to wait # for these tasks before cancelling the camera and audio tasks below # because they might be still rendering. - await self._sink_task - await self._sink_clock_task + if self._sink_task: + await self._sink_task + if self._sink_clock_task: + await self._sink_clock_task # Cancel and wait for the camera output task to finish. - if self._params.camera_out_enabled: + if self._camera_out_task and self._params.camera_out_enabled: self._camera_out_task.cancel() await self._camera_out_task + self._camera_out_task = None # Cancel and wait for the audio output task to finish. - if self._params.audio_out_enabled and self._params.audio_out_is_live: + if ( + self._audio_out_task + and self._params.audio_out_enabled + and self._params.audio_out_is_live + ): self._audio_out_task.cancel() await self._audio_out_task + self._audio_out_task = None async def cancel(self, frame: CancelFrame): # Since we are cancelling everything it doesn't matter if we cancel sink # tasks first or not. - self._sink_task.cancel() - self._sink_clock_task.cancel() - await self._sink_task - await self._sink_clock_task + if self._sink_task: + self._sink_task.cancel() + await self._sink_task + self._sink_task = None + + if self._sink_clock_task: + self._sink_clock_task.cancel() + await self._sink_clock_task + self._sink_clock_task = None # Cancel and wait for the camera output task to finish. - if self._params.camera_out_enabled: + if self._camera_out_task and self._params.camera_out_enabled: self._camera_out_task.cancel() await self._camera_out_task + self._camera_out_task = None # Cancel and wait for the audio output task to finish. - if self._params.audio_out_enabled and self._params.audio_out_is_live: + if self._audio_out_task and ( + self._params.audio_out_enabled and self._params.audio_out_is_live + ): self._audio_out_task.cancel() await self._audio_out_task + self._audio_out_task = None async def send_message(self, frame: TransportMessageFrame): pass @@ -183,11 +212,13 @@ async def _handle_interruptions(self, frame: Frame): if isinstance(frame, StartInterruptionFrame): # Stop sink tasks. - self._sink_task.cancel() - await self._sink_task + if self._sink_task: + self._sink_task.cancel() + await self._sink_task # Stop sink clock tasks. - self._sink_clock_task.cancel() - await self._sink_clock_task + if self._sink_clock_task: + self._sink_clock_task.cancel() + await self._sink_clock_task # Create sink tasks. self._create_sink_tasks() # Let's send a bot stopped speaking if we have to. diff --git a/src/pipecat/transports/services/daily.py b/src/pipecat/transports/services/daily.py index 48b59d8ff..50c2ae085 100644 --- a/src/pipecat/transports/services/daily.py +++ b/src/pipecat/transports/services/daily.py @@ -575,6 +575,9 @@ def __init__(self, client: DailyTransportClient, params: DailyParams, **kwargs): self._client = client self._video_renderers = {} + + # Task that gets audio data from a device or the network and queues it + # internally to be processed. self._audio_in_task = None self._vad_analyzer: VADAnalyzer | None = params.vad_analyzer @@ -603,6 +606,7 @@ async def stop(self, frame: EndFrame): if self._audio_in_task and (self._params.audio_in_enabled or self._params.vad_enabled): self._audio_in_task.cancel() await self._audio_in_task + self._audio_in_task = None async def cancel(self, frame: CancelFrame): # Parent stop. @@ -613,6 +617,7 @@ async def cancel(self, frame: CancelFrame): if self._audio_in_task and (self._params.audio_in_enabled or self._params.vad_enabled): self._audio_in_task.cancel() await self._audio_in_task + self._audio_in_task = None async def cleanup(self): await super().cleanup() diff --git a/src/pipecat/transports/services/livekit.py b/src/pipecat/transports/services/livekit.py index ef4a6bc0a..658994112 100644 --- a/src/pipecat/transports/services/livekit.py +++ b/src/pipecat/transports/services/livekit.py @@ -1,10 +1,17 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + import asyncio + from dataclasses import dataclass from typing import Any, Awaitable, Callable, List -import numpy as np -from loguru import logger from pydantic import BaseModel + +import numpy as np from scipy import signal from pipecat.frames.frames import ( @@ -28,6 +35,8 @@ from pipecat.transports.base_transport import BaseTransport, TransportParams from pipecat.vad.vad_analyzer import VADAnalyzer +from loguru import logger + try: from livekit import rtc from tenacity import retry, stop_after_attempt, wait_exponential diff --git a/src/pipecat/utils/string.py b/src/pipecat/utils/string.py index a47db6c5c..936764345 100644 --- a/src/pipecat/utils/string.py +++ b/src/pipecat/utils/string.py @@ -6,7 +6,6 @@ import re - ENDOFSENTENCE_PATTERN_STR = r""" (? bool: - return ENDOFSENTENCE_PATTERN.search(text.rstrip()) is not None +def match_endofsentence(text: str) -> int: + match = ENDOFSENTENCE_PATTERN.search(text.rstrip()) + return match.end() if match else 0 diff --git a/test-requirements.txt b/test-requirements.txt index 78280b139..07ef45054 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,11 +1,13 @@ aiohttp~=3.10.3 anthropic~=0.30.0 azure-cognitiveservices-speech~=1.40.0 -daily-python~=0.10.1 +boto3~=1.35.27 +daily-python~=0.11.0 deepgram-sdk~=3.5.0 fal-client~=0.4.1 fastapi~=0.112.1 faster-whisper~=1.0.3 +google-cloud-texttospeech~=2.17.2 google-generativeai~=0.7.2 langchain~=0.2.14 livekit~=0.13.1 diff --git a/tests/test_ai_services.py b/tests/test_ai_services.py index c52b0cb56..975f7e20c 100644 --- a/tests/test_ai_services.py +++ b/tests/test_ai_services.py @@ -29,6 +29,7 @@ async def test_endofsentence(self): assert match_endofsentence("This is a sentence! ") assert match_endofsentence("This is a sentence?") assert match_endofsentence("This is a sentence:") + assert match_endofsentence("This is a sentence;") assert not match_endofsentence("This is not a sentence") assert not match_endofsentence("This is not a sentence,") assert not match_endofsentence("This is not a sentence, ") @@ -40,6 +41,18 @@ async def test_endofsentence(self): assert not match_endofsentence("America, or the U.") # U.S.A. assert not match_endofsentence("It still early, it's 3:00 a.") # 3:00 a.m. + async def test_endofsentence_zh(self): + chinese_sentences = [ + "你好。", + "你好!", + "吃了吗?", + "安全第一;", + "他说:", + ] + for i in chinese_sentences: + assert match_endofsentence(i) + assert not match_endofsentence("你好,") + if __name__ == "__main__": unittest.main() diff --git a/tests/test_langchain.py b/tests/test_langchain.py index fb222205b..d30d213bd 100644 --- a/tests/test_langchain.py +++ b/tests/test_langchain.py @@ -7,9 +7,9 @@ import unittest from pipecat.frames.frames import ( + EndFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, - StopTaskFrame, TextFrame, TranscriptionFrame, UserStartedSpeakingFrame, @@ -32,6 +32,7 @@ class TestLangchain(unittest.IsolatedAsyncioTestCase): class MockProcessor(FrameProcessor): def __init__(self, name): + super().__init__() self.name = name self.token: list[str] = [] # Start collecting tokens when we see the start frame @@ -55,13 +56,13 @@ async def process_frame(self, frame, direction): def setUp(self): self.expected_response = "Hello dear human" self.fake_llm = FakeStreamingListLLM(responses=[self.expected_response]) - self.mock_proc = self.MockProcessor("token_collector") async def test_langchain(self): messages = [("system", "Say hello to {name}"), ("human", "{input}")] prompt = ChatPromptTemplate.from_messages(messages).partial(name="Thomas") chain = prompt | self.fake_llm proc = LangchainProcessor(chain=chain) + self.mock_proc = self.MockProcessor("token_collector") tma_in = LLMUserResponseAggregator(messages) tma_out = LLMAssistantResponseAggregator(messages) @@ -81,7 +82,7 @@ async def test_langchain(self): UserStartedSpeakingFrame(), TranscriptionFrame(text="Hi World", user_id="user", timestamp="now"), UserStoppedSpeakingFrame(), - StopTaskFrame(), + EndFrame(), ] )