Skip to content

Commit

Permalink
Consolidate service UpdateSettingsFrame into a single ServiceUpdateSe…
Browse files Browse the repository at this point in the history
…ttingsFrame
  • Loading branch information
markbackman committed Oct 1, 2024
1 parent c447115 commit 226ad77
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 143 deletions.
46 changes: 5 additions & 41 deletions src/pipecat/frames/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#

from dataclasses import dataclass, field
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple

from pipecat.clocks.base_clock import BaseClock
from pipecat.metrics.metrics import MetricsData
Expand Down Expand Up @@ -527,47 +527,11 @@ def __str__(self):


@dataclass
class LLMUpdateSettingsFrame(ControlFrame):
"""A control frame containing a request to update LLM settings."""
class ServiceUpdateSettingsFrame(ControlFrame):
"""A control frame containing a request to update service settings."""

model: Optional[str] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
max_tokens: Optional[int] = None
seed: Optional[int] = None
extra: dict = field(default_factory=dict)


@dataclass
class TTSUpdateSettingsFrame(ControlFrame):
"""A control frame containing a request to update TTS settings."""

model: Optional[str] = None
voice: Optional[str] = None
language: Optional[Language] = None
speed: Optional[Union[str, float]] = None
emotion: Optional[List[str]] = None
engine: Optional[str] = None
pitch: Optional[str] = None
rate: Optional[str] = None
volume: Optional[str] = None
emphasis: Optional[str] = None
style: Optional[str] = None
style_degree: Optional[str] = None
role: Optional[str] = None
gender: Optional[str] = None
google_style: Optional[str] = None


@dataclass
class STTUpdateSettingsFrame(ControlFrame):
"""A control frame containing a request to update STT settings."""

model: Optional[str] = None
language: Optional[Language] = None
service_type: str
settings: Dict[str, Any]


@dataclass
Expand Down
79 changes: 39 additions & 40 deletions src/pipecat/services/ai_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import io
import wave
from abc import abstractmethod
from typing import AsyncGenerator, List, Optional, Tuple, Union
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union

from loguru import logger

