From f319d55b53b0c8eafecc550c6e6ffe588298c010 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Fri, 20 Dec 2024 14:41:57 -0500 Subject: [PATCH] Add Fish Audio TTS service --- .../foundational/07t-interruptible-fish.py | 99 ++++++++ pyproject.toml | 1 + src/pipecat/services/fish.py | 234 ++++++++++++++++++ 3 files changed, 334 insertions(+) create mode 100644 examples/foundational/07t-interruptible-fish.py create mode 100644 src/pipecat/services/fish.py diff --git a/examples/foundational/07t-interruptible-fish.py b/examples/foundational/07t-interruptible-fish.py new file mode 100644 index 000000000..e710e25c3 --- /dev/null +++ b/examples/foundational/07t-interruptible-fish.py @@ -0,0 +1,99 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import os +import sys + +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from runner import configure + +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.frames.frames import LLMMessagesFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.fish import FishAudioTTSService +from pipecat.services.openai import OpenAILLMService +from pipecat.transports.services.daily import DailyParams, DailyTransport + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +async def main(): + async with aiohttp.ClientSession() as session: + (room_url, token) = await configure(session) + + transport = DailyTransport( + room_url, + token, + "Respond bot", + DailyParams( + audio_out_enabled=True, + transcription_enabled=True, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + ), + ) + + tts = FishAudioTTSService( + api_key=os.getenv("FISH_API_KEY"), + model="4ce7e917cedd4bc2bb2e6ff3a46acaa1", # Barack Obama + ) + + llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.", + }, + ] + + context = OpenAILLMContext(messages) + context_aggregator = llm.create_context_aggregator(context) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + context_aggregator.user(), # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + context_aggregator.assistant(), # Assistant spoken responses + ] + ) + + task = PipelineTask( + pipeline, + PipelineParams( + allow_interruptions=True, + enable_metrics=True, + enable_usage_metrics=True, + report_only_initial_ttfb=True, + ), + ) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + await transport.capture_participant_transcription(participant["id"]) + # Kick off the conversation. + messages.append({"role": "system", "content": "Please introduce yourself to the user."}) + await task.queue_frames([LLMMessagesFrame(messages)]) + + runner = PipelineRunner() + + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 85db84a84..dae895ded 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ daily = [ "daily-python~=0.14.0" ] deepgram = [ "deepgram-sdk~=3.7.7" ] elevenlabs = [ "websockets~=13.1" ] fal = [ "fal-client~=0.4.1" ] +fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ] gladia = [ "websockets~=13.1" ] google = [ "google-generativeai~=0.8.3", "google-cloud-texttospeech~=2.21.1" ] grok = [ "openai~=1.57.2" ] diff --git a/src/pipecat/services/fish.py b/src/pipecat/services/fish.py new file mode 100644 index 000000000..4fbdee714 --- /dev/null +++ b/src/pipecat/services/fish.py @@ -0,0 +1,234 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import uuid +from typing import AsyncGenerator, Literal, Optional + +from loguru import logger +from pydantic import BaseModel +from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential + +from pipecat.frames.frames import ( + BotStoppedSpeakingFrame, + CancelFrame, + EndFrame, + ErrorFrame, + Frame, + LLMFullResponseEndFrame, + StartFrame, + StartInterruptionFrame, + TTSAudioRawFrame, + TTSSpeakFrame, + TTSStartedFrame, + TTSStoppedFrame, +) +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.ai_services import TTSService +from pipecat.transcriptions.language import Language + +try: + import ormsgpack + import websockets +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use Fish Audio, you need to `pip install pipecat-ai[fish]`. Also, set `FISH_API_KEY` environment variable." + ) + raise Exception(f"Missing module: {e}") + +# FishAudio supports various output formats +FishAudioOutputFormat = Literal["opus", "mp3", "pcm", "wav"] + + +class FishAudioTTSService(TTSService): + class InputParams(BaseModel): + language: Optional[Language] = Language.EN + latency: Optional[str] = "normal" # "normal" or "balanced" + prosody_speed: Optional[float] = 1.0 # Speech speed (0.5-2.0) + prosody_volume: Optional[int] = 0 # Volume adjustment in dB + + def __init__( + self, + *, + api_key: str, + model: str, # This is the reference_id + output_format: FishAudioOutputFormat = "pcm", + sample_rate: int = 24000, + params: InputParams = InputParams(), + **kwargs, + ): + super().__init__(sample_rate=sample_rate, **kwargs) + + self._api_key = api_key + self._base_url = "wss://api.fish.audio/v1/tts/live" + self._websocket = None + self._receive_task = None + self._request_id = None + self._started = False + + self._settings = { + "sample_rate": sample_rate, + "latency": params.latency, + "format": output_format, + "prosody": { + "speed": params.prosody_speed, + "volume": params.prosody_volume, + }, + "reference_id": model, + } + + self.set_model_name(model) + + def can_generate_metrics(self) -> bool: + return True + + async def set_model(self, model: str): + self._settings["reference_id"] = model + await super().set_model(model) + logger.info(f"Switching TTS model to: [{model}]") + + async def start(self, frame: StartFrame): + await super().start(frame) + await self._connect() + + async def stop(self, frame: EndFrame): + await super().stop(frame) + await self._disconnect() + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + await self._disconnect() + + async def _connect(self): + await self._connect_websocket() + self._receive_task = self.get_event_loop().create_task(self._receive_task_handler()) + + async def _disconnect(self): + await self._disconnect_websocket() + if self._receive_task: + self._receive_task.cancel() + await self._receive_task + self._receive_task = None + + async def _connect_websocket(self): + try: + logger.debug("Connecting to Fish Audio") + headers = {"Authorization": f"Bearer {self._api_key}"} + self._websocket = await websockets.connect(self._base_url, extra_headers=headers) + + # Send initial start message with ormsgpack + start_message = {"event": "start", "request": {"text": "", **self._settings}} + await self._websocket.send(ormsgpack.packb(start_message)) + logger.debug("Sent start message to Fish Audio") + except Exception as e: + logger.error(f"Fish Audio initialization error: {e}") + self._websocket = None + + async def _disconnect_websocket(self): + try: + await self.stop_all_metrics() + if self._websocket: + logger.debug("Disconnecting from Fish Audio") + # Send stop event with ormsgpack + stop_message = {"event": "stop"} + await self._websocket.send(ormsgpack.packb(stop_message)) + await self._websocket.close() + self._websocket = None + self._request_id = None + self._started = False + except Exception as e: + logger.error(f"Error closing websocket: {e}") + + def _get_websocket(self): + if self._websocket: + return self._websocket + raise Exception("Websocket not connected") + + async def _receive_messages(self): + async for message in self._get_websocket(): + try: + if isinstance(message, bytes): + msg = ormsgpack.unpackb(message) + if isinstance(msg, dict): + event = msg.get("event") + print(f"Received event: {event}") + if event == "audio": + await self.stop_ttfb_metrics() + audio_data = msg.get("audio") + # Only process larger chunks to remove msgpack overhead + if audio_data and len(audio_data) > 1024: + frame = TTSAudioRawFrame( + audio_data, self._settings["sample_rate"], 1 + ) + await self.push_frame(frame) + continue + + except Exception as e: + logger.error(f"Error processing message: {e}") + + async def _reconnect_websocket(self, retry_state: RetryCallState): + logger.warning(f"Fish Audio reconnecting (attempt: {retry_state.attempt_number})") + await self._disconnect_websocket() + await self._connect_websocket() + + async def _receive_task_handler(self): + while True: + try: + async for attempt in AsyncRetrying( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + before_sleep=self._reconnect_websocket, + reraise=True, + ): + with attempt: + await self._receive_messages() + except asyncio.CancelledError: + break + except Exception as e: + message = f"Fish Audio error receiving messages: {e}" + logger.error(message) + await self.push_error(ErrorFrame(message, fatal=True)) + break + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, TTSSpeakFrame): + await self.pause_processing_frames() + elif isinstance(frame, LLMFullResponseEndFrame) and self._request_id: + await self.pause_processing_frames() + elif isinstance(frame, BotStoppedSpeakingFrame): + await self.resume_processing_frames() + + async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection): + await super()._handle_interruption(frame, direction) + await self.stop_all_metrics() + self._request_id = None + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"Generating Fish TTS: [{text}]") + try: + if not self._websocket or self._websocket.closed: + await self._connect() + + if not self._request_id: + await self.start_ttfb_metrics() + yield TTSStartedFrame() + self._request_id = str(uuid.uuid4()) + + text_message = { + "event": "text", + "text": text, + } + await self._get_websocket().send(ormsgpack.packb(text_message)) + await self.start_tts_usage_metrics(text) + + yield None + + except Exception as e: + logger.error(f"Error generating TTS: {e}") + yield ErrorFrame(f"Error: {str(e)}")