Skip to content

Commit

Permalink
Merge pull request #475 from pipecat-ai/mb/tts-sample-rate
Browse files Browse the repository at this point in the history
Add sample_rate setting to TTS services
  • Loading branch information
markbackman authored Sep 18, 2024
2 parents 13a4a05 + eadd68d commit 6f3c421
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 18 deletions.
6 changes: 6 additions & 0 deletions src/pipecat/services/ai_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __init__(
push_stop_frames: bool = False,
# if push_stop_frames is True, wait for this idle period before pushing TTSStoppedFrame
stop_frame_timeout_s: float = 1.0,
sample_rate: int = 16000,
**kwargs):
super().__init__(**kwargs)
self._aggregate_sentences: bool = aggregate_sentences
Expand All @@ -180,6 +181,11 @@ def __init__(
self._stop_frame_task: Optional[asyncio.Task] = None
self._stop_frame_queue: asyncio.Queue = asyncio.Queue()
self._current_sentence: str = ""
self._sample_rate: int = sample_rate

@property
def sample_rate(self) -> int:
return self._sample_rate

@abstractmethod
async def set_model(self, model: str):
Expand Down
14 changes: 11 additions & 3 deletions src/pipecat/services/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,21 @@ def create_client(self, api_key=None, base_url=None, **kwargs):


class AzureTTSService(TTSService):
def __init__(self, *, api_key: str, region: str, voice="en-US-SaraNeural", **kwargs):
super().__init__(**kwargs)
def __init__(
self,
*,
api_key: str,
region: str,
voice="en-US-SaraNeural",
sample_rate: int = 16000,
**kwargs):
super().__init__(sample_rate=sample_rate, **kwargs)

speech_config = SpeechConfig(subscription=api_key, region=region)
self._speech_synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=None)

self._voice = voice
self._sample_rate = sample_rate

def can_generate_metrics(self) -> bool:
return True
Expand Down Expand Up @@ -109,7 +117,7 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
await self.stop_ttfb_metrics()
await self.push_frame(TTSStartedFrame())
# Azure always sends a 44-byte header. Strip it off.
yield AudioRawFrame(audio=result.audio_data[44:], sample_rate=16000, num_channels=1)
yield AudioRawFrame(audio=result.audio_data[44:], sample_rate=self._sample_rate, num_channels=1)
await self.push_frame(TTSStoppedFrame())
elif result.reason == ResultReason.Canceled:
cancellation_details = result.cancellation_details
Expand Down
2 changes: 1 addition & 1 deletion src/pipecat/services/cartesia.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
# if we're interrupted. Cartesia gives us word-by-word timestamps. We
# can use those to generate text frames ourselves aligned with the
# playout timing of the audio!
super().__init__(aggregate_sentences=True, push_text_frames=False, **kwargs)
super().__init__(aggregate_sentences=True, push_text_frames=False, sample_rate=sample_rate, **kwargs)

self._api_key = api_key
self._cartesia_version = cartesia_version
Expand Down
4 changes: 2 additions & 2 deletions src/pipecat/services/elevenlabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
push_text_frames=False,
push_stop_frames=True,
stop_frame_timeout_s=2.0,
sample_rate=sample_rate_from_output_format(params.output_format),
**kwargs
)

Expand All @@ -109,7 +110,6 @@ def __init__(
self._model = model
self._url = url
self._params = params
self._sample_rate = sample_rate_from_output_format(params.output_format)

# Websocket connection to ElevenLabs.
self._websocket = None
Expand Down Expand Up @@ -209,7 +209,7 @@ async def _receive_task_handler(self):
self.start_word_timestamps()

audio = base64.b64decode(msg["audio"])
frame = AudioRawFrame(audio, self._sample_rate, 1)
frame = AudioRawFrame(audio, self.sample_rate, 1)
await self.push_frame(frame)

if msg.get("alignment"):
Expand Down
2 changes: 1 addition & 1 deletion src/pipecat/services/lmnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
**kwargs):
# Let TTSService produce TTSStoppedFrames after a short delay of
# no activity.
super().__init__(push_stop_frames=True, **kwargs)
super().__init__(push_stop_frames=True, sample_rate=sample_rate, **kwargs)

