Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add speed and emotion options for Cartesia. #435

Merged
merged 8 commits into from
Sep 26, 2024
Merged
4 changes: 3 additions & 1 deletion examples/foundational/12c-describe-video-anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ async def main():
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
sample_rate=16000,
params=CartesiaTTSService.InputParams(
sample_rate=16000,
),
)

@transport.event_handler("on_first_participant_joined")
Expand Down
4 changes: 3 additions & 1 deletion examples/studypal/studypal.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ async def main():
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id=os.getenv("CARTESIA_VOICE_ID", "4d2fd738-3b3d-4368-957a-bb4805275bd9"),
# British Narration Lady: 4d2fd738-3b3d-4368-957a-bb4805275bd9
sample_rate=44100,
params=CartesiaTTSService.InputParams(
sample_rate=44100,
),
)

llm = OpenAILLMService(
Expand Down
58 changes: 43 additions & 15 deletions src/pipecat/services/cartesia.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import asyncio
import time

from typing import AsyncGenerator
from typing import AsyncGenerator, Optional
from pydantic.main import BaseModel

from pipecat.frames.frames import (
CancelFrame,
Expand Down Expand Up @@ -61,6 +62,13 @@ def language_to_cartesia_language(language: Language) -> str | None:


class CartesiaTTSService(AsyncWordTTSService):
class InputParams(BaseModel):
encoding: Optional[str] = "pcm_s16le"
sample_rate: Optional[int] = 16000
container: Optional[str] = "raw"
language: Optional[str] = "en"
speed: Optional[str] = None
emotion: Optional[list[str]] = []

def __init__(
self,
Expand All @@ -70,9 +78,7 @@ def __init__(
cartesia_version: str = "2024-06-10",
url: str = "wss://api.cartesia.ai/tts/websocket",
model_id: str = "sonic-english",
encoding: str = "pcm_s16le",
sample_rate: int = 16000,
language: str = "en",
params: InputParams = InputParams(),
**kwargs):
# Aggregating sentences still gives cleaner-sounding results and fewer
# artifacts than streaming one word at a time. On average, waiting for a
Expand All @@ -92,11 +98,13 @@ def __init__(
self._voice_id = voice_id
self._model_id = model_id
self._output_format = {
"container": "raw",
"encoding": encoding,
"sample_rate": sample_rate,
"container": params.container,
"encoding": params.encoding,
"sample_rate": params.sample_rate,
}
self._language = language
self._language = params.language
self._speed = params.speed
self._emotion = params.emotion

self._websocket = None
self._context_id = None
Expand All @@ -113,6 +121,14 @@ async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._voice_id = voice

async def set_speed(self, speed: str):
logger.debug(f"Switching TTS speed to: [{speed}]")
self._speed = speed

async def set_emotion(self, emotion: list[str]):
logger.debug(f"Switching TTS emotion to: [{emotion}]")
self._emotion = emotion

async def set_language(self, language: Language):
logger.debug(f"Switching TTS language to: [{language}]")
self._language = language_to_cartesia_language(language)
Expand All @@ -132,7 +148,8 @@ async def cancel(self, frame: CancelFrame):
async def _connect(self):
try:
self._websocket = await websockets.connect(
f"{self._url}?api_key={self._api_key}&cartesia_version={self._cartesia_version}"
f"{self._url}?api_key={self._api_key}&cartesia_version={
self._cartesia_version}"
)
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
except Exception as e:
Expand Down Expand Up @@ -197,7 +214,8 @@ async def _receive_task_handler(self):
await self.add_word_timestamps([("LLMFullResponseEndFrame", 0)])
elif msg["type"] == "timestamps":
await self.add_word_timestamps(
list(zip(msg["word_timestamps"]["words"], msg["word_timestamps"]["start"]))
list(zip(msg["word_timestamps"]["words"],
msg["word_timestamps"]["start"]))
)
elif msg["type"] == "chunk":
await self.stop_ttfb_metrics()
Expand All @@ -214,7 +232,8 @@ async def _receive_task_handler(self):
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
else:
logger.error(f"Cartesia error, unknown message type: {msg}")
logger.error(
f"Cartesia error, unknown message type: {msg}")
except asyncio.CancelledError:
pass
except Exception as e:
Expand All @@ -232,15 +251,24 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
await self.start_ttfb_metrics()
self._context_id = str(uuid.uuid4())

voice_config = {
"mode": "id",
"id": self._voice_id
}

if self._speed or self._emotion:
voice_config["__experimental_controls"] = {}
if self._speed:
voice_config["__experimental_controls"]["speed"] = self._speed
if self._emotion:
voice_config["__experimental_controls"]["emotion"] = self._emotion

msg = {
"transcript": text + " ",
"continue": True,
"context_id": self._context_id,
"model_id": self._model_id,
"voice": {
"mode": "id",
"id": self._voice_id
},
"voice": voice_config,
"output_format": self._output_format,
"language": self._language,
"add_timestamps": True,
Expand Down