From f686df05f219fc5c08d59057c78fdb650843ba95 Mon Sep 17 00:00:00 2001 From: MoeMamdouh Date: Wed, 11 Dec 2024 10:44:23 +0200 Subject: [PATCH 1/5] feat(fish): add FishAudioTTSService for real-time TTS processing via WebSocket --- examples/simple-chatbot/server/bot.py | 22 ++- src/pipecat/services/fish.py | 250 ++++++++++++++++++++++++++ 2 files changed, 260 insertions(+), 12 deletions(-) create mode 100644 src/pipecat/services/fish.py diff --git a/examples/simple-chatbot/server/bot.py b/examples/simple-chatbot/server/bot.py index 5bfb4a8fe..e2553dbbb 100644 --- a/examples/simple-chatbot/server/bot.py +++ b/examples/simple-chatbot/server/bot.py @@ -33,7 +33,7 @@ RTVIBotTranscriptionProcessor, RTVIUserTranscriptionProcessor, ) -from pipecat.services.elevenlabs import ElevenLabsTTSService +from pipecat.services.fish import FishAudioTTSService from pipecat.services.openai import OpenAILLMService from pipecat.transports.services.daily import DailyParams, DailyTransport @@ -113,17 +113,15 @@ async def main(): ), ) - tts = ElevenLabsTTSService( - api_key=os.getenv("ELEVENLABS_API_KEY"), - # - # English - # - voice_id="pNInz6obpgDQGcFmaJgB", - # - # Spanish - # - # model="eleven_multilingual_v2", - # voice_id="gD1IexrzCvsXPHUuT0s3", + tts = FishAudioTTSService( + api_key=os.getenv("FISH_API_KEY"), + model_id="e58b0d7efca34eb38d5c4985e378abcb", # Trump voice + params=FishAudioTTSService.InputParams( + # language=Language.EN_US, # Use the Language enum + latency="normal", # Optional, defaults to "normal" + prosody_speed=4.0, # Use prosody_speed instead of speed + prosody_volume=2.0 # Use prosody_volume instead of pitch + ) ) llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") diff --git a/src/pipecat/services/fish.py b/src/pipecat/services/fish.py new file mode 100644 index 000000000..f29c6cb26 --- /dev/null +++ b/src/pipecat/services/fish.py @@ -0,0 +1,250 @@ +import asyncio +import base64 +import json +from typing import Any, AsyncGenerator, Dict, List, Optional + +import websockets +from loguru import logger +from pydantic import BaseModel + +from pipecat.frames.frames import ( + BotStoppedSpeakingFrame, + CancelFrame, + EndFrame, + ErrorFrame, + Frame, + LLMFullResponseEndFrame, + StartFrame, + StartInterruptionFrame, + TTSAudioRawFrame, + TTSSpeakFrame, + TTSStartedFrame, + TTSStoppedFrame, +) +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.ai_services import TTSService +from pipecat.transcriptions.language import Language + +# FishAudio supports various output formats +FishAudioOutputFormat = Literal["opus", "mp3", "wav"] + +def language_to_fishaudio_language(language: Language) -> str: + # Map Language enum to fish.audio language codes + language_map = { + Language.EN: "en-US", + Language.EN_US: "en-US", + Language.EN_GB: "en-GB", + Language.ES: "es-ES", + Language.FR: "fr-FR", + Language.DE: "de-DE", + # Add other mappings as needed + } + return language_map.get(language, "en-US") # Default to 'en-US' if not found + +def sample_rate_from_output_format(output_format: str) -> int: + # FishAudio might have specific sample rates per format + format_sample_rates = { + "opus": 24000, + "mp3": 44100, + "wav": 44100, + } + return format_sample_rates.get(output_format, 24000) # Default to 24kHz + +class FishAudioTTSService(TTSService): + class InputParams(BaseModel): + language: Optional[Language] = Language.EN + latency: Optional[str] = "normal" # "normal" or "balanced" + prosody_speed: Optional[float] = 1.0 # Speech speed (0.5-2.0) + prosody_volume: Optional[int] = 0 # Volume adjustment in dB + + def __init__( + self, + *, + api_key: str, + model_id: str, + output_format: FishAudioOutputFormat = "opus", + params: InputParams = InputParams(), + **kwargs, + ): + super().__init__( + sample_rate=sample_rate_from_output_format(output_format), + **kwargs, + ) + + self._api_key = "api_key" + self._model_id = model_id + self._url = "wss://api.fish.audio/v1/tts/live" + self._output_format = output_format + + self._settings = { + # "sample_rate": sample_rate_from_output_format(output_format), + # "language": self.language_to_service_language(params.language) + # if params.language else "en-US", + "latency": params.latency, + "prosody": { + "speed": params.prosody_speed, + "volume": params.prosody_volume, + }, + "format": output_format, + "reference_id": model_id, + } + + self._websocket = None + self._receive_task = None + self._started = False + + def can_generate_metrics(self) -> bool: + return True + + def language_to_service_language(self, language: Language) -> str: + return language_to_fishaudio_language(language) + + async def start(self, frame: StartFrame): + await super().start(frame) + await self._connect() + + async def stop(self, frame: EndFrame): + await super().stop(frame) + await self._disconnect() + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + await self._disconnect() + + async def _connect(self): + try: + # headers = { + # "Authorization": f"Bearer {self._api_key}", + # } + + self._websocket = await websockets.connect(self._url, extra_headers=headers) + self._receive_task = asyncio.create_task(self._receive_task_handler()) + logger.debug("Connected to fish.audio WebSocket") + + # Send 'start' event to initialize the session + start_message = { + "event": "start", + "request": { + "text": "", # Initial empty text + "latency": self._settings["latency"], + "format": self._output_format, + "prosody": self._settings["prosody"], + "reference_id": self._settings["reference_id"], + }, + "debug": True, # Added debug flag + + } + await self._websocket.send(json.dumps(start_message)) + logger.debug("Sent start event to fish.audio WebSocket") + + except Exception as e: + logger.error(f"Error connecting to fish.audio WebSocket: {e}") + self._websocket = None + + async def _disconnect(self): + try: + await self.stop_all_metrics() + + if self._websocket: + # Send 'stop' event to end the session + stop_message = { + "event": "stop" + } + await self._websocket.send(json.dumps(stop_message)) + await self._websocket.close() + self._websocket = None + + if self._receive_task: + self._receive_task.cancel() + await self._receive_task + self._receive_task = None + + self._started = False + except Exception as e: + logger.error(f"Error disconnecting from fish.audio WebSocket: {e}") + + async def _receive_task_handler(self): + try: + async for message in self._websocket: + # Messages can be text or binary + if isinstance(message, str): + msg = json.loads(message) + event = msg.get("event") + + if event == "audio": + await self.stop_ttfb_metrics() + audio_data_base64 = msg.get("audio") + audio_data = base64.b64decode(audio_data_base64) + frame = TTSAudioRawFrame( + audio_data, self._settings["sample_rate"], 1) + await self.push_frame(frame) + elif event == "finish": + reason = msg.get("reason") + if reason == "stop": + await self.push_frame(TTSStoppedFrame()) + self._started = False + elif reason == "error": + error_msg = msg.get("error", "Unknown error") + logger.error(f"fish.audio error: {error_msg}") + await self.push_error(ErrorFrame(f"fish.audio error: {error_msg}")) + self._started = False + elif event == "error": + error_msg = msg.get("error", "Unknown error") + logger.error(f"fish.audio error: {error_msg}") + await self.push_error(ErrorFrame(f"fish.audio error: {error_msg}")) + else: + logger.warning(f"Received unexpected binary message: {message}") + + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"Exception in receive task: {e}") + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, TTSSpeakFrame): + await self.pause_processing_frames() + elif isinstance(frame, LLMFullResponseEndFrame) and self._started: + await self.pause_processing_frames() + elif isinstance(frame, BotStoppedSpeakingFrame): + await self.resume_processing_frames() + + async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection): + await super()._handle_interruption(frame, direction) + await self.stop_all_metrics() + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"Generating TTS: [{text}]") + + try: + if not self._websocket or self._websocket.closed: + await self._connect() + + if not self._started: + await self.start_ttfb_metrics() + yield TTSStartedFrame() + self._started = True + + # Send 'text' event to stream text chunks + text_message = { + "event": "text", + "text": text + " " # Ensure a space at the end + } + + try: + await self._websocket.send(json.dumps(text_message)) + await self.start_tts_usage_metrics(text) + except Exception as e: + logger.error(f"Error sending text to fish.audio WebSocket: {e}") + yield TTSStoppedFrame() + await self._disconnect() + return + + # The audio frames will be received in _receive_task_handler + yield None + + except Exception as e: + logger.error(f"Error in run_tts: {e}") + yield ErrorFrame(f"Error in run_tts: {str(e)}") + From e7949d83b14e5c64a61437d6c502f99f3adb531f Mon Sep 17 00:00:00 2001 From: MoeMamdouh Date: Tue, 17 Dec 2024 15:01:14 +0200 Subject: [PATCH 2/5] feat(fish): update FishAudioTTSService to use MessagePack for WebSocket communication and adjust audio format settings --- src/pipecat/services/fish.py | 109 +++++++++++++++++------------------ 1 file changed, 54 insertions(+), 55 deletions(-) diff --git a/src/pipecat/services/fish.py b/src/pipecat/services/fish.py index f29c6cb26..d09245aa3 100644 --- a/src/pipecat/services/fish.py +++ b/src/pipecat/services/fish.py @@ -1,11 +1,11 @@ import asyncio import base64 -import json -from typing import Any, AsyncGenerator, Dict, List, Optional +from typing import Any, AsyncGenerator, Dict, Optional, Literal import websockets from loguru import logger from pydantic import BaseModel +import ormsgpack # Import ormsgpack for MessagePack encoding/decoding from pipecat.frames.frames import ( BotStoppedSpeakingFrame, @@ -45,8 +45,8 @@ def sample_rate_from_output_format(output_format: str) -> int: # FishAudio might have specific sample rates per format format_sample_rates = { "opus": 24000, - "mp3": 44100, - "wav": 44100, + "mp3": 24000, + "wav": 24000, } return format_sample_rates.get(output_format, 24000) # Default to 24kHz @@ -62,7 +62,7 @@ def __init__( *, api_key: str, model_id: str, - output_format: FishAudioOutputFormat = "opus", + output_format: FishAudioOutputFormat = "wav", params: InputParams = InputParams(), **kwargs, ): @@ -71,13 +71,13 @@ def __init__( **kwargs, ) - self._api_key = "api_key" + self._api_key = api_key self._model_id = model_id self._url = "wss://api.fish.audio/v1/tts/live" self._output_format = output_format self._settings = { - # "sample_rate": sample_rate_from_output_format(output_format), + "sample_rate": sample_rate_from_output_format(output_format), # "language": self.language_to_service_language(params.language) # if params.language else "en-US", "latency": params.latency, @@ -113,13 +113,12 @@ async def cancel(self, frame: CancelFrame): async def _connect(self): try: - # headers = { - # "Authorization": f"Bearer {self._api_key}", - # } + headers = { + "Authorization": f"Bearer {self._api_key}", + } self._websocket = await websockets.connect(self._url, extra_headers=headers) self._receive_task = asyncio.create_task(self._receive_task_handler()) - logger.debug("Connected to fish.audio WebSocket") # Send 'start' event to initialize the session start_message = { @@ -130,15 +129,15 @@ async def _connect(self): "format": self._output_format, "prosody": self._settings["prosody"], "reference_id": self._settings["reference_id"], + "sample_rate": self._settings["sample_rate"], }, "debug": True, # Added debug flag - } - await self._websocket.send(json.dumps(start_message)) + await self._websocket.send(ormsgpack.packb(start_message)) logger.debug("Sent start event to fish.audio WebSocket") except Exception as e: - logger.error(f"Error connecting to fish.audio WebSocket: {e}") + logger.exception(f"Error connecting to fish.audio WebSocket: {e}") self._websocket = None async def _disconnect(self): @@ -150,7 +149,7 @@ async def _disconnect(self): stop_message = { "event": "stop" } - await self._websocket.send(json.dumps(stop_message)) + await self._websocket.send(ormsgpack.packb(stop_message)) await self._websocket.close() self._websocket = None @@ -165,40 +164,45 @@ async def _disconnect(self): async def _receive_task_handler(self): try: - async for message in self._websocket: - # Messages can be text or binary - if isinstance(message, str): - msg = json.loads(message) - event = msg.get("event") - - if event == "audio": - await self.stop_ttfb_metrics() - audio_data_base64 = msg.get("audio") - audio_data = base64.b64decode(audio_data_base64) - frame = TTSAudioRawFrame( - audio_data, self._settings["sample_rate"], 1) - await self.push_frame(frame) - elif event == "finish": - reason = msg.get("reason") - if reason == "stop": - await self.push_frame(TTSStoppedFrame()) - self._started = False - elif reason == "error": + while True: + try: + message = await self._websocket.recv() + if isinstance(message, bytes): + msg = ormsgpack.unpackb(message) + event = msg.get("event") + + if event == "audio": + await self.stop_ttfb_metrics() + audio_data = msg.get("audio") + # Audio data is binary, no need to base64 decode + frame = TTSAudioRawFrame( + audio_data, self._settings["sample_rate"], 1) + await self.push_frame(frame) + elif event == "finish": + reason = msg.get("reason") + if reason == "stop": + await self.push_frame(TTSStoppedFrame()) + self._started = False + elif reason == "error": + error_msg = msg.get("error", "Unknown error") + logger.error(f"fish.audio error: {error_msg}") + await self.push_error(ErrorFrame(f"fish.audio error: {error_msg}")) + self._started = False + elif event == "error": error_msg = msg.get("error", "Unknown error") logger.error(f"fish.audio error: {error_msg}") await self.push_error(ErrorFrame(f"fish.audio error: {error_msg}")) - self._started = False - elif event == "error": - error_msg = msg.get("error", "Unknown error") - logger.error(f"fish.audio error: {error_msg}") - await self.push_error(ErrorFrame(f"fish.audio error: {error_msg}")) - else: - logger.warning(f"Received unexpected binary message: {message}") - - except asyncio.CancelledError: - pass + else: + logger.warning(f"Unhandled event from fish.audio: {event}") + else: + logger.warning(f"Received unexpected message type: {type(message)}") + except asyncio.TimeoutError: + logger.warning("No message received from fish.audio within timeout period") + except websockets.ConnectionClosed as e: + logger.error(f"WebSocket connection closed: {e}") + break except Exception as e: - logger.error(f"Exception in receive task: {e}") + logger.exception(f"Exception in receive task: {e}") async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -215,7 +219,7 @@ async def _handle_interruption(self, frame: StartInterruptionFrame, direction: F await self.stop_all_metrics() async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: - logger.debug(f"Generating TTS: [{text}]") + logger.debug(f"Generating Fish TTS: [{text}]") try: if not self._websocket or self._websocket.closed: @@ -231,15 +235,11 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: "event": "text", "text": text + " " # Ensure a space at the end } + logger.debug(f"Sending text message: {text_message}") + await self._websocket.send(ormsgpack.packb(text_message)) + logger.debug("Sent text message to fish.audio WebSocket") - try: - await self._websocket.send(json.dumps(text_message)) - await self.start_tts_usage_metrics(text) - except Exception as e: - logger.error(f"Error sending text to fish.audio WebSocket: {e}") - yield TTSStoppedFrame() - await self._disconnect() - return + await self.start_tts_usage_metrics(text) # The audio frames will be received in _receive_task_handler yield None @@ -247,4 +247,3 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: except Exception as e: logger.error(f"Error in run_tts: {e}") yield ErrorFrame(f"Error in run_tts: {str(e)}") - From baa46f1910c8b7b55329e3f8e0d495e0c265c136 Mon Sep 17 00:00:00 2001 From: MoeMamdouh Date: Thu, 19 Dec 2024 17:17:24 +0200 Subject: [PATCH 3/5] feat(fish): add example for FishAudioTTSService --- .../07t-xinterruptible-fish-audio.py | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 examples/foundational/07t-xinterruptible-fish-audio.py diff --git a/examples/foundational/07t-xinterruptible-fish-audio.py b/examples/foundational/07t-xinterruptible-fish-audio.py new file mode 100644 index 000000000..63dee387f --- /dev/null +++ b/examples/foundational/07t-xinterruptible-fish-audio.py @@ -0,0 +1,105 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import os +import sys + +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from runner import configure + +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.frames.frames import LLMMessagesFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.openai import OpenAILLMService +from pipecat.services.fish import FishAudioTTSService +from pipecat.transports.services.daily import DailyParams, DailyTransport + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +async def main(): + async with aiohttp.ClientSession() as session: + (room_url, token) = await configure(session) + + transport = DailyTransport( + room_url, + token, + "Respond bot", + DailyParams( + audio_out_enabled=True, + transcription_enabled=True, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + ), + ) + + tts = FishAudioTTSService( + api_key=os.getenv("FISH_API_KEY"), + model_id="e58b0d7efca34eb38d5c4985e378abcb", # Trump + params=FishAudioTTSService.InputParams( + # language=Language.EN_US, # Use the Language enum + latency="normal", # Optional, defaults to "normal" + prosody_speed=1.0, # Use prosody_speed instead of speed + prosody_volume=0 # Use prosody_volume instead of pitch + ) + ) + + llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.", + }, + ] + + context = OpenAILLMContext(messages) + context_aggregator = llm.create_context_aggregator(context) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + context_aggregator.user(), # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + context_aggregator.assistant(), # Assistant spoken responses + ] + ) + + task = PipelineTask( + pipeline, + PipelineParams( + allow_interruptions=True, + enable_metrics=True, + enable_usage_metrics=True, + report_only_initial_ttfb=True, + ), + ) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + await transport.capture_participant_transcription(participant["id"]) + # Kick off the conversation. + messages.append({"role": "system", "content": "Please introduce yourself to the user."}) + await task.queue_frames([LLMMessagesFrame(messages)]) + + runner = PipelineRunner() + + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) From 468cc900f68e40d3ffd8fcd5db3ef059e5c5f761 Mon Sep 17 00:00:00 2001 From: MoeMamdouh Date: Thu, 19 Dec 2024 17:22:17 +0200 Subject: [PATCH 4/5] refactor(fish): comment out debug and error logging in FishAudioTTSService --- src/pipecat/services/fish.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/pipecat/services/fish.py b/src/pipecat/services/fish.py index d09245aa3..166970aa9 100644 --- a/src/pipecat/services/fish.py +++ b/src/pipecat/services/fish.py @@ -134,10 +134,10 @@ async def _connect(self): "debug": True, # Added debug flag } await self._websocket.send(ormsgpack.packb(start_message)) - logger.debug("Sent start event to fish.audio WebSocket") + # logger.debug("Sent start event to fish.audio WebSocket") except Exception as e: - logger.exception(f"Error connecting to fish.audio WebSocket: {e}") + # logger.exception(f"Error connecting to fish.audio WebSocket: {e}") self._websocket = None async def _disconnect(self): @@ -192,8 +192,6 @@ async def _receive_task_handler(self): error_msg = msg.get("error", "Unknown error") logger.error(f"fish.audio error: {error_msg}") await self.push_error(ErrorFrame(f"fish.audio error: {error_msg}")) - else: - logger.warning(f"Unhandled event from fish.audio: {event}") else: logger.warning(f"Received unexpected message type: {type(message)}") except asyncio.TimeoutError: From ee1b871404a0dd4bb54132d22062ae12a3ec6ffc Mon Sep 17 00:00:00 2001 From: MoeMamdouh Date: Thu, 19 Dec 2024 17:33:07 +0200 Subject: [PATCH 5/5] refactor(bot-openai): replace FishAudioTTSService with ElevenLabsTTSService --- examples/simple-chatbot/server/bot-openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/simple-chatbot/server/bot-openai.py b/examples/simple-chatbot/server/bot-openai.py index 003709291..a3a68c839 100644 --- a/examples/simple-chatbot/server/bot-openai.py +++ b/examples/simple-chatbot/server/bot-openai.py @@ -48,7 +48,7 @@ RTVISpeakingProcessor, RTVIUserTranscriptionProcessor, ) -from pipecat.services.fish import FishAudioTTSService +from pipecat.services.elevenlabs import ElevenLabsTTSService from pipecat.services.openai import OpenAILLMService from pipecat.transports.services.daily import DailyParams, DailyTransport