self._api_key = api_key
self._voice_id = voice_id
Expand Down
33 changes: 25 additions & 8 deletions src/pipecat/services/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import httpx
from dataclasses import dataclass

from typing import AsyncGenerator, List, Literal
from typing import AsyncGenerator, Dict, List, Literal

from loguru import logger
from PIL import Image
Expand Down Expand Up @@ -55,6 +55,17 @@
"In order to use OpenAI, you need to `pip install pipecat-ai[openai]`. Also, set `OPENAI_API_KEY` environment variable.")
raise Exception(f"Missing module: {e}")

ValidVoice = Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"]

VALID_VOICES: Dict[str, ValidVoice] = {
"alloy": "alloy",
"echo": "echo",
"fable": "fable",
"onyx": "onyx",
"nova": "nova",
"shimmer": "shimmer",
}


class OpenAIUnhandledFunctionException(Exception):
pass
Expand Down Expand Up @@ -182,8 +193,8 @@ async def _process_context(self, context: OpenAILLMContext):
if self.has_function(function_name):
await self._handle_function_call(context, tool_call_id, function_name, arguments)
else:
raise OpenAIUnhandledFunctionException(
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function.")
raise OpenAIUnhandledFunctionException(f"The LLM tried to call a function named '{
function_name}', but there isn't a callback registered for that function.")

async def _handle_function_call(
self,
Expand Down Expand Up @@ -307,13 +318,15 @@ def __init__(
self,
*,
api_key: str | None = None,
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = "alloy",
voice: str = "alloy",
model: Literal["tts-1", "tts-1-hd"] = "tts-1",
sample_rate: int = 24000,
**kwargs):
super().__init__(**kwargs)
super().__init__(sample_rate=sample_rate, **kwargs)

self._voice = voice
self._voice: ValidVoice = VALID_VOICES.get(voice, "alloy")
self._model = model
self._sample_rate = sample_rate

self._client = AsyncOpenAI(api_key=api_key)

Expand All @@ -322,7 +335,11 @@ def can_generate_metrics(self) -> bool:

async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._voice = voice
self._voice = VALID_VOICES.get(voice, self._voice)

async def set_model(self, model: str):
logger.debug(f"Switching TTS model to: [{model}]")
self._model = model

async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating TTS: [{text}]")
Expand All @@ -348,7 +365,7 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
async for chunk in r.iter_bytes(8192):
if len(chunk) > 0:
await self.stop_ttfb_metrics()
frame = AudioRawFrame(chunk, 24_000, 1)
frame = AudioRawFrame(chunk, self.sample_rate, 1)
yield frame
await self.push_frame(TTSStoppedFrame())
except BadRequestError as e:
Expand Down
17 changes: 14 additions & 3 deletions src/pipecat/services/playht.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,15 @@

class PlayHTTTSService(TTSService):

def __init__(self, *, api_key: str, user_id: str, voice_url: str, **kwargs):
super().__init__(**kwargs)
def __init__(
self,
*,
api_key: str,
user_id: str,
voice_url: str,
sample_rate: int = 16000,
**kwargs):
super().__init__(sample_rate=sample_rate, **kwargs)

self._user_id = user_id
self._speech_key = api_key
Expand All @@ -39,13 +46,17 @@ def __init__(self, *, api_key: str, user_id: str, voice_url: str, **kwargs):
)
self._options = TTSOptions(
voice=voice_url,
sample_rate=16000,
sample_rate=sample_rate,
quality="higher",
format=Format.FORMAT_WAV)

def can_generate_metrics(self) -> bool:
return True

async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._options.voice = voice

async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating TTS: [{text}]")

Expand Down

0 comments on commit 6f3c421

Please sign in to comment.