Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

services: fix infinite websocket-bases TTS services retries #872

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ All notable changes to **Pipecat** will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.0.51] - 2024-12-16

### Fixed

- Fixed an issue in websocket-based TTS services that was causing infinite
reconnections (Cartesia, ElevenLabs, PlayHT and LMNT).

## [0.0.50] - 2024-12-11

### Added
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"pydantic~=2.8.2",
"pyloudnorm~=0.1.1",
"resampy~=0.4.3",
"tenacity~=9.0.0"
]

[project.urls]
Expand All @@ -55,7 +56,7 @@ gstreamer = [ "pygobject~=3.48.2" ]
fireworks = [ "openai~=1.50.2" ]
krisp = [ "pipecat-ai-krisp~=0.3.0" ]
langchain = [ "langchain~=0.2.14", "langchain-community~=0.2.12", "langchain-openai~=0.1.20" ]
livekit = [ "livekit~=0.17.5", "livekit-api~=0.7.1", "tenacity~=8.5.0" ]
livekit = [ "livekit~=0.17.5", "livekit-api~=0.7.1" ]
lmnt = [ "lmnt~=1.1.4" ]
local = [ "pyaudio~=0.2.14" ]
moondream = [ "einops~=0.8.0", "timm~=1.0.8", "transformers~=4.44.0" ]
Expand Down
94 changes: 54 additions & 40 deletions src/pipecat/services/cartesia.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from loguru import logger
from pydantic import BaseModel
from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential


from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
Expand Down Expand Up @@ -239,52 +241,64 @@ async def flush_audio(self):
msg = self._build_msg(text="", continue_transcript=False)
await self._websocket.send(msg)

async def _receive_messages(self):
async for message in self._get_websocket():
msg = json.loads(message)
if not msg or msg["context_id"] != self._context_id:
continue
if msg["type"] == "done":
await self.stop_ttfb_metrics()
# Unset _context_id but not the _context_id_start_timestamp
# because we are likely still playing out audio and need the
# timestamp to set send context frames.
self._context_id = None
await self.add_word_timestamps(
[("TTSStoppedFrame", 0), ("LLMFullResponseEndFrame", 0), ("Reset", 0)]
)
elif msg["type"] == "timestamps":
await self.add_word_timestamps(
list(zip(msg["word_timestamps"]["words"], msg["word_timestamps"]["start"]))
)
elif msg["type"] == "chunk":
await self.stop_ttfb_metrics()
self.start_word_timestamps()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["data"]),
sample_rate=self._settings["output_format"]["sample_rate"],
num_channels=1,
)
await self.push_frame(frame)
elif msg["type"] == "error":
logger.error(f"{self} error: {msg}")
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
else:
logger.error(f"{self} error, unknown message type: {msg}")

async def _reconnect_websocket(self, retry_state: RetryCallState):
logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})")
await self._disconnect_websocket()
await self._connect_websocket()

async def _receive_task_handler(self):
while True:
try:
async for message in self._get_websocket():
msg = json.loads(message)
if not msg or msg["context_id"] != self._context_id:
continue
if msg["type"] == "done":
await self.stop_ttfb_metrics()
# Unset _context_id but not the _context_id_start_timestamp
# because we are likely still playing out audio and need the
# timestamp to set send context frames.
self._context_id = None
await self.add_word_timestamps(
[("TTSStoppedFrame", 0), ("LLMFullResponseEndFrame", 0), ("Reset", 0)]
)
elif msg["type"] == "timestamps":
await self.add_word_timestamps(
list(
zip(
msg["word_timestamps"]["words"], msg["word_timestamps"]["start"]
)
)
)
elif msg["type"] == "chunk":
await self.stop_ttfb_metrics()
self.start_word_timestamps()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["data"]),
sample_rate=self._settings["output_format"]["sample_rate"],
num_channels=1,
)
await self.push_frame(frame)
elif msg["type"] == "error":
logger.error(f"{self} error: {msg}")
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
else:
logger.error(f"{self} error, unknown message type: {msg}")
async for attempt in AsyncRetrying(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
before_sleep=self._reconnect_websocket,
reraise=True,
):
with attempt:
await self._receive_messages()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"{self} exception: {e}")
await self._disconnect_websocket()
await self._connect_websocket()
message = f"{self} error receiving messages: {e}"
logger.error(message)
await self.push_error(ErrorFrame(message, fatal=True))
break

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
Expand Down
50 changes: 34 additions & 16 deletions src/pipecat/services/elevenlabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@

from loguru import logger
from pydantic import BaseModel, model_validator
from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential

from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
LLMFullResponseEndFrame,
StartFrame,
Expand Down Expand Up @@ -348,28 +350,44 @@ async def _disconnect_websocket(self):
except Exception as e:
logger.error(f"{self} error closing websocket: {e}")

async def _receive_messages(self):
async for message in self._websocket:
msg = json.loads(message)
if msg.get("audio"):
await self.stop_ttfb_metrics()
self.start_word_timestamps()

audio = base64.b64decode(msg["audio"])
frame = TTSAudioRawFrame(audio, self._settings["sample_rate"], 1)
await self.push_frame(frame)
if msg.get("alignment"):
word_times = calculate_word_times(msg["alignment"], self._cumulative_time)
await self.add_word_timestamps(word_times)
self._cumulative_time = word_times[-1][1]

