diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c8862d4f..d6655ed9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added a new `WebsocketService` based class for TTS services, containing + base functions and retry logic. + - Added `DeepSeekLLMService` for DeepSeek integration with an OpenAI-compatible interface. Added foundational example `14l-function-calling-deepseek.py`. @@ -61,6 +64,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed an issue where websocket based TTS services could incorrectly terminate + their connection due to a retry counter not resetting. + - Fixed a `PipelineTask` issue that would cause a dangling task after stopping the pipeline with an `EndFrame`. diff --git a/src/pipecat/services/cartesia.py b/src/pipecat/services/cartesia.py index 4e0bb111c..9712ce600 100644 --- a/src/pipecat/services/cartesia.py +++ b/src/pipecat/services/cartesia.py @@ -4,7 +4,6 @@ # SPDX-License-Identifier: BSD 2-Clause License # -import asyncio import base64 import json import uuid @@ -12,7 +11,6 @@ from loguru import logger from pydantic import BaseModel -from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential from pipecat.frames.frames import ( BotStoppedSpeakingFrame, @@ -30,6 +28,7 @@ ) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.ai_services import TTSService, WordTTSService +from pipecat.services.websocket_service import WebsocketService from pipecat.transcriptions.language import Language # See .env.example for Cartesia configuration needed @@ -76,7 +75,7 @@ def language_to_cartesia_language(language: Language) -> str | None: return result -class CartesiaTTSService(WordTTSService): +class CartesiaTTSService(WordTTSService, WebsocketService): class InputParams(BaseModel): language: Optional[Language] = Language.EN speed: Optional[Union[str, float]] = "" @@ -106,12 +105,14 @@ def __init__( # if we're interrupted. Cartesia gives us word-by-word timestamps. We # can use those to generate text frames ourselves aligned with the # playout timing of the audio! - super().__init__( + WordTTSService.__init__( + self, aggregate_sentences=True, push_text_frames=False, sample_rate=sample_rate, **kwargs, ) + WebsocketService.__init__(self) self._api_key = api_key self._cartesia_version = cartesia_version @@ -131,7 +132,6 @@ def __init__( self.set_model_name(model) self.set_voice(voice_id) - self._websocket = None self._context_id = None self._receive_task = None @@ -187,7 +187,9 @@ async def cancel(self, frame: CancelFrame): async def _connect(self): await self._connect_websocket() - self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) + self._receive_task = self.get_event_loop().create_task( + self._receive_task_handler(self.push_error) + ) async def _disconnect(self): await self._disconnect_websocket() @@ -275,30 +277,6 @@ async def _receive_messages(self): else: logger.error(f"{self} error, unknown message type: {msg}") - async def _reconnect_websocket(self, retry_state: RetryCallState): - logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})") - await self._disconnect_websocket() - await self._connect_websocket() - - async def _receive_task_handler(self): - while True: - try: - async for attempt in AsyncRetrying( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - before_sleep=self._reconnect_websocket, - reraise=True, - ): - with attempt: - await self._receive_messages() - except asyncio.CancelledError: - break - except Exception as e: - message = f"{self} error receiving messages: {e}" - logger.error(message) - await self.push_error(ErrorFrame(message, fatal=True)) - break - async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) diff --git a/src/pipecat/services/elevenlabs.py b/src/pipecat/services/elevenlabs.py index 2b598550a..c1d326dfc 100644 --- a/src/pipecat/services/elevenlabs.py +++ b/src/pipecat/services/elevenlabs.py @@ -11,13 +11,11 @@ from loguru import logger from pydantic import BaseModel, model_validator -from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential from pipecat.frames.frames import ( BotStoppedSpeakingFrame, CancelFrame, EndFrame, - ErrorFrame, Frame, LLMFullResponseEndFrame, StartFrame, @@ -29,6 +27,7 @@ ) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.ai_services import WordTTSService +from pipecat.services.websocket_service import WebsocketService from pipecat.transcriptions.language import Language # See .env.example for ElevenLabs configuration needed @@ -133,7 +132,7 @@ def calculate_word_times( return word_times -class ElevenLabsTTSService(WordTTSService): +class ElevenLabsTTSService(WordTTSService, WebsocketService): class InputParams(BaseModel): language: Optional[Language] = Language.EN optimize_streaming_latency: Optional[str] = None @@ -178,7 +177,8 @@ def __init__( # Finally, ElevenLabs doesn't provide information on when the bot stops # speaking for a while, so we want the parent class to send TTSStopFrame # after a short period not receiving any audio. - super().__init__( + WordTTSService.__init__( + self, aggregate_sentences=True, push_text_frames=False, push_stop_frames=True, @@ -186,6 +186,7 @@ def __init__( sample_rate=sample_rate_from_output_format(output_format), **kwargs, ) + WebsocketService.__init__(self) self._api_key = api_key self._url = url @@ -206,8 +207,6 @@ def __init__( self.set_voice(voice_id) self._voice_settings = self._set_voice_settings() - # Websocket connection to ElevenLabs. - self._websocket = None # Indicates if we have sent TTSStartedFrame. It will reset to False when # there's an interruption or TTSStoppedFrame. self._started = False @@ -297,7 +296,9 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): async def _connect(self): await self._connect_websocket() - self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) + self._receive_task = self.get_event_loop().create_task( + self._receive_task_handler(self.push_error) + ) self._keepalive_task = self.get_event_loop().create_task(self._keepalive_task_handler()) async def _disconnect(self): @@ -377,30 +378,6 @@ async def _receive_messages(self): await self.add_word_timestamps(word_times) self._cumulative_time = word_times[-1][1] - async def _reconnect_websocket(self, retry_state: RetryCallState): - logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})") - await self._disconnect_websocket() - await self._connect_websocket() - - async def _receive_task_handler(self): - while True: - try: - async for attempt in AsyncRetrying( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - before_sleep=self._reconnect_websocket, - reraise=True, - ): - with attempt: - await self._receive_messages() - except asyncio.CancelledError: - break - except Exception as e: - message = f"{self} error receiving messages: {e}" - logger.error(message) - await self.push_error(ErrorFrame(message, fatal=True)) - break - async def _keepalive_task_handler(self): while True: try: diff --git a/src/pipecat/services/lmnt.py b/src/pipecat/services/lmnt.py index 2ce0f0096..633c24265 100644 --- a/src/pipecat/services/lmnt.py +++ b/src/pipecat/services/lmnt.py @@ -4,11 +4,10 @@ # SPDX-License-Identifier: BSD 2-Clause License # -import asyncio +import json from typing import AsyncGenerator from loguru import logger -from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential from pipecat.frames.frames import ( CancelFrame, @@ -23,11 +22,12 @@ ) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.ai_services import TTSService +from pipecat.services.websocket_service import WebsocketService from pipecat.transcriptions.language import Language # See .env.example for LMNT configuration needed try: - from lmnt.api import Speech + import websockets except ModuleNotFoundError as e: logger.error(f"Exception: {e}") logger.error( @@ -60,7 +60,7 @@ def language_to_lmnt_language(language: Language) -> str | None: return result -class LmntTTSService(TTSService): +class LmntTTSService(TTSService, WebsocketService): def __init__( self, *, @@ -70,27 +70,21 @@ def __init__( language: Language = Language.EN, **kwargs, ): - # Let TTSService produce TTSStoppedFrames after a short delay of - # no activity. - super().__init__(push_stop_frames=True, sample_rate=sample_rate, **kwargs) + TTSService.__init__( + self, + push_stop_frames=True, + sample_rate=sample_rate, + **kwargs, + ) + WebsocketService.__init__(self) self._api_key = api_key + self._voice_id = voice_id self._settings = { - "output_format": { - "container": "raw", - "encoding": "pcm_s16le", - "sample_rate": sample_rate, - }, + "sample_rate": sample_rate, "language": self.language_to_service_language(language), + "format": "raw", # Use raw format for direct PCM data } - - self.set_voice(voice_id) - - self._speech = None - self._connection = None - self._receive_task = None - # Indicates if we have sent TTSStartedFrame. It will reset to False when - # there's an interruption or TTSStoppedFrame. self._started = False def can_generate_metrics(self) -> bool: @@ -117,106 +111,107 @@ async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirect self._started = False async def _connect(self): - await self._connect_lmnt() + await self._connect_websocket() - self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) + self._receive_task = self.get_event_loop().create_task( + self._receive_task_handler(self.push_error) + ) async def _disconnect(self): - await self._disconnect_lmnt() + await self._disconnect_websocket() if self._receive_task: self._receive_task.cancel() await self._receive_task self._receive_task = None - async def _connect_lmnt(self): + async def _connect_websocket(self): + """Connect to LMNT websocket.""" try: logger.debug("Connecting to LMNT") - self._speech = Speech() - self._connection = await self._speech.synthesize_streaming( - self._voice_id, - format="raw", - sample_rate=self._settings["output_format"]["sample_rate"], - language=self._settings["language"], - ) + # Build initial connection message + init_msg = { + "X-API-Key": self._api_key, + "voice": self._voice_id, + "format": self._settings["format"], + "sample_rate": self._settings["sample_rate"], + "language": self._settings["language"], + } + + # Connect to LMNT's websocket directly + self._websocket = await websockets.connect("wss://api.lmnt.com/v1/ai/speech/stream") + + # Send initialization message + await self._websocket.send(json.dumps(init_msg)) + except Exception as e: logger.error(f"{self} initialization error: {e}") - self._connection = None + self._websocket = None - async def _disconnect_lmnt(self): + async def _disconnect_websocket(self): + """Disconnect from LMNT websocket.""" try: await self.stop_all_metrics() - if self._connection: + if self._websocket: logger.debug("Disconnecting from LMNT") - await self._connection.socket.close() - self._connection = None - if self._speech: - await self._speech.close() - self._speech = None + # Send EOF message before closing + await self._websocket.send(json.dumps({"eof": True})) + await self._websocket.close() + self._websocket = None self._started = False except Exception as e: - logger.error(f"{self} error closing connection: {e}") + logger.error(f"{self} error closing websocket: {e}") + + def _get_websocket(self): + if self._websocket: + return self._websocket + raise Exception("Websocket not connected") async def _receive_messages(self): - async for msg in self._connection: - if "error" in msg: - logger.error(f'{self} error: {msg["error"]}') - await self.push_frame(TTSStoppedFrame()) - await self.stop_all_metrics() - await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}')) - elif "audio" in msg: + """Receive messages from LMNT websocket.""" + async for message in self._get_websocket(): + if isinstance(message, bytes): + # Raw audio data await self.stop_ttfb_metrics() frame = TTSAudioRawFrame( - audio=msg["audio"], - sample_rate=self._settings["output_format"]["sample_rate"], + audio=message, + sample_rate=self._settings["sample_rate"], num_channels=1, ) await self.push_frame(frame) else: - logger.error(f"{self}: LMNT error, unknown message type: {msg}") - - async def _reconnect_websocket(self, retry_state: RetryCallState): - logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})") - await self._disconnect_lmnt() - await self._connect_lmnt() - - async def _receive_task_handler(self): - while True: - try: - async for attempt in AsyncRetrying( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - before_sleep=self._reconnect_websocket, - reraise=True, - ): - with attempt: - await self._receive_messages() - except asyncio.CancelledError: - break - except Exception as e: - message = f"{self} error receiving messages: {e}" - logger.error(message) - await self.push_error(ErrorFrame(message, fatal=True)) - break + try: + msg = json.loads(message) + if "error" in msg: + logger.error(f'{self} error: {msg["error"]}') + await self.push_frame(TTSStoppedFrame()) + await self.stop_all_metrics() + await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}')) + return + except json.JSONDecodeError: + logger.error(f"Invalid JSON message: {message}") async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + """Generate TTS audio from text.""" logger.debug(f"Generating TTS: [{text}]") try: - if not self._connection: + if not self._websocket: await self._connect() - if not self._started: - await self.start_ttfb_metrics() - yield TTSStartedFrame() - self._started = True - try: - await self._connection.append_text(text) - await self._connection.flush() + if not self._started: + await self.start_ttfb_metrics() + yield TTSStartedFrame() + self._started = True + + # Send text to LMNT + await self._get_websocket().send(json.dumps({"text": text})) + # Force synthesis + await self._get_websocket().send(json.dumps({"flush": True})) await self.start_tts_usage_metrics(text) except Exception as e: logger.error(f"{self} error sending message: {e}") diff --git a/src/pipecat/services/playht.py b/src/pipecat/services/playht.py index 7454bb16a..a511e2456 100644 --- a/src/pipecat/services/playht.py +++ b/src/pipecat/services/playht.py @@ -4,7 +4,6 @@ # SPDX-License-Identifier: BSD 2-Clause License # -import asyncio import io import json import struct @@ -15,7 +14,6 @@ import websockets from loguru import logger from pydantic import BaseModel -from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential from pipecat.frames.frames import ( BotStoppedSpeakingFrame, @@ -33,6 +31,7 @@ ) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.ai_services import TTSService +from pipecat.services.websocket_service import WebsocketService from pipecat.transcriptions.language import Language try: @@ -101,7 +100,7 @@ def language_to_playht_language(language: Language) -> str | None: return result -class PlayHTTTSService(TTSService): +class PlayHTTTSService(TTSService, WebsocketService): class InputParams(BaseModel): language: Optional[Language] = Language.EN speed: Optional[float] = 1.0 @@ -119,15 +118,16 @@ def __init__( params: InputParams = InputParams(), **kwargs, ): - super().__init__( + TTSService.__init__( + self, sample_rate=sample_rate, **kwargs, ) + WebsocketService.__init__(self) self._api_key = api_key self._user_id = user_id self._websocket_url = None - self._websocket = None self._receive_task = None self._request_id = None @@ -165,7 +165,9 @@ async def cancel(self, frame: CancelFrame): async def _connect(self): await self._connect_websocket() - self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) + self._receive_task = self.get_event_loop().create_task( + self._receive_task_handler(self.push_error) + ) async def _disconnect(self): await self._disconnect_websocket() @@ -271,30 +273,6 @@ async def _receive_messages(self): except json.JSONDecodeError: logger.error(f"Invalid JSON message: {message}") - async def _reconnect_websocket(self, retry_state: RetryCallState): - logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})") - await self._disconnect_websocket() - await self._connect_websocket() - - async def _receive_task_handler(self): - while True: - try: - async for attempt in AsyncRetrying( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - before_sleep=self._reconnect_websocket, - reraise=True, - ): - with attempt: - await self._receive_messages() - except asyncio.CancelledError: - break - except Exception as e: - message = f"{self} error receiving messages: {e}" - logger.error(message) - await self.push_error(ErrorFrame(message, fatal=True)) - break - async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) diff --git a/src/pipecat/services/websocket_service.py b/src/pipecat/services/websocket_service.py new file mode 100644 index 000000000..365f5a7c8 --- /dev/null +++ b/src/pipecat/services/websocket_service.py @@ -0,0 +1,124 @@ +# +# Copyright (c) 2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +from abc import ABC, abstractmethod +from typing import Awaitable, Callable, Optional + +import websockets +from loguru import logger + +from pipecat.frames.frames import ErrorFrame + + +class WebsocketService(ABC): + """Base class for websocket-based services with reconnection logic.""" + + def __init__(self): + """Initialize websocket attributes.""" + self._websocket: Optional[websockets.WebSocketClientProtocol] = None + + async def _verify_connection(self) -> bool: + """Verify websocket connection is working. + + Returns: + bool: True if connection is verified working, False otherwise + """ + try: + if not self._websocket: + return False + await self._websocket.ping() + return True + except Exception as e: + logger.error(f"{self} connection verification failed: {e}") + return False + + async def _reconnect_websocket(self, attempt_number: int) -> bool: + """Reconnect the websocket. + + Args: + attempt_number: Current retry attempt number + + Returns: + bool: True if reconnection and verification successful, False otherwise + """ + logger.warning(f"{self} reconnecting (attempt: {attempt_number})") + await self._disconnect_websocket() + await self._connect_websocket() + return await self._verify_connection() + + def _calculate_wait_time( + self, attempt: int, min_wait: float = 4, max_wait: float = 10, multiplier: float = 1 + ) -> float: + """Calculate exponential backoff wait time. + + Args: + attempt: Current attempt number (1-based) + min_wait: Minimum wait time in seconds + max_wait: Maximum wait time in seconds + multiplier: Base multiplier for exponential calculation + + Returns: + Wait time in seconds + """ + try: + exp = 2 ** (attempt - 1) * multiplier + result = max(0, min(exp, max_wait)) + return max(min_wait, result) + except (ValueError, ArithmeticError): + return max_wait + + async def _receive_task_handler(self, report_error: Callable[[ErrorFrame], Awaitable[None]]): + """Handles WebSocket message receiving with automatic retry logic. + + Args: + report_error: Callback to report errors + """ + retry_count = 0 + MAX_RETRIES = 3 + + while True: + try: + await self._receive_messages() + logger.debug(f"{self} connection established successfully") + retry_count = 0 # Reset counter on successful message receive + + except asyncio.CancelledError: + break + + except Exception as e: + retry_count += 1 + if retry_count >= MAX_RETRIES: + message = f"{self} error receiving messages: {e}" + logger.error(message) + await report_error(ErrorFrame(message, fatal=True)) + break + + logger.warning(f"{self} connection error, will retry: {e}") + + try: + if await self._reconnect_websocket(retry_count): + retry_count = 0 # Reset counter on successful reconnection + wait_time = self._calculate_wait_time(retry_count) + await asyncio.sleep(wait_time) + except Exception as reconnect_error: + logger.error(f"{self} reconnection failed: {reconnect_error}") + continue + + @abstractmethod + async def _connect_websocket(self): + """Implement service-specific websocket connection logic.""" + pass + + @abstractmethod + async def _disconnect_websocket(self): + """Implement service-specific websocket disconnection logic.""" + pass + + @abstractmethod + async def _receive_messages(self): + """Implement service-specific message receiving logic.""" + pass