Skip to content

Commit

Permalink
Merge pull request #435 from golbin/main
Browse files Browse the repository at this point in the history
Add speed and emotion options for Cartesia.
  • Loading branch information
markbackman authored Sep 26, 2024
2 parents d11daee + d05717a commit b1818cc
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 41 deletions.
4 changes: 3 additions & 1 deletion examples/foundational/12c-describe-video-anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ async def main():
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
sample_rate=16000,
params=CartesiaTTSService.InputParams(
sample_rate=16000,
),
)

@transport.event_handler("on_first_participant_joined")
Expand Down
4 changes: 3 additions & 1 deletion examples/studypal/studypal.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ async def main():
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id=os.getenv("CARTESIA_VOICE_ID", "4d2fd738-3b3d-4368-957a-bb4805275bd9"),
# British Narration Lady: 4d2fd738-3b3d-4368-957a-bb4805275bd9
sample_rate=44100,
params=CartesiaTTSService.InputParams(
sample_rate=44100,
),
)

llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o-mini")
Expand Down
135 changes: 96 additions & 39 deletions src/pipecat/services/cartesia.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import base64
import asyncio

from typing import AsyncGenerator
from typing import AsyncGenerator, Optional, Union, List
from pydantic.main import BaseModel

from pipecat.frames.frames import (
CancelFrame,
Expand Down Expand Up @@ -61,6 +62,14 @@ def language_to_cartesia_language(language: Language) -> str | None:


class CartesiaTTSService(AsyncWordTTSService):
class InputParams(BaseModel):
encoding: Optional[str] = "pcm_s16le"
sample_rate: Optional[int] = 16000
container: Optional[str] = "raw"
language: Optional[str] = "en"
speed: Optional[Union[str, float]] = ""
emotion: Optional[List[str]] = []

def __init__(
self,
*,
Expand All @@ -69,9 +78,7 @@ def __init__(
cartesia_version: str = "2024-06-10",
url: str = "wss://api.cartesia.ai/tts/websocket",
model_id: str = "sonic-english",
encoding: str = "pcm_s16le",
sample_rate: int = 16000,
language: str = "en",
params: InputParams = InputParams(),
**kwargs,
):
# Aggregating sentences still gives cleaner-sounding results and fewer
Expand All @@ -85,20 +92,26 @@ def __init__(
# can use those to generate text frames ourselves aligned with the
# playout timing of the audio!
super().__init__(
aggregate_sentences=True, push_text_frames=False, sample_rate=sample_rate, **kwargs
aggregate_sentences=True,
push_text_frames=False,
sample_rate=params.sample_rate,
**kwargs,
)

self._api_key = api_key
self._cartesia_version = cartesia_version
self._url = url
self._voice_id = voice_id
self._model_id = model_id
self.set_model_name(model_id)
self._output_format = {
"container": "raw",
"encoding": encoding,
"sample_rate": sample_rate,
"container": params.container,
"encoding": params.encoding,
"sample_rate": params.sample_rate,
}
self._language = language
self._language = params.language
self._speed = params.speed
self._emotion = params.emotion

self._websocket = None
self._context_id = None
Expand All @@ -108,17 +121,50 @@ def can_generate_metrics(self) -> bool:
return True

async def set_model(self, model: str):
self._model_id = model
await super().set_model(model)
logger.debug(f"Switching TTS model to: [{model}]")

async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._voice_id = voice

async def set_speed(self, speed: str):
logger.debug(f"Switching TTS speed to: [{speed}]")
self._speed = speed

async def set_emotion(self, emotion: list[str]):
logger.debug(f"Switching TTS emotion to: [{emotion}]")
self._emotion = emotion

async def set_language(self, language: Language):
logger.debug(f"Switching TTS language to: [{language}]")
self._language = language_to_cartesia_language(language)

def _build_msg(
self, text: str = "", continue_transcript: bool = True, add_timestamps: bool = True
):
voice_config = {"mode": "id", "id": self._voice_id}

if self._speed or self._emotion:
voice_config["__experimental_controls"] = {}
if self._speed:
voice_config["__experimental_controls"]["speed"] = self._speed
if self._emotion:
voice_config["__experimental_controls"]["emotion"] = self._emotion

msg = {
"transcript": text,
"continue": continue_transcript,
"context_id": self._context_id,
"model_id": self._model_name,
"voice": voice_config,
"output_format": self._output_format,
"language": self._language,
"add_timestamps": add_timestamps,
}
return json.dumps(msg)

async def start(self, frame: StartFrame):
await super().start(frame)
await self._connect()
Expand All @@ -134,7 +180,8 @@ async def cancel(self, frame: CancelFrame):
async def _connect(self):
try:
self._websocket = await websockets.connect(
f"{self._url}?api_key={self._api_key}&cartesia_version={self._cartesia_version}"
f"{self._url}?api_key={self._api_key}&cartesia_version={
self._cartesia_version}"
)
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
except Exception as e:
Expand Down Expand Up @@ -173,17 +220,8 @@ async def flush_audio(self):
if not self._context_id or not self._websocket:
return
logger.trace("Flushing audio")
msg = {
"transcript": "",
"continue": False,
"context_id": self._context_id,
"model_id": self.model_name,
"voice": {"mode": "id", "id": self._voice_id},
"output_format": self._output_format,
"language": self._language,
"add_timestamps": True,
}
await self._websocket.send(json.dumps(msg))
msg = self._build_msg(text="", continue_transcript=False)
await self._websocket.send(msg)

async def _receive_task_handler(self):
try:
Expand Down Expand Up @@ -236,18 +274,10 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
await self.start_ttfb_metrics()
self._context_id = str(uuid.uuid4())

msg = {
"transcript": text + " ",
"continue": True,
"context_id": self._context_id,
"model_id": self.model_name,
"voice": {"mode": "id", "id": self._voice_id},
"output_format": self._output_format,
"language": self._language,
"add_timestamps": True,
}
msg = self._build_msg(text=text)

try:
await self._get_websocket().send(json.dumps(msg))
await self._get_websocket().send(msg)
await self.start_tts_usage_metrics(text)
except Exception as e:
logger.error(f"{self} error sending message: {e}")
Expand All @@ -261,29 +291,38 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:


class CartesiaHttpTTSService(TTSService):
class InputParams(BaseModel):
encoding: Optional[str] = "pcm_s16le"
sample_rate: Optional[int] = 16000
container: Optional[str] = "raw"
language: Optional[str] = "en"
speed: Optional[Union[str, float]] = ""
emotion: Optional[List[str]] = []

def __init__(
self,
*,
api_key: str,
voice_id: str,
model_id: str = "sonic-english",
base_url: str = "https://api.cartesia.ai",
encoding: str = "pcm_s16le",
sample_rate: int = 16000,
language: str = "en",
params: InputParams = InputParams(),
**kwargs,
):
super().__init__(**kwargs)

self._api_key = api_key
self._voice_id = voice_id
self._model_id = model_id
self.set_model_name(model_id)
self._output_format = {
"container": "raw",
"encoding": encoding,
"sample_rate": sample_rate,
"container": params.container,
"encoding": params.encoding,
"sample_rate": params.sample_rate,
}
self._language = language
self._language = params.language
self._speed = params.speed
self._emotion = params.emotion

self._client = AsyncCartesia(api_key=api_key, base_url=base_url)

Expand All @@ -293,11 +332,20 @@ def can_generate_metrics(self) -> bool:
async def set_model(self, model: str):
logger.debug(f"Switching TTS model to: [{model}]")
self._model_id = model
await super().set_model(model)

async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._voice_id = voice

async def set_speed(self, speed: str):
logger.debug(f"Switching TTS speed to: [{speed}]")
self._speed = speed

async def set_emotion(self, emotion: list[str]):
logger.debug(f"Switching TTS emotion to: [{emotion}]")
self._emotion = emotion

async def set_language(self, language: Language):
logger.debug(f"Switching TTS language to: [{language}]")
self._language = language_to_cartesia_language(language)
Expand All @@ -317,13 +365,22 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
await self.start_ttfb_metrics()

try:
voice_controls = None
if self._speed or self._emotion:
voice_controls = {}
if self._speed:
voice_controls["speed"] = self._speed
if self._emotion:
voice_controls["emotion"] = self._emotion

output = await self._client.tts.sse(
model_id=self._model_id,
transcript=text,
voice_id=self._voice_id,
output_format=self._output_format,
language=self._language,
stream=False,
_experimental_voice_controls=voice_controls,
)

await self.stop_ttfb_metrics()
Expand Down

0 comments on commit b1818cc

Please sign in to comment.