Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Fish Audio tts #827

Closed
wants to merge 6 commits into from
Closed
105 changes: 105 additions & 0 deletions examples/foundational/07t-xinterruptible-fish-audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import asyncio
import os
import sys

import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure

from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMMessagesFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.openai import OpenAILLMService
from pipecat.services.fish import FishAudioTTSService
from pipecat.transports.services.daily import DailyParams, DailyTransport

load_dotenv(override=True)

logger.remove(0)
logger.add(sys.stderr, level="DEBUG")


async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)

transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
),
)

tts = FishAudioTTSService(
api_key=os.getenv("FISH_API_KEY"),
model_id="e58b0d7efca34eb38d5c4985e378abcb", # Trump
params=FishAudioTTSService.InputParams(
# language=Language.EN_US, # Use the Language enum
latency="normal", # Optional, defaults to "normal"
prosody_speed=1.0, # Use prosody_speed instead of speed
prosody_volume=0 # Use prosody_volume instead of pitch
)
)

llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")

messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]

context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)

pipeline = Pipeline(
[
transport.input(), # Transport user input
context_aggregator.user(), # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)

task = PipelineTask(
pipeline,
PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
report_only_initial_ttfb=True,
),
)

@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
await transport.capture_participant_transcription(participant["id"])
# Kick off the conversation.
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMMessagesFrame(messages)])

runner = PipelineRunner()

await runner.run(task)


if __name__ == "__main__":
asyncio.run(main())
247 changes: 247 additions & 0 deletions src/pipecat/services/fish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
import asyncio
import base64
from typing import Any, AsyncGenerator, Dict, Optional, Literal

import websockets
from loguru import logger
from pydantic import BaseModel
import ormsgpack # Import ormsgpack for MessagePack encoding/decoding

from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
LLMFullResponseEndFrame,
StartFrame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSSpeakFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import TTSService
from pipecat.transcriptions.language import Language

# FishAudio supports various output formats
FishAudioOutputFormat = Literal["opus", "mp3", "wav"]

def language_to_fishaudio_language(language: Language) -> str:
# Map Language enum to fish.audio language codes
language_map = {
Language.EN: "en-US",
Language.EN_US: "en-US",
Language.EN_GB: "en-GB",
Language.ES: "es-ES",
Language.FR: "fr-FR",
Language.DE: "de-DE",
# Add other mappings as needed
}
return language_map.get(language, "en-US") # Default to 'en-US' if not found

def sample_rate_from_output_format(output_format: str) -> int:
# FishAudio might have specific sample rates per format
format_sample_rates = {
"opus": 24000,
"mp3": 24000,
"wav": 24000,
}
return format_sample_rates.get(output_format, 24000) # Default to 24kHz

class FishAudioTTSService(TTSService):
class InputParams(BaseModel):
language: Optional[Language] = Language.EN
latency: Optional[str] = "normal" # "normal" or "balanced"
prosody_speed: Optional[float] = 1.0 # Speech speed (0.5-2.0)
prosody_volume: Optional[int] = 0 # Volume adjustment in dB

def __init__(
self,
*,
api_key: str,
model_id: str,
output_format: FishAudioOutputFormat = "wav",
params: InputParams = InputParams(),
**kwargs,
):
super().__init__(
sample_rate=sample_rate_from_output_format(output_format),
**kwargs,
)

self._api_key = api_key
self._model_id = model_id
self._url = "wss://api.fish.audio/v1/tts/live"
self._output_format = output_format

self._settings = {
"sample_rate": sample_rate_from_output_format(output_format),
# "language": self.language_to_service_language(params.language)
# if params.language else "en-US",
"latency": params.latency,
"prosody": {
"speed": params.prosody_speed,
"volume": params.prosody_volume,
},
"format": output_format,
"reference_id": model_id,
}

self._websocket = None
self._receive_task = None
self._started = False

def can_generate_metrics(self) -> bool:
return True

