Skip to content

Commit

Permalink
Refactor LMNTTTSService to make a websocket connection directly, then…
Browse files Browse the repository at this point in the history
… use the WebsocketService base class
  • Loading branch information
markbackman committed Jan 10, 2025
1 parent dc6f634 commit c4c7bc0
Showing 1 changed file with 74 additions and 81 deletions.
155 changes: 74 additions & 81 deletions src/pipecat/services/lmnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
# SPDX-License-Identifier: BSD 2-Clause License
#

import asyncio
import json
from typing import AsyncGenerator

from loguru import logger
from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential

from pipecat.frames.frames import (
CancelFrame,
Expand All @@ -23,11 +22,12 @@
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import TTSService
from pipecat.services.websocket_service import WebsocketService
from pipecat.transcriptions.language import Language

# See .env.example for LMNT configuration needed
try:
from lmnt.api import Speech
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
Expand Down Expand Up @@ -60,7 +60,7 @@ def language_to_lmnt_language(language: Language) -> str | None:
return result


class LmntTTSService(TTSService):
class LmntTTSService(TTSService, WebsocketService):
def __init__(
self,
*,
Expand All @@ -70,27 +70,21 @@ def __init__(
language: Language = Language.EN,
**kwargs,
):
# Let TTSService produce TTSStoppedFrames after a short delay of
# no activity.
super().__init__(push_stop_frames=True, sample_rate=sample_rate, **kwargs)
TTSService.__init__(
self,
push_stop_frames=True,
sample_rate=sample_rate,
**kwargs,
)
WebsocketService.__init__(self)

self._api_key = api_key
self._voice_id = voice_id
self._settings = {
"output_format": {
"container": "raw",
"encoding": "pcm_s16le",
"sample_rate": sample_rate,
},
"sample_rate": sample_rate,
"language": self.language_to_service_language(language),
"format": "raw", # Use raw format for direct PCM data
}

self.set_voice(voice_id)

self._speech = None
self._connection = None
self._receive_task = None
# Indicates if we have sent TTSStartedFrame. It will reset to False when
# there's an interruption or TTSStoppedFrame.
self._started = False

def can_generate_metrics(self) -> bool:
Expand All @@ -117,106 +111,105 @@ async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirect
self._started = False

async def _connect(self):
await self._connect_lmnt()
await self._connect_websocket()

self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())

async def _disconnect(self):
await self._disconnect_lmnt()
await self._disconnect_websocket()

if self._receive_task:
self._receive_task.cancel()
await self._receive_task
self._receive_task = None

async def _connect_lmnt(self):
async def _connect_websocket(self):
"""Connect to LMNT websocket."""
try:
logger.debug("Connecting to LMNT")

self._speech = Speech()
self._connection = await self._speech.synthesize_streaming(
self._voice_id,
format="raw",
sample_rate=self._settings["output_format"]["sample_rate"],
language=self._settings["language"],
)
# Build initial connection message
init_msg = {
"X-API-Key": self._api_key,
"voice": self._voice_id,
"format": self._settings["format"],
"sample_rate": self._settings["sample_rate"],
"language": self._settings["language"],
}

# Connect to LMNT's websocket directly
self._websocket = await websockets.connect("wss://api.lmnt.com/v1/ai/speech/stream")

# Send initialization message
await self._websocket.send(json.dumps(init_msg))

except Exception as e:
logger.error(f"{self} initialization error: {e}")
self._connection = None
self._websocket = None

async def _disconnect_lmnt(self):
async def _disconnect_websocket(self):
"""Disconnect from LMNT websocket."""
try:
await self.stop_all_metrics()

if self._connection:
if self._websocket:
logger.debug("Disconnecting from LMNT")
await self._connection.socket.close()
self._connection = None
if self._speech:
await self._speech.close()
self._speech = None
# Send EOF message before closing
await self._websocket.send(json.dumps({"eof": True}))
await self._websocket.close()
self._websocket = None

self._started = False
except Exception as e:
logger.error(f"{self} error closing connection: {e}")
logger.error(f"{self} error closing websocket: {e}")

def _get_websocket(self):
if self._websocket:
return self._websocket
raise Exception("Websocket not connected")

async def _receive_messages(self):
async for msg in self._connection:
if "error" in msg:
logger.error(f'{self} error: {msg["error"]}')
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
elif "audio" in msg:
"""Receive messages from LMNT websocket."""
async for message in self._get_websocket():
if isinstance(message, bytes):
# Raw audio data
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=msg["audio"],
sample_rate=self._settings["output_format"]["sample_rate"],
audio=message,
sample_rate=self._settings["sample_rate"],
num_channels=1,
)
await self.push_frame(frame)
else:
logger.error(f"{self}: LMNT error, unknown message type: {msg}")

async def _reconnect_websocket(self, retry_state: RetryCallState):
logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})")
await self._disconnect_lmnt()
await self._connect_lmnt()

async def _receive_task_handler(self):
while True:
try:
async for attempt in AsyncRetrying(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
before_sleep=self._reconnect_websocket,
reraise=True,
):
with attempt:
await self._receive_messages()
except asyncio.CancelledError:
break
except Exception as e:
message = f"{self} error receiving messages: {e}"
logger.error(message)
await self.push_error(ErrorFrame(message, fatal=True))
break
try:
msg = json.loads(message)
if "error" in msg:
logger.error(f'{self} error: {msg["error"]}')
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
return
except json.JSONDecodeError:
logger.error(f"Invalid JSON message: {message}")

async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate TTS audio from text."""
logger.debug(f"Generating TTS: [{text}]")

try:
if not self._connection:
if not self._websocket:
await self._connect()

if not self._started:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
self._started = True

try:
await self._connection.append_text(text)
await self._connection.flush()
if not self._started:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
self._started = True

# Send text to LMNT
await self._get_websocket().send(json.dumps({"text": text}))
# Force synthesis
await self._get_websocket().send(json.dumps({"flush": True}))
await self.start_tts_usage_metrics(text)
except Exception as e:
logger.error(f"{self} error sending message: {e}")
Expand Down

0 comments on commit c4c7bc0

Please sign in to comment.