Skip to content

Commit

Permalink
Merge pull request #304 from pipecat-ai/khk/cartesia-continue
Browse files Browse the repository at this point in the history
Cartesia streaming (WebSocket) and word-level timestamps support
  • Loading branch information
kwindla authored Jul 18, 2024
2 parents 5e8e11e + 355fe01 commit d1b62c5
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 34 deletions.
6 changes: 4 additions & 2 deletions examples/foundational/07d-interruptible-cartesia.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ async def main(room_url: str, token):
token,
"Respond bot",
DailyParams(
audio_out_sample_rate=44100,
audio_out_enabled=True,
transcription_enabled=True,
vad_enabled=True,
Expand All @@ -47,6 +48,7 @@ async def main(room_url: str, token):
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="a0e99841-438c-4a64-b679-ae501e7d6091", # Barbershop Man
sample_rate=44100,
)

llm = OpenAILLMService(
Expand All @@ -68,11 +70,11 @@ async def main(room_url: str, token):
tma_in, # User responses
llm, # LLM
tts, # TTS
tma_out, # Goes before the transport because cartesia has word-level timestamps!
transport.output(), # Transport bot output
tma_out # Assistant spoken responses
])

task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True, enable_metrics=True))

@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Website = "https://pipecat.ai"
[project.optional-dependencies]
anthropic = [ "anthropic~=0.28.1" ]
azure = [ "azure-cognitiveservices-speech~=1.38.0" ]
cartesia = [ "cartesia~=1.0.3" ]
cartesia = [ "websockets~=12.0" ]
daily = [ "daily-python~=0.10.1" ]
deepgram = [ "deepgram-sdk~=3.2.7" ]
examples = [ "python-dotenv~=1.0.0", "flask~=3.0.3", "flask_cors~=4.0.1" ]
Expand Down
29 changes: 22 additions & 7 deletions src/pipecat/services/ai_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,16 @@ async def call_start_function(self, function_name: str):


class TTSService(AIService):
def __init__(self, *, aggregate_sentences: bool = True, **kwargs):
def __init__(
self,
*,
aggregate_sentences: bool = True,
# if True, subclass is responsible for pushing TextFrames and LLMFullResponseEndFrames
push_text_frames: bool = True,
**kwargs):
super().__init__(**kwargs)
self._aggregate_sentences: bool = aggregate_sentences
self._push_text_frames: bool = push_text_frames
self._current_sentence: str = ""

# Converts the text to audio.
Expand All @@ -149,6 +156,10 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
async def say(self, text: str):
await self.process_frame(TextFrame(text=text), FrameDirection.DOWNSTREAM)

async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
self._current_sentence = ""
await self.push_frame(frame, direction)

async def _process_text_frame(self, frame: TextFrame):
text: str | None = None
if not self._aggregate_sentences:
Expand All @@ -172,22 +183,26 @@ async def _push_tts_frames(self, text: str):
await self.process_generator(self.run_tts(text))
await self.stop_processing_metrics()
await self.push_frame(TTSStoppedFrame())
# We send the original text after the audio. This way, if we are
# interrupted, the text is not added to the assistant context.
await self.push_frame(TextFrame(text))
if self._push_text_frames:
# We send the original text after the audio. This way, if we are
# interrupted, the text is not added to the assistant context.
await self.push_frame(TextFrame(text))

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

if isinstance(frame, TextFrame):
await self._process_text_frame(frame)
elif isinstance(frame, StartInterruptionFrame):
self._current_sentence = ""
await self.push_frame(frame, direction)
await self._handle_interruption(frame, direction)
elif isinstance(frame, LLMFullResponseEndFrame) or isinstance(frame, EndFrame):
self._current_sentence = ""
await self._push_tts_frames(self._current_sentence)
await self.push_frame(frame)
if isinstance(frame, LLMFullResponseEndFrame):
if self._push_text_frames:
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)

Expand Down
197 changes: 173 additions & 24 deletions src/pipecat/services/cartesia.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,37 @@
# SPDX-License-Identifier: BSD 2-Clause License
#

from cartesia import AsyncCartesia
import json
import uuid
import base64
import asyncio
import time

from typing import AsyncGenerator

from pipecat.frames.frames import AudioRawFrame, CancelFrame, EndFrame, Frame, StartFrame
from pipecat.processors.frame_processor import FrameDirection
from pipecat.frames.frames import (
Frame,
AudioRawFrame,
StartInterruptionFrame,
StartFrame,
EndFrame,
TextFrame,
LLMFullResponseEndFrame
)
from pipecat.services.ai_services import TTSService

from loguru import logger

# See .env.example for Cartesia configuration needed
try:
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use Cartesia, you need to `pip install pipecat-ai[cartesia]`. Also, set `CARTESIA_API_KEY` environment variable.")
raise Exception(f"Missing module: {e}")


class CartesiaTTSService(TTSService):

Expand All @@ -21,56 +43,183 @@ def __init__(
*,
api_key: str,
voice_id: str,
cartesia_version: str = "2024-06-10",
url: str = "wss://api.cartesia.ai/tts/websocket",
model_id: str = "sonic-english",
encoding: str = "pcm_s16le",
sample_rate: int = 16000,
language: str = "en",
**kwargs):
super().__init__(**kwargs)

# Aggregating sentences still gives cleaner-sounding results and fewer
# artifacts than streaming one word at a time. On average, waiting for
# a full sentence should only "cost" us 15ms or so with GPT-4o or a Llama 3
# model, and it's worth it for the better audio quality.
self._aggregate_sentences = True

