Skip to content

Commit

Permalink
Merge pull request #391 from sharvil/pr/add-lmnt
Browse files Browse the repository at this point in the history
LMNT TTS
  • Loading branch information
aconchillo authored Aug 28, 2024
2 parents e038767 + 87c4a1b commit 79aca81
Show file tree
Hide file tree
Showing 5 changed files with 328 additions and 1 deletion.
4 changes: 4 additions & 0 deletions dot-env.template
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ FIREWORKS_API_KEY=...
# Gladia
GLADIA_API_KEY=...

# LMNT
LMNT_API_KEY=...
LMNT_VOICE_ID=...

# PlayHT
PLAY_HT_USER_ID=...
PLAY_HT_API_KEY=...
Expand Down
95 changes: 95 additions & 0 deletions examples/foundational/07k-interruptible-lmnt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import aiohttp
import asyncio
import os
import sys

from pipecat.frames.frames import LLMMessagesFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_response import (
LLMAssistantResponseAggregator, LLMUserResponseAggregator)
from pipecat.services.lmnt import LmntTTSService
from pipecat.services.openai import OpenAILLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
from pipecat.vad.silero import SileroVADAnalyzer

from runner import configure

from loguru import logger

from dotenv import load_dotenv
load_dotenv(override=True)

logger.remove(0)
logger.add(sys.stderr, level="DEBUG")


async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)

transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
audio_out_sample_rate=24000,
transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer()
)
)

tts = LmntTTSService(
api_key=os.getenv("LMNT_API_KEY"),
voice="morgan"
)

llm = OpenAILLMService(
api_key=os.getenv("OPENAI_API_KEY"),
model="gpt-4o")

messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]

tma_in = LLMUserResponseAggregator(messages)
tma_out = LLMAssistantResponseAggregator(messages)

pipeline = Pipeline([
transport.input(), # Transport user input
tma_in, # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
tma_out # Assistant spoken responses
])

task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))

