Skip to content

Commit

Permalink
Apply and Fix upstream changes for Cartesia
Browse files Browse the repository at this point in the history
  • Loading branch information
golbin committed Sep 23, 2024
1 parent cf72129 commit 49f2123
Showing 1 changed file with 69 additions and 44 deletions.
113 changes: 69 additions & 44 deletions src/pipecat/services/cartesia.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import base64
import asyncio

from typing import AsyncGenerator, Optional
from typing import AsyncGenerator, Optional, Union, List
from pydantic.main import BaseModel

from pipecat.frames.frames import (
Expand Down Expand Up @@ -67,8 +67,8 @@ class InputParams(BaseModel):
sample_rate: Optional[int] = 16000
container: Optional[str] = "raw"
language: Optional[str] = "en"
speed: Optional[str] = None
emotion: Optional[list[str]] = []
speed: Optional[Union[str, float]] = ""
emotion: Optional[List[str]] = []

def __init__(
self,
Expand All @@ -91,13 +91,14 @@ def __init__(
# can use those to generate text frames ourselves aligned with the
# playout timing of the audio!
super().__init__(
aggregate_sentences=True, push_text_frames=False, sample_rate=sample_rate, **kwargs
aggregate_sentences=True, push_text_frames=False, sample_rate=params.sample_rate, **kwargs
)

self._api_key = api_key
self._cartesia_version = cartesia_version
self._url = url
self._voice_id = voice_id
self._model_id = model_id
self.set_model_name(model_id)
self._output_format = {
"container": params.container,
Expand All @@ -116,6 +117,7 @@ def can_generate_metrics(self) -> bool:
return True

async def set_model(self, model: str):
self._model_id = model
await super().set_model(model)
logger.debug(f"Switching TTS model to: [{model}]")

Expand All @@ -135,6 +137,31 @@ async def set_language(self, language: Language):
logger.debug(f"Switching TTS language to: [{language}]")
self._language = language_to_cartesia_language(language)

def _build_msg(self, text: str = "", continue_transcript: bool = True, add_timestamps: bool = True):
voice_config = {
"mode": "id",
"id": self._voice_id
}

if self._speed or self._emotion:
voice_config["__experimental_controls"] = {}
if self._speed:
voice_config["__experimental_controls"]["speed"] = self._speed
if self._emotion:
voice_config["__experimental_controls"]["emotion"] = self._emotion

msg = {
"transcript": text,
"continue": continue_transcript,
"context_id": self._context_id,
"model_id": self._model_name,
"voice": voice_config,
"output_format": self._output_format,
"language": self._language,
"add_timestamps": add_timestamps,
}
return json.dumps(msg)

async def start(self, frame: StartFrame):
await super().start(frame)
await self._connect()
Expand Down Expand Up @@ -190,17 +217,8 @@ async def flush_audio(self):
if not self._context_id or not self._websocket:
return
logger.trace("Flushing audio")
msg = {
"transcript": "",
"continue": False,
"context_id": self._context_id,
"model_id": self.model_name,
"voice": {"mode": "id", "id": self._voice_id},
"output_format": self._output_format,
"language": self._language,
"add_timestamps": True,
}
await self._websocket.send(json.dumps(msg))
msg = self._build_msg(text="", continue_transcript=False)
await self._websocket.send(msg)

async def _receive_task_handler(self):
try:
Expand Down Expand Up @@ -255,30 +273,10 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
await self.start_ttfb_metrics()
self._context_id = str(uuid.uuid4())

voice_config = {
"mode": "id",
"id": self._voice_id
}
msg = self._build_msg(text=text)

if self._speed or self._emotion:
voice_config["__experimental_controls"] = {}
if self._speed:
voice_config["__experimental_controls"]["speed"] = self._speed
if self._emotion:
voice_config["__experimental_controls"]["emotion"] = self._emotion

msg = {
"transcript": text + " ",
"continue": True,
"context_id": self._context_id,
"model_id": self._model_id,
"voice": voice_config,
"output_format": self._output_format,
"language": self._language,
"add_timestamps": True,
}
try:
await self._get_websocket().send(json.dumps(msg))
await self._get_websocket().send(msg)
await self.start_tts_usage_metrics(text)
except Exception as e:
logger.error(f"{self} error sending message: {e}")
Expand All @@ -292,29 +290,38 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:


class CartesiaHttpTTSService(TTSService):
class InputParams(BaseModel):
encoding: Optional[str] = "pcm_s16le"
sample_rate: Optional[int] = 16000
container: Optional[str] = "raw"
language: Optional[str] = "en"
speed: Optional[Union[str, float]] = ""
emotion: Optional[List[str]] = []

def __init__(
self,
*,
api_key: str,
voice_id: str,
model_id: str = "sonic-english",
base_url: str = "https://api.cartesia.ai",
encoding: str = "pcm_s16le",
sample_rate: int = 16000,
language: str = "en",
params: InputParams = InputParams(),
**kwargs,
):
super().__init__(**kwargs)

self._api_key = api_key
self._voice_id = voice_id
self._model_id = model_id
self.set_model_name(model_id)
self._output_format = {
"container": "raw",
"encoding": encoding,
"sample_rate": sample_rate,
"container": params.container,
"encoding": params.encoding,
"sample_rate": params.sample_rate,
}
self._language = language
self._language = params.language
self._speed = params.speed
self._emotion = params.emotion

self._client = AsyncCartesia(api_key=api_key, base_url=base_url)

Expand All @@ -324,11 +331,20 @@ def can_generate_metrics(self) -> bool:
async def set_model(self, model: str):
logger.debug(f"Switching TTS model to: [{model}]")
self._model_id = model
await super().set_model(model)

async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._voice_id = voice

async def set_speed(self, speed: str):
logger.debug(f"Switching TTS speed to: [{speed}]")
self._speed = speed

async def set_emotion(self, emotion: list[str]):
logger.debug(f"Switching TTS emotion to: [{emotion}]")
self._emotion = emotion

async def set_language(self, language: Language):
logger.debug(f"Switching TTS language to: [{language}]")
self._language = language_to_cartesia_language(language)
Expand All @@ -348,13 +364,22 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
await self.start_ttfb_metrics()

try:
voice_controls = None
if self._speed or self._emotion:
voice_controls = {}
if self._speed:
voice_controls["speed"] = self._speed
if self._emotion:
voice_controls["emotion"] = self._emotion

output = await self._client.tts.sse(
model_id=self._model_id,
transcript=text,
voice_id=self._voice_id,
output_format=self._output_format,
language=self._language,
stream=False,
_experimental_voice_controls=voice_controls
)

await self.stop_ttfb_metrics()
Expand Down

0 comments on commit 49f2123

Please sign in to comment.