From 653fbb7e3ecd7607d137d0d4974bd5413511f220 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Mon, 16 Dec 2024 14:55:33 -0800 Subject: [PATCH] services: fix infinite websocket-bases TTS services retries Fixes #871 --- CHANGELOG.md | 7 +++ pyproject.toml | 3 +- src/pipecat/services/cartesia.py | 94 +++++++++++++++++------------- src/pipecat/services/elevenlabs.py | 50 +++++++++++----- src/pipecat/services/lmnt.py | 55 +++++++++++------ src/pipecat/services/playht.py | 63 ++++++++++++-------- 6 files changed, 173 insertions(+), 99 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f4691396..8b4482c1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to **Pipecat** will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.0.51] - 2024-12-16 + +### Fixed + +- Fixed an issue in websocket-based TTS services that was causing infinite + reconnections (Cartesia, ElevenLabs, PlayHT and LMNT). + ## [0.0.50] - 2024-12-11 ### Added diff --git a/pyproject.toml b/pyproject.toml index 695f47ed1..acb682b8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "pydantic~=2.8.2", "pyloudnorm~=0.1.1", "resampy~=0.4.3", + "tenacity~=9.0.0" ] [project.urls] @@ -55,7 +56,7 @@ gstreamer = [ "pygobject~=3.48.2" ] fireworks = [ "openai~=1.50.2" ] krisp = [ "pipecat-ai-krisp~=0.3.0" ] langchain = [ "langchain~=0.2.14", "langchain-community~=0.2.12", "langchain-openai~=0.1.20" ] -livekit = [ "livekit~=0.17.5", "livekit-api~=0.7.1", "tenacity~=8.5.0" ] +livekit = [ "livekit~=0.17.5", "livekit-api~=0.7.1" ] lmnt = [ "lmnt~=1.1.4" ] local = [ "pyaudio~=0.2.14" ] moondream = [ "einops~=0.8.0", "timm~=1.0.8", "transformers~=4.44.0" ] diff --git a/src/pipecat/services/cartesia.py b/src/pipecat/services/cartesia.py index 8683fd29a..88fbf29d4 100644 --- a/src/pipecat/services/cartesia.py +++ b/src/pipecat/services/cartesia.py @@ -12,6 +12,8 @@ from loguru import logger from pydantic import BaseModel +from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential + from pipecat.frames.frames import ( BotStoppedSpeakingFrame, @@ -239,52 +241,64 @@ async def flush_audio(self): msg = self._build_msg(text="", continue_transcript=False) await self._websocket.send(msg) + async def _receive_messages(self): + async for message in self._get_websocket(): + msg = json.loads(message) + if not msg or msg["context_id"] != self._context_id: + continue + if msg["type"] == "done": + await self.stop_ttfb_metrics() + # Unset _context_id but not the _context_id_start_timestamp + # because we are likely still playing out audio and need the + # timestamp to set send context frames. + self._context_id = None + await self.add_word_timestamps( + [("TTSStoppedFrame", 0), ("LLMFullResponseEndFrame", 0), ("Reset", 0)] + ) + elif msg["type"] == "timestamps": + await self.add_word_timestamps( + list(zip(msg["word_timestamps"]["words"], msg["word_timestamps"]["start"])) + ) + elif msg["type"] == "chunk": + await self.stop_ttfb_metrics() + self.start_word_timestamps() + frame = TTSAudioRawFrame( + audio=base64.b64decode(msg["data"]), + sample_rate=self._settings["output_format"]["sample_rate"], + num_channels=1, + ) + await self.push_frame(frame) + elif msg["type"] == "error": + logger.error(f"{self} error: {msg}") + await self.push_frame(TTSStoppedFrame()) + await self.stop_all_metrics() + await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}')) + else: + logger.error(f"{self} error, unknown message type: {msg}") + + async def _reconnect_websocket(self, retry_state: RetryCallState): + logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})") + await self._disconnect_websocket() + await self._connect_websocket() + async def _receive_task_handler(self): while True: try: - async for message in self._get_websocket(): - msg = json.loads(message) - if not msg or msg["context_id"] != self._context_id: - continue - if msg["type"] == "done": - await self.stop_ttfb_metrics() - # Unset _context_id but not the _context_id_start_timestamp - # because we are likely still playing out audio and need the - # timestamp to set send context frames. - self._context_id = None - await self.add_word_timestamps( - [("TTSStoppedFrame", 0), ("LLMFullResponseEndFrame", 0), ("Reset", 0)] - ) - elif msg["type"] == "timestamps": - await self.add_word_timestamps( - list( - zip( - msg["word_timestamps"]["words"], msg["word_timestamps"]["start"] - ) - ) - ) - elif msg["type"] == "chunk": - await self.stop_ttfb_metrics() - self.start_word_timestamps() - frame = TTSAudioRawFrame( - audio=base64.b64decode(msg["data"]), - sample_rate=self._settings["output_format"]["sample_rate"], - num_channels=1, - ) - await self.push_frame(frame) - elif msg["type"] == "error": - logger.error(f"{self} error: {msg}") - await self.push_frame(TTSStoppedFrame()) - await self.stop_all_metrics() - await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}')) - else: - logger.error(f"{self} error, unknown message type: {msg}") + async for attempt in AsyncRetrying( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + before_sleep=self._reconnect_websocket, + reraise=True, + ): + with attempt: + await self._receive_messages() except asyncio.CancelledError: break except Exception as e: - logger.error(f"{self} exception: {e}") - await self._disconnect_websocket() - await self._connect_websocket() + message = f"{self} error receiving messages: {e}" + logger.error(message) + await self.push_error(ErrorFrame(message, fatal=True)) + break async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) diff --git a/src/pipecat/services/elevenlabs.py b/src/pipecat/services/elevenlabs.py index 87d48b4fa..bc4ff90f0 100644 --- a/src/pipecat/services/elevenlabs.py +++ b/src/pipecat/services/elevenlabs.py @@ -11,11 +11,13 @@ from loguru import logger from pydantic import BaseModel, model_validator +from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential from pipecat.frames.frames import ( BotStoppedSpeakingFrame, CancelFrame, EndFrame, + ErrorFrame, Frame, LLMFullResponseEndFrame, StartFrame, @@ -348,28 +350,44 @@ async def _disconnect_websocket(self): except Exception as e: logger.error(f"{self} error closing websocket: {e}") + async def _receive_messages(self): + async for message in self._websocket: + msg = json.loads(message) + if msg.get("audio"): + await self.stop_ttfb_metrics() + self.start_word_timestamps() + + audio = base64.b64decode(msg["audio"]) + frame = TTSAudioRawFrame(audio, self._settings["sample_rate"], 1) + await self.push_frame(frame) + if msg.get("alignment"): + word_times = calculate_word_times(msg["alignment"], self._cumulative_time) + await self.add_word_timestamps(word_times) + self._cumulative_time = word_times[-1][1] + + async def _reconnect_websocket(self, retry_state: RetryCallState): + logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})") + await self._disconnect_websocket() + await self._connect_websocket() + async def _receive_task_handler(self): while True: try: - async for message in self._websocket: - msg = json.loads(message) - if msg.get("audio"): - await self.stop_ttfb_metrics() - self.start_word_timestamps() - - audio = base64.b64decode(msg["audio"]) - frame = TTSAudioRawFrame(audio, self._settings["sample_rate"], 1) - await self.push_frame(frame) - if msg.get("alignment"): - word_times = calculate_word_times(msg["alignment"], self._cumulative_time) - await self.add_word_timestamps(word_times) - self._cumulative_time = word_times[-1][1] + async for attempt in AsyncRetrying( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + before_sleep=self._reconnect_websocket, + reraise=True, + ): + with attempt: + await self._receive_messages() except asyncio.CancelledError: break except Exception as e: - logger.error(f"{self} exception: {e}") - await self._disconnect_websocket() - await self._connect_websocket() + message = f"{self} error receiving messages: {e}" + logger.error(message) + await self.push_error(ErrorFrame(message, fatal=True)) + break async def _keepalive_task_handler(self): while True: diff --git a/src/pipecat/services/lmnt.py b/src/pipecat/services/lmnt.py index 04223a1a1..5393e6653 100644 --- a/src/pipecat/services/lmnt.py +++ b/src/pipecat/services/lmnt.py @@ -8,6 +8,7 @@ from typing import AsyncGenerator from loguru import logger +from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential from pipecat.frames.frames import ( CancelFrame, @@ -159,31 +160,47 @@ async def _disconnect_lmnt(self): except Exception as e: logger.error(f"{self} error closing connection: {e}") + async def _receive_messages(self): + async for msg in self._connection: + if "error" in msg: + logger.error(f'{self} error: {msg["error"]}') + await self.push_frame(TTSStoppedFrame()) + await self.stop_all_metrics() + await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}')) + elif "audio" in msg: + await self.stop_ttfb_metrics() + frame = TTSAudioRawFrame( + audio=msg["audio"], + sample_rate=self._settings["output_format"]["sample_rate"], + num_channels=1, + ) + await self.push_frame(frame) + else: + logger.error(f"{self}: LMNT error, unknown message type: {msg}") + + async def _reconnect_websocket(self, retry_state: RetryCallState): + logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})") + await self._disconnect_lmnt() + await self._connect_lmnt() + async def _receive_task_handler(self): while True: try: - async for msg in self._connection: - if "error" in msg: - logger.error(f'{self} error: {msg["error"]}') - await self.push_frame(TTSStoppedFrame()) - await self.stop_all_metrics() - await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}')) - elif "audio" in msg: - await self.stop_ttfb_metrics() - frame = TTSAudioRawFrame( - audio=msg["audio"], - sample_rate=self._settings["output_format"]["sample_rate"], - num_channels=1, - ) - await self.push_frame(frame) - else: - logger.error(f"{self}: LMNT error, unknown message type: {msg}") + async for attempt in AsyncRetrying( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + before_sleep=self._reconnect_websocket, + reraise=True, + ): + with attempt: + await self._receive_messages() except asyncio.CancelledError: break except Exception as e: - logger.error(f"{self} exception: {e}") - await self._disconnect_lmnt() - await self._connect_lmnt() + message = f"{self} error receiving messages: {e}" + logger.error(message) + await self.push_error(ErrorFrame(message, fatal=True)) + break async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: logger.debug(f"Generating TTS: [{text}]") diff --git a/src/pipecat/services/playht.py b/src/pipecat/services/playht.py index 78be924af..06b493fc1 100644 --- a/src/pipecat/services/playht.py +++ b/src/pipecat/services/playht.py @@ -15,6 +15,7 @@ import websockets from loguru import logger from pydantic import BaseModel +from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential from pipecat.frames.frames import ( BotStoppedSpeakingFrame, @@ -217,35 +218,51 @@ async def _handle_interruption(self, frame: StartInterruptionFrame, direction: F await self.stop_all_metrics() self._request_id = None + async def _receive_messages(self): + async for message in self._get_websocket(): + if isinstance(message, bytes): + # Skip the WAV header message + if message.startswith(b"RIFF"): + continue + await self.stop_ttfb_metrics() + frame = TTSAudioRawFrame(message, self._settings["sample_rate"], 1) + await self.push_frame(frame) + else: + logger.debug(f"Received text message: {message}") + try: + msg = json.loads(message) + if "request_id" in msg and msg["request_id"] == self._request_id: + await self.push_frame(TTSStoppedFrame()) + self._request_id = None + elif "error" in msg: + logger.error(f"{self} error: {msg}") + await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}')) + except json.JSONDecodeError: + logger.error(f"Invalid JSON message: {message}") + + async def _reconnect_websocket(self, retry_state: RetryCallState): + logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})") + await self._disconnect_websocket() + await self._connect_websocket() + async def _receive_task_handler(self): while True: try: - async for message in self._get_websocket(): - if isinstance(message, bytes): - # Skip the WAV header message - if message.startswith(b"RIFF"): - continue - await self.stop_ttfb_metrics() - frame = TTSAudioRawFrame(message, self._settings["sample_rate"], 1) - await self.push_frame(frame) - else: - logger.debug(f"Received text message: {message}") - try: - msg = json.loads(message) - if "request_id" in msg and msg["request_id"] == self._request_id: - await self.push_frame(TTSStoppedFrame()) - self._request_id = None - elif "error" in msg: - logger.error(f"{self} error: {msg}") - await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}')) - except json.JSONDecodeError: - logger.error(f"Invalid JSON message: {message}") + async for attempt in AsyncRetrying( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + before_sleep=self._reconnect_websocket, + reraise=True, + ): + with attempt: + await self._receive_messages() except asyncio.CancelledError: break except Exception as e: - logger.error(f"{self} exception in receive task: {e}") - await self._disconnect_websocket() - await self._connect_websocket() + message = f"{self} error receiving messages: {e}" + logger.error(message) + await self.push_error(ErrorFrame(message, fatal=True)) + break async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction)