diff --git a/examples/foundational/07d-interruptible-cartesia.py b/examples/foundational/07d-interruptible-cartesia.py index 610fdb5b8..df52294b1 100644 --- a/examples/foundational/07d-interruptible-cartesia.py +++ b/examples/foundational/07d-interruptible-cartesia.py @@ -37,6 +37,7 @@ async def main(room_url: str, token): token, "Respond bot", DailyParams( + audio_out_sample_rate=44100, audio_out_enabled=True, transcription_enabled=True, vad_enabled=True, @@ -47,6 +48,7 @@ async def main(room_url: str, token): tts = CartesiaTTSService( api_key=os.getenv("CARTESIA_API_KEY"), voice_id="a0e99841-438c-4a64-b679-ae501e7d6091", # Barbershop Man + sample_rate=44100, ) llm = OpenAILLMService( @@ -68,11 +70,11 @@ async def main(room_url: str, token): tma_in, # User responses llm, # LLM tts, # TTS + tma_out, # Goes before the transport because cartesia has word-level timestamps! transport.output(), # Transport bot output - tma_out # Assistant spoken responses ]) - task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True)) + task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True, enable_metrics=True)) @transport.event_handler("on_first_participant_joined") async def on_first_participant_joined(transport, participant): diff --git a/pyproject.toml b/pyproject.toml index bf4693786..a54b5539d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ Website = "https://pipecat.ai" [project.optional-dependencies] anthropic = [ "anthropic~=0.28.1" ] azure = [ "azure-cognitiveservices-speech~=1.38.0" ] -cartesia = [ "cartesia~=1.0.3" ] +cartesia = [ "websockets~=12.0" ] daily = [ "daily-python~=0.10.1" ] deepgram = [ "deepgram-sdk~=3.2.7" ] examples = [ "python-dotenv~=1.0.0", "flask~=3.0.3", "flask_cors~=4.0.1" ] diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index 46bba673e..8647d3f77 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -136,9 +136,16 @@ async def call_start_function(self, function_name: str): class TTSService(AIService): - def __init__(self, *, aggregate_sentences: bool = True, **kwargs): + def __init__( + self, + *, + aggregate_sentences: bool = True, + # if True, subclass is responsible for pushing TextFrames and LLMFullResponseEndFrames + push_text_frames: bool = True, + **kwargs): super().__init__(**kwargs) self._aggregate_sentences: bool = aggregate_sentences + self._push_text_frames: bool = push_text_frames self._current_sentence: str = "" # Converts the text to audio. @@ -149,6 +156,10 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: async def say(self, text: str): await self.process_frame(TextFrame(text=text), FrameDirection.DOWNSTREAM) + async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection): + self._current_sentence = "" + await self.push_frame(frame, direction) + async def _process_text_frame(self, frame: TextFrame): text: str | None = None if not self._aggregate_sentences: @@ -172,9 +183,10 @@ async def _push_tts_frames(self, text: str): await self.process_generator(self.run_tts(text)) await self.stop_processing_metrics() await self.push_frame(TTSStoppedFrame()) - # We send the original text after the audio. This way, if we are - # interrupted, the text is not added to the assistant context. - await self.push_frame(TextFrame(text)) + if self._push_text_frames: + # We send the original text after the audio. This way, if we are + # interrupted, the text is not added to the assistant context. + await self.push_frame(TextFrame(text)) async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -182,12 +194,15 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): if isinstance(frame, TextFrame): await self._process_text_frame(frame) elif isinstance(frame, StartInterruptionFrame): - self._current_sentence = "" - await self.push_frame(frame, direction) + await self._handle_interruption(frame, direction) elif isinstance(frame, LLMFullResponseEndFrame) or isinstance(frame, EndFrame): self._current_sentence = "" await self._push_tts_frames(self._current_sentence) - await self.push_frame(frame) + if isinstance(frame, LLMFullResponseEndFrame): + if self._push_text_frames: + await self.push_frame(frame, direction) + else: + await self.push_frame(frame, direction) else: await self.push_frame(frame, direction) diff --git a/src/pipecat/services/cartesia.py b/src/pipecat/services/cartesia.py index c3b6b905b..be0d80f0b 100644 --- a/src/pipecat/services/cartesia.py +++ b/src/pipecat/services/cartesia.py @@ -4,15 +4,37 @@ # SPDX-License-Identifier: BSD 2-Clause License # -from cartesia import AsyncCartesia +import json +import uuid +import base64 +import asyncio +import time from typing import AsyncGenerator -from pipecat.frames.frames import AudioRawFrame, CancelFrame, EndFrame, Frame, StartFrame +from pipecat.processors.frame_processor import FrameDirection +from pipecat.frames.frames import ( + Frame, + AudioRawFrame, + StartInterruptionFrame, + StartFrame, + EndFrame, + TextFrame, + LLMFullResponseEndFrame +) from pipecat.services.ai_services import TTSService from loguru import logger +# See .env.example for Cartesia configuration needed +try: + import websockets +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use Cartesia, you need to `pip install pipecat-ai[cartesia]`. Also, set `CARTESIA_API_KEY` environment variable.") + raise Exception(f"Missing module: {e}") + class CartesiaTTSService(TTSService): @@ -21,13 +43,30 @@ def __init__( *, api_key: str, voice_id: str, + cartesia_version: str = "2024-06-10", + url: str = "wss://api.cartesia.ai/tts/websocket", model_id: str = "sonic-english", encoding: str = "pcm_s16le", sample_rate: int = 16000, + language: str = "en", **kwargs): super().__init__(**kwargs) + # Aggregating sentences still gives cleaner-sounding results and fewer + # artifacts than streaming one word at a time. On average, waiting for + # a full sentence should only "cost" us 15ms or so with GPT-4o or a Llama 3 + # model, and it's worth it for the better audio quality. + self._aggregate_sentences = True + + # we don't want to automatically push LLM response text frames, because the + # context aggregators will add them to the LLM context even if we're + # interrupted. cartesia gives us word-by-word timestamps. we can use those + # to generate text frames ourselves aligned with the playout timing of the audio! + self._push_text_frames = False + self._api_key = api_key + self._cartesia_version = cartesia_version + self._url = url self._voice_id = voice_id self._model_id = model_id self._output_format = { @@ -35,42 +74,152 @@ def __init__( "encoding": encoding, "sample_rate": sample_rate, } - self._client = None + self._language = language + + self._websocket = None + self._context_id = None + self._context_id_start_timestamp = None + self._timestamped_words_buffer = [] + self._receive_task = None + self._context_appending_task = None + self._waiting_for_ttfb = False def can_generate_metrics(self) -> bool: return True async def start(self, frame: StartFrame): + await super().start(frame) + await self._connect() + + async def stop(self, frame: EndFrame): + await super().stop(frame) + await self._disconnect() + + async def _connect(self): try: - self._client = AsyncCartesia(api_key=self._api_key) - self._voice = self._client.voices.get(id=self._voice_id) + self._websocket = await websockets.connect( + f"{self._url}?api_key={self._api_key}&cartesia_version={self._cartesia_version}" + ) + self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) + self._context_appending_task = self.get_event_loop().create_task(self._context_appending_task_handler()) except Exception as e: logger.exception(f"{self} initialization error: {e}") + self._websocket = None - async def stop(self, frame: EndFrame): - if self._client: - await self._client.close() + async def _disconnect(self): + try: + if self._context_appending_task: + self._context_appending_task.cancel() + self._context_appending_task = None + if self._receive_task: + self._receive_task.cancel() + self._receive_task = None + if self._websocket: + ws = self._websocket + self._websocket = None + await ws.close() + self._context_id = None + self._context_id_start_timestamp = None + self._timestamped_words_buffer = [] + self._waiting_for_ttfb = False + await self.stop_all_metrics() + except Exception as e: + logger.exception(f"{self} error closing websocket: {e}") - async def cancel(self, frame: CancelFrame): - if self._client: - await self._client.close() + async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection): + await super()._handle_interruption(frame, direction) + self._context_id = None + self._context_id_start_timestamp = None + self._timestamped_words_buffer = [] + await self.stop_all_metrics() + await self.push_frame(LLMFullResponseEndFrame()) + + async def _receive_task_handler(self): + try: + async for message in self._websocket: + msg = json.loads(message) + # logger.debug(f"Received message: {msg['type']} {msg['context_id']}") + if not msg or msg["context_id"] != self._context_id: + continue + if msg["type"] == "done": + # unset _context_id but not the _context_id_start_timestamp because we are likely still + # playing out audio and need the timestamp to set send context frames + self._context_id = None + self._timestamped_words_buffer.append(("LLMFullResponseEndFrame", 0)) + elif msg["type"] == "timestamps": + # logger.debug(f"TIMESTAMPS: {msg}") + self._timestamped_words_buffer.extend( + list(zip(msg["word_timestamps"]["words"], msg["word_timestamps"]["end"])) + ) + elif msg["type"] == "chunk": + if not self._context_id_start_timestamp: + self._context_id_start_timestamp = time.time() + if self._waiting_for_ttfb: + await self.stop_ttfb_metrics() + self._waiting_for_ttfb = False + frame = AudioRawFrame( + audio=base64.b64decode(msg["data"]), + sample_rate=self._output_format["sample_rate"], + num_channels=1 + ) + await self.push_frame(frame) + except Exception as e: + logger.exception(f"{self} exception: {e}") + + async def _context_appending_task_handler(self): + try: + while True: + await asyncio.sleep(0.1) + if not self._context_id_start_timestamp: + continue + elapsed_seconds = time.time() - self._context_id_start_timestamp + # pop all words from self._timestamped_words_buffer that are older than the + # elapsed time and print a message about them to the console + while self._timestamped_words_buffer and self._timestamped_words_buffer[0][1] <= elapsed_seconds: + word, timestamp = self._timestamped_words_buffer.pop(0) + if word == "LLMFullResponseEndFrame" and timestamp == 0: + await self.push_frame(LLMFullResponseEndFrame()) + continue + # print(f"Word '{word}' with timestamp {timestamp:.2f}s has been spoken.") + await self.push_frame(TextFrame(word)) + except Exception as e: + logger.exception(f"{self} exception: {e}") async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: logger.debug(f"Generating TTS: [{text}]") try: - await self.start_ttfb_metrics() - - chunk_generator = await self._client.tts.sse( - stream=True, - transcript=text, - voice_embedding=self._voice["embedding"], - model_id=self._model_id, - output_format=self._output_format, - ) - - async for chunk in chunk_generator: - await self.stop_ttfb_metrics() - yield AudioRawFrame(chunk["audio"], self._output_format["sample_rate"], 1) + if not self._websocket: + await self._connect() + + if not self._waiting_for_ttfb: + await self.start_ttfb_metrics() + self._waiting_for_ttfb = True + + if not self._context_id: + self._context_id = str(uuid.uuid4()) + + msg = { + "transcript": text + " ", + "continue": True, + "context_id": self._context_id, + "model_id": self._model_id, + "voice": { + "mode": "id", + "id": self._voice_id + }, + "output_format": self._output_format, + "language": self._language, + "add_timestamps": True, + } + # logger.debug(f"SENDING MESSAGE {json.dumps(msg)}") + try: + await self._websocket.send(json.dumps(msg)) + except Exception as e: + logger.exception(f"{self} error sending message: {e}") + await self._disconnect() + await self._connect() + return + yield None except Exception as e: logger.exception(f"{self} exception: {e}")