def language_to_service_language(self, language: Language) -> str:
return language_to_fishaudio_language(language)

async def start(self, frame: StartFrame):
await super().start(frame)
await self._connect()

async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._disconnect()

async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._disconnect()

async def _connect(self):
try:
headers = {
"Authorization": f"Bearer {self._api_key}",
}

self._websocket = await websockets.connect(self._url, extra_headers=headers)
self._receive_task = asyncio.create_task(self._receive_task_handler())

# Send 'start' event to initialize the session
start_message = {
"event": "start",
"request": {
"text": "", # Initial empty text
"latency": self._settings["latency"],
"format": self._output_format,
"prosody": self._settings["prosody"],
"reference_id": self._settings["reference_id"],
"sample_rate": self._settings["sample_rate"],
},
"debug": True, # Added debug flag
}
await self._websocket.send(ormsgpack.packb(start_message))
# logger.debug("Sent start event to fish.audio WebSocket")

except Exception as e:
# logger.exception(f"Error connecting to fish.audio WebSocket: {e}")
self._websocket = None

async def _disconnect(self):
try:
await self.stop_all_metrics()

if self._websocket:
# Send 'stop' event to end the session
stop_message = {
"event": "stop"
}
await self._websocket.send(ormsgpack.packb(stop_message))
await self._websocket.close()
self._websocket = None

if self._receive_task:
self._receive_task.cancel()
await self._receive_task
self._receive_task = None

self._started = False
except Exception as e:
logger.error(f"Error disconnecting from fish.audio WebSocket: {e}")

async def _receive_task_handler(self):
try:
while True:
try:
message = await self._websocket.recv()
if isinstance(message, bytes):
msg = ormsgpack.unpackb(message)
event = msg.get("event")

if event == "audio":
await self.stop_ttfb_metrics()
audio_data = msg.get("audio")
# Audio data is binary, no need to base64 decode
frame = TTSAudioRawFrame(
audio_data, self._settings["sample_rate"], 1)
await self.push_frame(frame)
elif event == "finish":
reason = msg.get("reason")
if reason == "stop":
await self.push_frame(TTSStoppedFrame())
self._started = False
elif reason == "error":
error_msg = msg.get("error", "Unknown error")
logger.error(f"fish.audio error: {error_msg}")
await self.push_error(ErrorFrame(f"fish.audio error: {error_msg}"))
self._started = False
elif event == "error":
error_msg = msg.get("error", "Unknown error")
logger.error(f"fish.audio error: {error_msg}")
await self.push_error(ErrorFrame(f"fish.audio error: {error_msg}"))
else:
logger.warning(f"Received unexpected message type: {type(message)}")
except asyncio.TimeoutError:
logger.warning("No message received from fish.audio within timeout period")
except websockets.ConnectionClosed as e:
logger.error(f"WebSocket connection closed: {e}")
break
except Exception as e:
logger.exception(f"Exception in receive task: {e}")

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

if isinstance(frame, TTSSpeakFrame):
await self.pause_processing_frames()
elif isinstance(frame, LLMFullResponseEndFrame) and self._started:
await self.pause_processing_frames()
elif isinstance(frame, BotStoppedSpeakingFrame):
await self.resume_processing_frames()

async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
await self.stop_all_metrics()

async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating Fish TTS: [{text}]")

try:
if not self._websocket or self._websocket.closed:
await self._connect()

if not self._started:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
self._started = True

# Send 'text' event to stream text chunks
text_message = {
"event": "text",
"text": text + " " # Ensure a space at the end
}
logger.debug(f"Sending text message: {text_message}")
await self._websocket.send(ormsgpack.packb(text_message))
logger.debug("Sent text message to fish.audio WebSocket")

await self.start_tts_usage_metrics(text)

# The audio frames will be received in _receive_task_handler
yield None

except Exception as e:
logger.error(f"Error in run_tts: {e}")
yield ErrorFrame(f"Error in run_tts: {str(e)}")
Loading