Skip to content

Commit

Permalink
LMNT TTS
Browse files Browse the repository at this point in the history
  • Loading branch information
sharvil committed Aug 22, 2024
1 parent 21de8e0 commit f4fd7b7
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 0 deletions.
4 changes: 4 additions & 0 deletions dot-env.template
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ FIREWORKS_API_KEY=...
# Gladia
GLADIA_API_KEY=...

# LMNT
LMNT_API_KEY=...
LMNT_VOICE_ID=...

# PlayHT
PLAY_HT_USER_ID=...
PLAY_HT_API_KEY=...
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ google = [ "google-generativeai~=0.7.2" ]
gstreamer = [ "pygobject~=3.48.2" ]
fireworks = [ "openai~=1.37.2" ]
langchain = [ "langchain~=0.2.14", "langchain-community~=0.2.12", "langchain-openai~=0.1.20" ]
lmnt = [ "lmnt~=1.1.4" ]
local = [ "pyaudio~=0.2.14" ]
moondream = [ "einops~=0.8.0", "timm~=1.0.8", "transformers~=4.44.0" ]
openai = [ "openai~=1.37.2" ]
Expand Down
159 changes: 159 additions & 0 deletions src/pipecat/services/lmnt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import json
import uuid
import base64
import asyncio
import time

from typing import AsyncGenerator

from pipecat.processors.frame_processor import FrameDirection
from pipecat.frames.frames import (
CancelFrame,
ErrorFrame,
Frame,
AudioRawFrame,
StartFrame,
EndFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.ai_services import TTSService

from loguru import logger

# See .env.example for LMNT configuration needed
try:
from lmnt.api import Speech
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use LMNT, you need to `pip install pipecat-ai[lmnt]`. Also, set `LMNT_API_KEY` environment variable.")
raise Exception(f"Missing module: {e}")


class LmntTTSService(TTSService):

def __init__(
self,
*,
api_key: str,
voice_id: str,
sample_rate: int = 16000,
language: str = "en",
**kwargs):
super().__init__(**kwargs)

self._api_key = api_key
self._voice_id = voice_id
self._output_format = {
"container": "raw",
"encoding": "pcm_s16le",
"sample_rate": sample_rate,
}
self._language = language

self._speech = None
self._connection = None
self._receive_task = None
self._started = False

def can_generate_metrics(self) -> bool:
return True

async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._voice_id = voice

async def start(self, frame: StartFrame):
await super().start(frame)
await self._connect()

async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._disconnect()

async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._disconnect()

async def _connect(self):
try:
self._speech = Speech()
self._connection = await self._speech.synthesize_streaming(self._voice_id, format="raw", sample_rate=self._output_format["sample_rate"])
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
except Exception as e:
logger.exception(f"{self} initialization error: {e}")
self._connection = None

async def _disconnect(self):
try:
await self.stop_all_metrics()

if self._receive_task:
self._receive_task.cancel()
await self._receive_task
self._receive_task = None
if self._connection:
await self._connection.socket.close()
self._connection = None
if self._speech:
await self._speech.close()
self._speech = None
self._started = False
except Exception as e:
logger.exception(f"{self} error closing websocket: {e}")

async def _receive_task_handler(self):
try:
async for msg in self._connection:
if "error" in msg:
logger.error(f'{self} error: {msg["error"]}')
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
elif "audio" in msg:
await self.stop_ttfb_metrics()
frame = AudioRawFrame(
audio=msg["audio"],
sample_rate=self._output_format["sample_rate"],
num_channels=1
)
await self.push_frame(frame)
else:
logger.error(f"LMNT error, unknown message type: {msg}")
except asyncio.CancelledError:
pass
except Exception as e:
logger.exception(f"{self} exception: {e}")

async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating TTS: [{text}]")

try:
if not self._connection:
await self._connect()

if not self._started:
await self.push_frame(TTSStartedFrame())
await self.start_ttfb_metrics()
self._started = True

try:
await self._connection.append_text(text)
await self._connection.flush()
await self.start_tts_usage_metrics(text)
except Exception as e:
logger.error(f"{self} error sending message: {e}")
await self.push_frame(TTSStoppedFrame())
await self._disconnect()
await self._connect()
return
yield None
except Exception as e:
logger.exception(f"{self} exception: {e}")

0 comments on commit f4fd7b7

Please sign in to comment.