# we don't want to automatically push LLM response text frames, because the
# context aggregators will add them to the LLM context even if we're
# interrupted. cartesia gives us word-by-word timestamps. we can use those
# to generate text frames ourselves aligned with the playout timing of the audio!
self._push_text_frames = False

self._api_key = api_key
self._cartesia_version = cartesia_version
self._url = url
self._voice_id = voice_id
self._model_id = model_id
self._output_format = {
"container": "raw",
"encoding": encoding,
"sample_rate": sample_rate,
}
self._client = None
self._language = language

self._websocket = None
self._context_id = None
self._context_id_start_timestamp = None
self._timestamped_words_buffer = []
self._receive_task = None
self._context_appending_task = None
self._waiting_for_ttfb = False

def can_generate_metrics(self) -> bool:
return True

async def start(self, frame: StartFrame):
await super().start(frame)
await self._connect()

async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._disconnect()

async def _connect(self):
try:
self._client = AsyncCartesia(api_key=self._api_key)
self._voice = self._client.voices.get(id=self._voice_id)
self._websocket = await websockets.connect(
f"{self._url}?api_key={self._api_key}&cartesia_version={self._cartesia_version}"
)
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
self._context_appending_task = self.get_event_loop().create_task(self._context_appending_task_handler())
except Exception as e:
logger.exception(f"{self} initialization error: {e}")
self._websocket = None

async def stop(self, frame: EndFrame):
if self._client:
await self._client.close()
async def _disconnect(self):
try:
if self._context_appending_task:
self._context_appending_task.cancel()
self._context_appending_task = None
if self._receive_task:
self._receive_task.cancel()
self._receive_task = None
if self._websocket:
ws = self._websocket
self._websocket = None
await ws.close()
self._context_id = None
self._context_id_start_timestamp = None
self._timestamped_words_buffer = []
self._waiting_for_ttfb = False
await self.stop_all_metrics()
except Exception as e:
logger.exception(f"{self} error closing websocket: {e}")

async def cancel(self, frame: CancelFrame):
if self._client:
await self._client.close()
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
self._context_id = None
self._context_id_start_timestamp = None
self._timestamped_words_buffer = []
await self.stop_all_metrics()
await self.push_frame(LLMFullResponseEndFrame())

async def _receive_task_handler(self):
try:
async for message in self._websocket:
msg = json.loads(message)
# logger.debug(f"Received message: {msg['type']} {msg['context_id']}")
if not msg or msg["context_id"] != self._context_id:
continue
if msg["type"] == "done":
# unset _context_id but not the _context_id_start_timestamp because we are likely still
# playing out audio and need the timestamp to set send context frames
self._context_id = None
self._timestamped_words_buffer.append(("LLMFullResponseEndFrame", 0))
elif msg["type"] == "timestamps":
# logger.debug(f"TIMESTAMPS: {msg}")
self._timestamped_words_buffer.extend(
list(zip(msg["word_timestamps"]["words"], msg["word_timestamps"]["end"]))
)
elif msg["type"] == "chunk":
if not self._context_id_start_timestamp:
self._context_id_start_timestamp = time.time()
if self._waiting_for_ttfb:
await self.stop_ttfb_metrics()
self._waiting_for_ttfb = False
frame = AudioRawFrame(
audio=base64.b64decode(msg["data"]),
sample_rate=self._output_format["sample_rate"],
num_channels=1
)
await self.push_frame(frame)
except Exception as e:
logger.exception(f"{self} exception: {e}")

async def _context_appending_task_handler(self):
try:
while True:
await asyncio.sleep(0.1)
if not self._context_id_start_timestamp:
continue
elapsed_seconds = time.time() - self._context_id_start_timestamp
# pop all words from self._timestamped_words_buffer that are older than the
# elapsed time and print a message about them to the console
while self._timestamped_words_buffer and self._timestamped_words_buffer[0][1] <= elapsed_seconds:
word, timestamp = self._timestamped_words_buffer.pop(0)
if word == "LLMFullResponseEndFrame" and timestamp == 0:
await self.push_frame(LLMFullResponseEndFrame())
continue
# print(f"Word '{word}' with timestamp {timestamp:.2f}s has been spoken.")
await self.push_frame(TextFrame(word))
except Exception as e:
logger.exception(f"{self} exception: {e}")

async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating TTS: [{text}]")

try:
await self.start_ttfb_metrics()

chunk_generator = await self._client.tts.sse(
stream=True,
transcript=text,
voice_embedding=self._voice["embedding"],
model_id=self._model_id,
output_format=self._output_format,
)

async for chunk in chunk_generator:
await self.stop_ttfb_metrics()
yield AudioRawFrame(chunk["audio"], self._output_format["sample_rate"], 1)
if not self._websocket:
await self._connect()

if not self._waiting_for_ttfb:
await self.start_ttfb_metrics()
self._waiting_for_ttfb = True

if not self._context_id:
self._context_id = str(uuid.uuid4())

msg = {
"transcript": text + " ",
"continue": True,
"context_id": self._context_id,
"model_id": self._model_id,
"voice": {
"mode": "id",
"id": self._voice_id
},
"output_format": self._output_format,
"language": self._language,
"add_timestamps": True,
}
# logger.debug(f"SENDING MESSAGE {json.dumps(msg)}")
try:
await self._websocket.send(json.dumps(msg))
except Exception as e:
logger.exception(f"{self} error sending message: {e}")
await self._disconnect()
await self._connect()
return
yield None
except Exception as e:
logger.exception(f"{self} exception: {e}")

0 comments on commit d1b62c5

Please sign in to comment.