Skip to content

Commit

Permalink
Update to use LLM, STT, TTS subclasses and remove setter methods
Browse files Browse the repository at this point in the history
  • Loading branch information
markbackman committed Oct 1, 2024
1 parent 88cca7b commit 08d0738
Show file tree
Hide file tree
Showing 15 changed files with 380 additions and 645 deletions.
16 changes: 15 additions & 1 deletion src/pipecat/frames/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,10 +530,24 @@ def __str__(self):
class ServiceUpdateSettingsFrame(ControlFrame):
"""A control frame containing a request to update service settings."""

service_type: str
settings: Dict[str, Any]


@dataclass
class LLMUpdateSettingsFrame(ServiceUpdateSettingsFrame):
pass


@dataclass
class TTSUpdateSettingsFrame(ServiceUpdateSettingsFrame):
pass


@dataclass
class STTUpdateSettingsFrame(ServiceUpdateSettingsFrame):
pass


@dataclass
class FunctionCallInProgressFrame(SystemFrame):
"""A frame signaling that a function call is in progress."""
Expand Down
109 changes: 19 additions & 90 deletions src/pipecat/services/ai_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import io
import wave
from abc import abstractmethod
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple

from loguru import logger

Expand All @@ -19,14 +19,15 @@
ErrorFrame,
Frame,
LLMFullResponseEndFrame,
ServiceUpdateSettingsFrame,
StartFrame,
StartInterruptionFrame,
STTUpdateSettingsFrame,
TextFrame,
TTSAudioRawFrame,
TTSSpeakFrame,
TTSStartedFrame,
TTSStoppedFrame,
TTSUpdateSettingsFrame,
UserImageRequestFrame,
VisionImageRawFrame,
)
Expand Down Expand Up @@ -175,70 +176,10 @@ def __init__(

self._current_sentence: str = ""

@property
def sample_rate(self) -> int:
return self._sample_rate

@abstractmethod
async def set_model(self, model: str):
self.set_model_name(model)

@abstractmethod
async def set_voice(self, voice: str):
pass

@abstractmethod
async def set_language(self, language: Language):
pass

@abstractmethod
async def set_speed(self, speed: Union[str, float]):
pass

@abstractmethod
async def set_emotion(self, emotion: List[str]):
pass

@abstractmethod
async def set_engine(self, engine: str):
pass

@abstractmethod
async def set_pitch(self, pitch: str):
pass

@abstractmethod
async def set_rate(self, rate: str):
pass

@abstractmethod
async def set_volume(self, volume: str):
pass

@abstractmethod
async def set_emphasis(self, emphasis: str):
pass

@abstractmethod
async def set_style(self, style: str):
pass

@abstractmethod
async def set_style_degree(self, style_degree: str):
pass

@abstractmethod
async def set_role(self, role: str):
pass

@abstractmethod
async def set_gender(self, gender: str):
pass

@abstractmethod
async def set_google_style(self, google_style: str):
pass

@abstractmethod
async def flush_audio(self):
pass
Expand Down Expand Up @@ -269,20 +210,16 @@ async def cancel(self, frame: CancelFrame):

async def _update_settings(self, settings: Dict[str, Any]):
for key, value in settings.items():
setter = getattr(self, f"set_{key}", None)
if setter and callable(setter):
try:
if key == "language":
await setter(Language(value))
else:
await setter(value)
except Exception as e:
logger.warning(f"Error setting {key}: {e}")
if key in self._settings:
logger.debug(f"Updating TTS setting {key} to: [{value}]")
self._settings[key] = value
if key == "model":
self.set_model_name(value)
elif key == "language":
self._settings[key] = Language(value)
else:
logger.warning(f"Unknown setting for TTS service: {key}")

self._settings.update(settings)

async def say(self, text: str):
aggregate_sentences = self._aggregate_sentences
self._aggregate_sentences = False
Expand All @@ -309,7 +246,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
elif isinstance(frame, TTSSpeakFrame):
await self._push_tts_frames(frame.text)
await self.flush_audio()
elif isinstance(frame, ServiceUpdateSettingsFrame) and frame.service_type == "tts":
elif isinstance(frame, TTSUpdateSettingsFrame):
await self._update_settings(frame.settings)
else:
await self.push_frame(frame, direction)
Expand Down Expand Up @@ -448,31 +385,23 @@ def __init__(self, **kwargs):
async def set_model(self, model: str):
self.set_model_name(model)

@abstractmethod
async def set_language(self, language: Language):
pass

@abstractmethod
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Returns transcript as a string"""
pass

async def _update_settings(self, settings: Dict[str, Any]):
for key, value in settings.items():
setter = getattr(self, f"set_{key}", None)
if setter and callable(setter):
try:
if key == "language":
await setter(Language(value))
else:
await setter(value)
except Exception as e:
logger.warning(f"Error setting {key}: {e}")
if key in self._settings:
logger.debug(f"Updating STT setting {key} to: [{value}]")
self._settings[key] = value
if key == "model":
self.set_model_name(value)
elif key == "language":
self._settings[key] = Language(value)
else:
logger.warning(f"Unknown setting for STT service: {key}")

self._settings.update(settings)

async def process_audio_frame(self, frame: AudioRawFrame):
await self.process_generator(self.run_stt(frame.audio))

Expand All @@ -484,7 +413,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
# In this service we accumulate audio internally and at the end we
# push a TextFrame. We don't really want to push audio frames down.
await self.process_audio_frame(frame)
elif isinstance(frame, ServiceUpdateSettingsFrame) and frame.service_type == "stt":
elif isinstance(frame, STTUpdateSettingsFrame):
await self._update_settings(frame.settings)
else:
await self.push_frame(frame, direction)
Expand Down
70 changes: 24 additions & 46 deletions src/pipecat/services/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
ServiceUpdateSettingsFrame,
LLMUpdateSettingsFrame,
StartInterruptionFrame,
TextFrame,
UserImageRawFrame,
Expand Down Expand Up @@ -96,12 +96,15 @@ def __init__(
super().__init__(**kwargs)
self._client = AsyncAnthropic(api_key=api_key)
self.set_model_name(model)
self._max_tokens = params.max_tokens
self._enable_prompt_caching_beta: bool = params.enable_prompt_caching_beta or False
self._temperature = params.temperature
self._top_k = params.top_k
self._top_p = params.top_p
self._extra = params.extra if isinstance(params.extra, dict) else {}
self._settings = {
"model": model,
"max_tokens": params.max_tokens,
"enable_prompt_caching_beta": params.enable_prompt_caching_beta or False,
"temperature": params.temperature,
"top_k": params.top_k,
"top_p": params.top_p,
"extra": params.extra if isinstance(params.extra, dict) else {},
}

def can_generate_metrics(self) -> bool:
return True
Expand All @@ -120,30 +123,6 @@ def create_context_aggregator(
)
return AnthropicContextAggregatorPair(_user=user, _assistant=assistant)

async def set_enable_prompt_caching_beta(self, enable_prompt_caching_beta: bool):
logger.debug(f"Switching LLM enable_prompt_caching_beta to: [{enable_prompt_caching_beta}]")
self._enable_prompt_caching_beta = enable_prompt_caching_beta

async def set_max_tokens(self, max_tokens: int):
logger.debug(f"Switching LLM max_tokens to: [{max_tokens}]")
self._max_tokens = max_tokens

async def set_temperature(self, temperature: float):
logger.debug(f"Switching LLM temperature to: [{temperature}]")
self._temperature = temperature

async def set_top_k(self, top_k: float):
logger.debug(f"Switching LLM top_k to: [{top_k}]")
self._top_k = top_k

async def set_top_p(self, top_p: float):
logger.debug(f"Switching LLM top_p to: [{top_p}]")
self._top_p = top_p

async def set_extra(self, extra: Dict[str, Any]):
logger.debug(f"Switching LLM extra to: [{extra}]")
self._extra = extra

async def _process_context(self, context: OpenAILLMContext):
# Usage tracking. We track the usage reported by Anthropic in prompt_tokens and
# completion_tokens. We also estimate the completion tokens from output text
Expand All @@ -165,11 +144,11 @@ async def _process_context(self, context: OpenAILLMContext):
)

messages = context.messages
if self._enable_prompt_caching_beta:
if self._settings["enable_prompt_caching_beta"]:
messages = context.get_messages_with_cache_control_markers()

api_call = self._client.messages.create
if self._enable_prompt_caching_beta:
if self._settings["enable_prompt_caching_beta"]:
api_call = self._client.beta.prompt_caching.messages.create

await self.start_ttfb_metrics()
Expand All @@ -179,14 +158,14 @@ async def _process_context(self, context: OpenAILLMContext):
"system": context.system,
"messages": messages,
"model": self.model_name,
"max_tokens": self._max_tokens,
"max_tokens": self._settings["max_tokens"],
"stream": True,
"temperature": self._temperature,
"top_k": self._top_k,
"top_p": self._top_p,
"temperature": self._settings["temperature"],
"top_k": self._settings["top_k"],
"top_p": self._settings["top_p"],
}

params.update(self._extra)
params.update(self._settings["extra"])

response = await api_call(**params)

Expand Down Expand Up @@ -286,12 +265,11 @@ async def _process_context(self, context: OpenAILLMContext):

async def _update_settings(self, settings: Dict[str, Any]):
for key, value in settings.items():
setter = getattr(self, f"set_{key}", None)
if setter and callable(setter):
try:
await setter(value)
except Exception as e:
logger.warning(f"Error setting {key}: {e}")
if key in self._settings:
logger.debug(f"Updating LLM setting {key} to: [{value}]")
self._settings[key] = value
if key == "model":
self.set_model_name(value)
else:
logger.warning(f"Unknown setting for Anthropic LLM service: {key}")

Expand All @@ -309,11 +287,11 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
# UserImageRawFrames coming through the pipeline and add them
# to the context.
context = AnthropicLLMContext.from_image_frame(frame)
elif isinstance(frame, ServiceUpdateSettingsFrame) and frame.service_type == "llm":
elif isinstance(frame, LLMUpdateSettingsFrame):
await self._update_settings(frame.settings)
elif isinstance(frame, LLMEnablePromptCachingFrame):
logger.debug(f"Setting enable prompt caching to: [{frame.enable}]")
self._enable_prompt_caching_beta = frame.enable
self._settings["enable_prompt_caching_beta"] = frame.enable
else:
await self.push_frame(frame, direction)

Expand Down
Loading

0 comments on commit 08d0738

Please sign in to comment.