Expand All @@ -19,15 +19,14 @@
ErrorFrame,
Frame,
LLMFullResponseEndFrame,
ServiceUpdateSettingsFrame,
StartFrame,
StartInterruptionFrame,
STTUpdateSettingsFrame,
TextFrame,
TTSAudioRawFrame,
TTSSpeakFrame,
TTSStartedFrame,
TTSStoppedFrame,
TTSUpdateSettingsFrame,
UserImageRequestFrame,
VisionImageRawFrame,
)
Expand Down Expand Up @@ -153,6 +152,7 @@ def __init__(
self._push_text_frames: bool = push_text_frames
self._current_sentence: str = ""
self._sample_rate: int = sample_rate
self._settings: Dict[str, Any] = {}

@property
def sample_rate(self) -> int:
Expand Down Expand Up @@ -223,6 +223,22 @@ async def set_google_style(self, google_style: str):
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
pass

async def _update_settings(self, settings: Dict[str, Any]):
for key, value in settings.items():
setter = getattr(self, f"set_{key}", None)
if setter and callable(setter):
try:
if key == "language":
await setter(Language(value))
else:
await setter(value)
except Exception as e:
logger.warning(f"Error setting {key}: {e}")
else:
logger.warning(f"Unknown setting for TTS service: {key}")

self._settings.update(settings)

async def say(self, text: str):
await self.process_frame(TextFrame(text=text), FrameDirection.DOWNSTREAM)

Expand Down Expand Up @@ -256,34 +272,6 @@ async def _push_tts_frames(self, text: str):
# interrupted, the text is not added to the assistant context.
await self.push_frame(TextFrame(text))

async def _update_tts_settings(self, frame: TTSUpdateSettingsFrame):
if frame.model is not None:
await self.set_model(frame.model)
if frame.voice is not None:
await self.set_voice(frame.voice)
if frame.language is not None:
await self.set_language(frame.language)
if frame.speed is not None:
await self.set_speed(frame.speed)
if frame.emotion is not None:
await self.set_emotion(frame.emotion)
if frame.engine is not None:
await self.set_engine(frame.engine)
if frame.pitch is not None:
await self.set_pitch(frame.pitch)
if frame.rate is not None:
await self.set_rate(frame.rate)
if frame.volume is not None:
await self.set_volume(frame.volume)
if frame.emphasis is not None:
await self.set_emphasis(frame.emphasis)
if frame.style is not None:
await self.set_style(frame.style)
if frame.style_degree is not None:
await self.set_style_degree(frame.style_degree)
if frame.role is not None:
await self.set_role(frame.role)

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

Expand All @@ -302,8 +290,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
await self.push_frame(frame, direction)
elif isinstance(frame, TTSSpeakFrame):
await self._push_tts_frames(frame.text)
elif isinstance(frame, TTSUpdateSettingsFrame):
await self._update_tts_settings(frame)
elif isinstance(frame, ServiceUpdateSettingsFrame) and frame.service_type == "tts":
await self._update_settings(frame.settings)
else:
await self.push_frame(frame, direction)

Expand Down Expand Up @@ -451,6 +439,7 @@ class STTService(AIService):

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._settings: Dict[str, Any] = {}

@abstractmethod
async def set_model(self, model: str):
Expand All @@ -465,11 +454,21 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Returns transcript as a string"""
pass

async def _update_stt_settings(self, frame: STTUpdateSettingsFrame):
if frame.model is not None:
await self.set_model(frame.model)
if frame.language is not None:
await self.set_language(frame.language)
async def _update_settings(self, settings: Dict[str, Any]):
for key, value in settings.items():
setter = getattr(self, f"set_{key}", None)
if setter and callable(setter):
try:
if key == "language":
await setter(Language(value))
else:
await setter(value)
except Exception as e:
logger.warning(f"Error setting {key}: {e}")
else:
logger.warning(f"Unknown setting for STT service: {key}")

self._settings.update(settings)

async def process_audio_frame(self, frame: AudioRawFrame):
await self.process_generator(self.run_stt(frame.audio))
Expand All @@ -482,8 +481,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
# In this service we accumulate audio internally and at the end we
# push a TextFrame. We don't really want to push audio frames down.
await self.process_audio_frame(frame)
elif isinstance(frame, STTUpdateSettingsFrame):
await self._update_stt_settings(frame)
elif isinstance(frame, ServiceUpdateSettingsFrame) and frame.service_type == "stt":
await self._update_settings(frame.settings)
else:
await self.push_frame(frame, direction)

Expand Down
47 changes: 25 additions & 22 deletions src/pipecat/services/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
LLMUpdateSettingsFrame,
ServiceUpdateSettingsFrame,
StartInterruptionFrame,
TextFrame,
UserImageRawFrame,
Expand Down Expand Up @@ -110,9 +110,13 @@ def enable_prompt_caching_beta(self) -> bool:
return self._enable_prompt_caching_beta

@staticmethod
def create_context_aggregator(context: OpenAILLMContext) -> AnthropicContextAggregatorPair:
def create_context_aggregator(
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
) -> AnthropicContextAggregatorPair:
user = AnthropicUserContextAggregator(context)
assistant = AnthropicAssistantContextAggregator(user)
assistant = AnthropicAssistantContextAggregator(
user, expect_stripped_words=assistant_expect_stripped_words
)
return AnthropicContextAggregatorPair(_user=user, _assistant=assistant)

async def set_enable_prompt_caching_beta(self, enable_prompt_caching_beta: bool):
Expand Down Expand Up @@ -279,20 +283,16 @@ async def _process_context(self, context: OpenAILLMContext):
cache_read_input_tokens=cache_read_input_tokens,
)

async def _update_settings(self, frame: LLMUpdateSettingsFrame):
if frame.model is not None:
logger.debug(f"Switching LLM model to: [{frame.model}]")
self.set_model_name(frame.model)
if frame.max_tokens is not None:
await self.set_max_tokens(frame.max_tokens)
if frame.temperature is not None:
await self.set_temperature(frame.temperature)
if frame.top_k is not None:
await self.set_top_k(frame.top_k)
if frame.top_p is not None:
await self.set_top_p(frame.top_p)
if frame.extra:
await self.set_extra(frame.extra)
async def _update_settings(self, settings: Dict[str, Any]):
for key, value in settings.items():
setter = getattr(self, f"set_{key}", None)
if setter and callable(setter):
try:
await setter(value)
except Exception as e:
logger.warning(f"Error setting {key}: {e}")
else:
logger.warning(f"Unknown setting for Anthropic LLM service: {key}")

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
Expand All @@ -308,8 +308,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
# UserImageRawFrames coming through the pipeline and add them
# to the context.
context = AnthropicLLMContext.from_image_frame(frame)
elif isinstance(frame, LLMUpdateSettingsFrame):
await self._update_settings(frame)
elif isinstance(frame, ServiceUpdateSettingsFrame) and frame.service_type == "llm":
await self._update_settings(frame.settings)
elif isinstance(frame, LLMEnablePromptCachingFrame):
logger.debug(f"Setting enable prompt caching to: [{frame.enable}]")
self._enable_prompt_caching_beta = frame.enable
Expand Down Expand Up @@ -541,8 +541,8 @@ async def process_frame(self, frame, direction):


class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
def __init__(self, user_context_aggregator: AnthropicUserContextAggregator):
super().__init__(context=user_context_aggregator._context)
def __init__(self, user_context_aggregator: AnthropicUserContextAggregator, **kwargs):
super().__init__(context=user_context_aggregator._context, **kwargs)
self._user_context_aggregator = user_context_aggregator
self._function_call_in_progress = None
self._function_call_result = None
Expand Down Expand Up @@ -579,7 +579,7 @@ async def _push_aggregation(self):
run_llm = False

aggregation = self._aggregation
self._aggregation = ""
self._reset()

try:
if self._function_call_result:
Expand Down Expand Up @@ -630,5 +630,8 @@ async def _push_aggregation(self):
if run_llm:
await self._user_context_aggregator.push_context_frame()

frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)

except Exception as e:
logger.error(f"Error processing frame: {e}")
32 changes: 13 additions & 19 deletions src/pipecat/services/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
LLMUpdateSettingsFrame,
ServiceUpdateSettingsFrame,
StartInterruptionFrame,
TextFrame,
TTSAudioRawFrame,
Expand Down Expand Up @@ -273,22 +273,16 @@ async def _handle_function_call(self, context, tool_call_id, function_name, argu
arguments=arguments,
)

async def _update_settings(self, frame: LLMUpdateSettingsFrame):
if frame.model is not None:
logger.debug(f"Switching LLM model to: [{frame.model}]")
self.set_model_name(frame.model)
if frame.frequency_penalty is not None:
await self.set_frequency_penalty(frame.frequency_penalty)
if frame.presence_penalty is not None:
await self.set_presence_penalty(frame.presence_penalty)
if frame.seed is not None:
await self.set_seed(frame.seed)
if frame.temperature is not None:
await self.set_temperature(frame.temperature)
if frame.top_p is not None:
await self.set_top_p(frame.top_p)
if frame.extra:
await self.set_extra(frame.extra)
async def _update_settings(self, settings: Dict[str, Any]):
for key, value in settings.items():
setter = getattr(self, f"set_{key}", None)
if setter and callable(setter):
try:
await setter(value)
except Exception as e:
logger.warning(f"Error setting {key}: {e}")
else:
logger.warning(f"Unknown setting for OpenAI LLM service: {key}")

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
Expand All @@ -300,8 +294,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
context = OpenAILLMContext.from_messages(frame.messages)
elif isinstance(frame, VisionImageRawFrame):
context = OpenAILLMContext.from_image_frame(frame)
elif isinstance(frame, LLMUpdateSettingsFrame):
await self._update_settings(frame)
elif isinstance(frame, ServiceUpdateSettingsFrame) and frame.service_type == "llm":
await self._update_settings(frame.settings)
else:
await self.push_frame(frame, direction)

Expand Down
Loading

0 comments on commit 226ad77

Please sign in to comment.