@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
transport.capture_participant_transcription(participant["id"])
# Kick off the conversation.
messages.append(
{"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMMessagesFrame(messages)])

runner = PipelineRunner()

await runner.run(task)


if __name__ == "__main__":
asyncio.run(main())
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ gstreamer = [ "pygobject~=3.48.2" ]
fireworks = [ "openai~=1.37.2" ]
langchain = [ "langchain~=0.2.14", "langchain-community~=0.2.12", "langchain-openai~=0.1.20" ]
livekit = [ "livekit~=0.13.1" ]
lmnt = [ "lmnt~=1.1.4" ]
local = [ "pyaudio~=0.2.14" ]
moondream = [ "einops~=0.8.0", "timm~=1.0.8", "transformers~=4.44.0" ]
openai = [ "openai~=1.37.2" ]
Expand Down
60 changes: 59 additions & 1 deletion src/pipecat/services/ai_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# SPDX-License-Identifier: BSD 2-Clause License
#

import asyncio
import io
import wave

from abc import abstractmethod
from typing import AsyncGenerator
from typing import AsyncGenerator, Optional

from pipecat.frames.frames import (
AudioRawFrame,
Expand All @@ -20,6 +21,8 @@
StartFrame,
StartInterruptionFrame,
TTSSpeakFrame,
TTSStartedFrame,
TTSStoppedFrame,
TTSVoiceUpdateFrame,
TextFrame,
UserImageRequestFrame,
Expand Down Expand Up @@ -156,10 +159,18 @@ def __init__(
aggregate_sentences: bool = True,
# if True, subclass is responsible for pushing TextFrames and LLMFullResponseEndFrames
push_text_frames: bool = True,
# if True, TTSService will push TTSStoppedFrames, otherwise subclass must do it
push_stop_frames: bool = False,
# if push_stop_frames is True, wait for this idle period before pushing TTSStoppedFrame
stop_frame_timeout_s: float = 0.8,
**kwargs):
super().__init__(**kwargs)
self._aggregate_sentences: bool = aggregate_sentences
self._push_text_frames: bool = push_text_frames
self._push_stop_frames: bool = push_stop_frames
self._stop_frame_timeout_s: float = stop_frame_timeout_s
self._stop_frame_task: Optional[asyncio.Task] = None
self._stop_frame_queue: asyncio.Queue = asyncio.Queue()
self._current_sentence: str = ""

@abstractmethod
Expand Down Expand Up @@ -227,6 +238,53 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
else:
await self.push_frame(frame, direction)

async def start(self, frame: StartFrame):
await super().start(frame)
if self._push_stop_frames:
self._stop_frame_task = self.get_event_loop().create_task(self._stop_frame_handler())

async def stop(self, frame: EndFrame):
await super().stop(frame)
if self._stop_frame_task:
self._stop_frame_task.cancel()
await self._stop_frame_task
self._stop_frame_task = None

async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
if self._stop_frame_task:
self._stop_frame_task.cancel()
await self._stop_frame_task
self._stop_frame_task = None

async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
await super().push_frame(frame, direction)

if self._push_stop_frames and (
isinstance(frame, StartInterruptionFrame) or
isinstance(frame, TTSStartedFrame) or
isinstance(frame, AudioRawFrame) or
isinstance(frame, TTSStoppedFrame)):
await self._stop_frame_queue.put(frame)

async def _stop_frame_handler(self):
try:
has_started = False
while True:
try:
frame = await asyncio.wait_for(self._stop_frame_queue.get(),
self._stop_frame_timeout_s)
if isinstance(frame, TTSStartedFrame):
has_started = True
elif isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
has_started = False
except asyncio.TimeoutError:
if has_started:
await self.push_frame(TTSStoppedFrame())
has_started = False
except asyncio.CancelledError:
pass


class STTService(AIService):
"""STTService is a base class for speech-to-text services."""
Expand Down
169 changes: 169 additions & 0 deletions src/pipecat/services/lmnt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import json
import uuid
import base64
import asyncio
import time

from typing import AsyncGenerator

from pipecat.processors.frame_processor import FrameDirection
from pipecat.frames.frames import (
CancelFrame,
ErrorFrame,
Frame,
AudioRawFrame,
StartFrame,
StartInterruptionFrame,
EndFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.ai_services import TTSService

from loguru import logger

# See .env.example for LMNT configuration needed
try:
from lmnt.api import Speech
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use LMNT, you need to `pip install pipecat-ai[lmnt]`. Also, set `LMNT_API_KEY` environment variable.")
raise Exception(f"Missing module: {e}")


class LmntTTSService(TTSService):

def __init__(
self,
*,
api_key: str,
voice_id: str,
sample_rate: int = 24000,
language: str = "en",
**kwargs):
super().__init__(**kwargs)

# Let TTSService produce TTSStoppedFrames after a short delay of
# no activity.
self._push_stop_frames = True

self._api_key = api_key
self._voice_id = voice_id
self._output_format = {
"container": "raw",
"encoding": "pcm_s16le",
"sample_rate": sample_rate,
}
self._language = language

self._speech = None
self._connection = None
self._receive_task = None
self._started = False

def can_generate_metrics(self) -> bool:
return True

async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._voice_id = voice

async def start(self, frame: StartFrame):
await super().start(frame)
await self._connect()

async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._disconnect()

async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._disconnect()

async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
await super().push_frame(frame, direction)
if isinstance(frame, (TTSStoppedFrame, StartInterruptionFrame)):
self._started = False

async def _connect(self):
try:
self._speech = Speech()
self._connection = await self._speech.synthesize_streaming(self._voice_id, format="raw", sample_rate=self._output_format["sample_rate"])
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
except Exception as e:
logger.exception(f"{self} initialization error: {e}")
self._connection = None

async def _disconnect(self):
try:
await self.stop_all_metrics()

if self._receive_task:
self._receive_task.cancel()
await self._receive_task
self._receive_task = None
if self._connection:
await self._connection.socket.close()
self._connection = None
if self._speech:
await self._speech.close()
self._speech = None
self._started = False
except Exception as e:
logger.exception(f"{self} error closing websocket: {e}")

async def _receive_task_handler(self):
try:
async for msg in self._connection:
if "error" in msg:
logger.error(f'{self} error: {msg["error"]}')
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
elif "audio" in msg:
await self.stop_ttfb_metrics()
frame = AudioRawFrame(
audio=msg["audio"],
sample_rate=self._output_format["sample_rate"],
num_channels=1
)
await self.push_frame(frame)
else:
logger.error(f"LMNT error, unknown message type: {msg}")
except asyncio.CancelledError:
pass
except Exception as e:
logger.exception(f"{self} exception: {e}")

async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating TTS: [{text}]")

try:
if not self._connection:
await self._connect()

if not self._started:
await self.push_frame(TTSStartedFrame())
await self.start_ttfb_metrics()
self._started = True

try:
await self._connection.append_text(text)
await self._connection.flush()
await self.start_tts_usage_metrics(text)
except Exception as e:
logger.error(f"{self} error sending message: {e}")
await self.push_frame(TTSStoppedFrame())
await self._disconnect()
await self._connect()
return
yield None
except Exception as e:
logger.exception(f"{self} exception: {e}")

0 comments on commit 79aca81

Please sign in to comment.