Skip to content

Commit

Permalink
Add Fish Audio TTS service
Browse files Browse the repository at this point in the history
  • Loading branch information
markbackman committed Dec 20, 2024
1 parent 41d0769 commit f319d55
Show file tree
Hide file tree
Showing 3 changed files with 334 additions and 0 deletions.
99 changes: 99 additions & 0 deletions examples/foundational/07t-interruptible-fish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import asyncio
import os
import sys

import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure

from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMMessagesFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.fish import FishAudioTTSService
from pipecat.services.openai import OpenAILLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport

load_dotenv(override=True)

logger.remove(0)
logger.add(sys.stderr, level="DEBUG")


async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)

transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
)

tts = FishAudioTTSService(
api_key=os.getenv("FISH_API_KEY"),
model="4ce7e917cedd4bc2bb2e6ff3a46acaa1", # Barack Obama
)

llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")

messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]

context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)

pipeline = Pipeline(
[
transport.input(), # Transport user input
context_aggregator.user(), # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)

task = PipelineTask(
pipeline,
PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
report_only_initial_ttfb=True,
),
)

@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
await transport.capture_participant_transcription(participant["id"])
# Kick off the conversation.
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMMessagesFrame(messages)])

runner = PipelineRunner()

await runner.run(task)


if __name__ == "__main__":
asyncio.run(main())
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ daily = [ "daily-python~=0.14.0" ]
deepgram = [ "deepgram-sdk~=3.7.7" ]
elevenlabs = [ "websockets~=13.1" ]
fal = [ "fal-client~=0.4.1" ]
fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ]
gladia = [ "websockets~=13.1" ]
google = [ "google-generativeai~=0.8.3", "google-cloud-texttospeech~=2.21.1" ]
grok = [ "openai~=1.57.2" ]
Expand Down
234 changes: 234 additions & 0 deletions src/pipecat/services/fish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import asyncio
import uuid
from typing import AsyncGenerator, Literal, Optional

from loguru import logger
from pydantic import BaseModel
from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential

from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
LLMFullResponseEndFrame,
StartFrame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSSpeakFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import TTSService
from pipecat.transcriptions.language import Language

try:
import ormsgpack
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use Fish Audio, you need to `pip install pipecat-ai[fish]`. Also, set `FISH_API_KEY` environment variable."
)
raise Exception(f"Missing module: {e}")

# FishAudio supports various output formats
FishAudioOutputFormat = Literal["opus", "mp3", "pcm", "wav"]


class FishAudioTTSService(TTSService):
class InputParams(BaseModel):
language: Optional[Language] = Language.EN
latency: Optional[str] = "normal" # "normal" or "balanced"
prosody_speed: Optional[float] = 1.0 # Speech speed (0.5-2.0)
prosody_volume: Optional[int] = 0 # Volume adjustment in dB

def __init__(
self,
*,
api_key: str,
model: str, # This is the reference_id
output_format: FishAudioOutputFormat = "pcm",
sample_rate: int = 24000,
params: InputParams = InputParams(),
**kwargs,
):
super().__init__(sample_rate=sample_rate, **kwargs)

self._api_key = api_key
self._base_url = "wss://api.fish.audio/v1/tts/live"
self._websocket = None
self._receive_task = None
self._request_id = None
self._started = False

self._settings = {
"sample_rate": sample_rate,
"latency": params.latency,
"format": output_format,
"prosody": {
"speed": params.prosody_speed,
"volume": params.prosody_volume,
},
"reference_id": model,
}

self.set_model_name(model)

def can_generate_metrics(self) -> bool:
return True

async def set_model(self, model: str):
self._settings["reference_id"] = model
await super().set_model(model)
logger.info(f"Switching TTS model to: [{model}]")

async def start(self, frame: StartFrame):
await super().start(frame)
await self._connect()

async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._disconnect()

async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._disconnect()

async def _connect(self):
await self._connect_websocket()
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())

async def _disconnect(self):
await self._disconnect_websocket()
if self._receive_task:
self._receive_task.cancel()
await self._receive_task
self._receive_task = None

async def _connect_websocket(self):
try:
logger.debug("Connecting to Fish Audio")
headers = {"Authorization": f"Bearer {self._api_key}"}
self._websocket = await websockets.connect(self._base_url, extra_headers=headers)

# Send initial start message with ormsgpack
start_message = {"event": "start", "request": {"text": "", **self._settings}}
await self._websocket.send(ormsgpack.packb(start_message))
logger.debug("Sent start message to Fish Audio")
except Exception as e:
logger.error(f"Fish Audio initialization error: {e}")
self._websocket = None

async def _disconnect_websocket(self):
try:
await self.stop_all_metrics()
if self._websocket:
logger.debug("Disconnecting from Fish Audio")
# Send stop event with ormsgpack
stop_message = {"event": "stop"}
await self._websocket.send(ormsgpack.packb(stop_message))
await self._websocket.close()
self._websocket = None
self._request_id = None
self._started = False
except Exception as e:
logger.error(f"Error closing websocket: {e}")

def _get_websocket(self):
if self._websocket:
return self._websocket
raise Exception("Websocket not connected")

async def _receive_messages(self):
async for message in self._get_websocket():
try:
if isinstance(message, bytes):
msg = ormsgpack.unpackb(message)
if isinstance(msg, dict):
event = msg.get("event")
print(f"Received event: {event}")
if event == "audio":
await self.stop_ttfb_metrics()
audio_data = msg.get("audio")
# Only process larger chunks to remove msgpack overhead
if audio_data and len(audio_data) > 1024:
frame = TTSAudioRawFrame(
audio_data, self._settings["sample_rate"], 1
)
await self.push_frame(frame)
continue

except Exception as e:
logger.error(f"Error processing message: {e}")

async def _reconnect_websocket(self, retry_state: RetryCallState):
logger.warning(f"Fish Audio reconnecting (attempt: {retry_state.attempt_number})")
await self._disconnect_websocket()
await self._connect_websocket()

async def _receive_task_handler(self):
while True:
try:
async for attempt in AsyncRetrying(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
before_sleep=self._reconnect_websocket,
reraise=True,
):
with attempt:
await self._receive_messages()
except asyncio.CancelledError:
break
except Exception as e:
message = f"Fish Audio error receiving messages: {e}"
logger.error(message)
await self.push_error(ErrorFrame(message, fatal=True))
break

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

if isinstance(frame, TTSSpeakFrame):
await self.pause_processing_frames()
elif isinstance(frame, LLMFullResponseEndFrame) and self._request_id:
await self.pause_processing_frames()
elif isinstance(frame, BotStoppedSpeakingFrame):
await self.resume_processing_frames()

async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
await self.stop_all_metrics()
self._request_id = None

async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating Fish TTS: [{text}]")
try:
if not self._websocket or self._websocket.closed:
await self._connect()

if not self._request_id:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
self._request_id = str(uuid.uuid4())

text_message = {
"event": "text",
"text": text,
}
await self._get_websocket().send(ormsgpack.packb(text_message))
await self.start_tts_usage_metrics(text)

yield None

except Exception as e:
logger.error(f"Error generating TTS: {e}")
yield ErrorFrame(f"Error: {str(e)}")

0 comments on commit f319d55

Please sign in to comment.