async def _reconnect_websocket(self, retry_state: RetryCallState):
logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})")
await self._disconnect_websocket()
await self._connect_websocket()

async def _receive_task_handler(self):
while True:
try:
async for message in self._websocket:
msg = json.loads(message)
if msg.get("audio"):
await self.stop_ttfb_metrics()
self.start_word_timestamps()

audio = base64.b64decode(msg["audio"])
frame = TTSAudioRawFrame(audio, self._settings["sample_rate"], 1)
await self.push_frame(frame)
if msg.get("alignment"):
word_times = calculate_word_times(msg["alignment"], self._cumulative_time)
await self.add_word_timestamps(word_times)
self._cumulative_time = word_times[-1][1]
async for attempt in AsyncRetrying(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
before_sleep=self._reconnect_websocket,
reraise=True,
):
with attempt:
await self._receive_messages()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"{self} exception: {e}")
await self._disconnect_websocket()
await self._connect_websocket()
message = f"{self} error receiving messages: {e}"
logger.error(message)
await self.push_error(ErrorFrame(message, fatal=True))
break

async def _keepalive_task_handler(self):
while True:
Expand Down
55 changes: 36 additions & 19 deletions src/pipecat/services/lmnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import AsyncGenerator

from loguru import logger
from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential

from pipecat.frames.frames import (
CancelFrame,
Expand Down Expand Up @@ -159,31 +160,47 @@ async def _disconnect_lmnt(self):
except Exception as e:
logger.error(f"{self} error closing connection: {e}")

async def _receive_messages(self):
async for msg in self._connection:
if "error" in msg:
logger.error(f'{self} error: {msg["error"]}')
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
elif "audio" in msg:
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=msg["audio"],
sample_rate=self._settings["output_format"]["sample_rate"],
num_channels=1,
)
await self.push_frame(frame)
else:
logger.error(f"{self}: LMNT error, unknown message type: {msg}")

async def _reconnect_websocket(self, retry_state: RetryCallState):
logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})")
await self._disconnect_lmnt()
await self._connect_lmnt()

async def _receive_task_handler(self):
while True:
try:
async for msg in self._connection:
if "error" in msg:
logger.error(f'{self} error: {msg["error"]}')
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
elif "audio" in msg:
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=msg["audio"],
sample_rate=self._settings["output_format"]["sample_rate"],
num_channels=1,
)
await self.push_frame(frame)
else:
logger.error(f"{self}: LMNT error, unknown message type: {msg}")
async for attempt in AsyncRetrying(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
before_sleep=self._reconnect_websocket,
reraise=True,
):
with attempt:
await self._receive_messages()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"{self} exception: {e}")
await self._disconnect_lmnt()
await self._connect_lmnt()
message = f"{self} error receiving messages: {e}"
logger.error(message)
await self.push_error(ErrorFrame(message, fatal=True))
break

async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating TTS: [{text}]")
Expand Down
63 changes: 40 additions & 23 deletions src/pipecat/services/playht.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import websockets
from loguru import logger
from pydantic import BaseModel
from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential

from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
Expand Down Expand Up @@ -217,35 +218,51 @@ async def _handle_interruption(self, frame: StartInterruptionFrame, direction: F
await self.stop_all_metrics()
self._request_id = None

async def _receive_messages(self):
async for message in self._get_websocket():
if isinstance(message, bytes):
# Skip the WAV header message
if message.startswith(b"RIFF"):
continue
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(message, self._settings["sample_rate"], 1)
await self.push_frame(frame)
else:
logger.debug(f"Received text message: {message}")
try:
msg = json.loads(message)
if "request_id" in msg and msg["request_id"] == self._request_id:
await self.push_frame(TTSStoppedFrame())
self._request_id = None
elif "error" in msg:
logger.error(f"{self} error: {msg}")
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
except json.JSONDecodeError:
logger.error(f"Invalid JSON message: {message}")

async def _reconnect_websocket(self, retry_state: RetryCallState):
logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})")
await self._disconnect_websocket()
await self._connect_websocket()

async def _receive_task_handler(self):
while True:
try:
async for message in self._get_websocket():
if isinstance(message, bytes):
# Skip the WAV header message
if message.startswith(b"RIFF"):
continue
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(message, self._settings["sample_rate"], 1)
await self.push_frame(frame)
else:
logger.debug(f"Received text message: {message}")
try:
msg = json.loads(message)
if "request_id" in msg and msg["request_id"] == self._request_id:
await self.push_frame(TTSStoppedFrame())
self._request_id = None
elif "error" in msg:
logger.error(f"{self} error: {msg}")
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
except json.JSONDecodeError:
logger.error(f"Invalid JSON message: {message}")
async for attempt in AsyncRetrying(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
before_sleep=self._reconnect_websocket,
reraise=True,
):
with attempt:
await self._receive_messages()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"{self} exception in receive task: {e}")
await self._disconnect_websocket()
await self._connect_websocket()
message = f"{self} error receiving messages: {e}"
logger.error(message)
await self.push_error(ErrorFrame(message, fatal=True))
break

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
Expand Down
Loading