From 226ad7776a6b8f3458ff09ce2737642e277251a5 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Tue, 1 Oct 2024 10:35:59 -0400 Subject: [PATCH] Consolidate service UpdateSettingsFrame into a single ServiceUpdateSettingsFrame --- src/pipecat/frames/frames.py | 46 ++--------------- src/pipecat/services/ai_services.py | 79 ++++++++++++++--------------- src/pipecat/services/anthropic.py | 47 +++++++++-------- src/pipecat/services/openai.py | 32 +++++------- src/pipecat/services/together.py | 34 +++++-------- 5 files changed, 95 insertions(+), 143 deletions(-) diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index ab704824c..c07ed5045 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -5,7 +5,7 @@ # from dataclasses import dataclass, field -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple from pipecat.clocks.base_clock import BaseClock from pipecat.metrics.metrics import MetricsData @@ -527,47 +527,11 @@ def __str__(self): @dataclass -class LLMUpdateSettingsFrame(ControlFrame): - """A control frame containing a request to update LLM settings.""" +class ServiceUpdateSettingsFrame(ControlFrame): + """A control frame containing a request to update service settings.""" - model: Optional[str] = None - temperature: Optional[float] = None - top_k: Optional[int] = None - top_p: Optional[float] = None - frequency_penalty: Optional[float] = None - presence_penalty: Optional[float] = None - max_tokens: Optional[int] = None - seed: Optional[int] = None - extra: dict = field(default_factory=dict) - - -@dataclass -class TTSUpdateSettingsFrame(ControlFrame): - """A control frame containing a request to update TTS settings.""" - - model: Optional[str] = None - voice: Optional[str] = None - language: Optional[Language] = None - speed: Optional[Union[str, float]] = None - emotion: Optional[List[str]] = None - engine: Optional[str] = None - pitch: Optional[str] = None - rate: Optional[str] = None - volume: Optional[str] = None - emphasis: Optional[str] = None - style: Optional[str] = None - style_degree: Optional[str] = None - role: Optional[str] = None - gender: Optional[str] = None - google_style: Optional[str] = None - - -@dataclass -class STTUpdateSettingsFrame(ControlFrame): - """A control frame containing a request to update STT settings.""" - - model: Optional[str] = None - language: Optional[Language] = None + service_type: str + settings: Dict[str, Any] @dataclass diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index 0b5b5f0a5..afe7c6866 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -8,7 +8,7 @@ import io import wave from abc import abstractmethod -from typing import AsyncGenerator, List, Optional, Tuple, Union +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union from loguru import logger @@ -19,15 +19,14 @@ ErrorFrame, Frame, LLMFullResponseEndFrame, + ServiceUpdateSettingsFrame, StartFrame, StartInterruptionFrame, - STTUpdateSettingsFrame, TextFrame, TTSAudioRawFrame, TTSSpeakFrame, TTSStartedFrame, TTSStoppedFrame, - TTSUpdateSettingsFrame, UserImageRequestFrame, VisionImageRawFrame, ) @@ -153,6 +152,7 @@ def __init__( self._push_text_frames: bool = push_text_frames self._current_sentence: str = "" self._sample_rate: int = sample_rate + self._settings: Dict[str, Any] = {} @property def sample_rate(self) -> int: @@ -223,6 +223,22 @@ async def set_google_style(self, google_style: str): async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: pass + async def _update_settings(self, settings: Dict[str, Any]): + for key, value in settings.items(): + setter = getattr(self, f"set_{key}", None) + if setter and callable(setter): + try: + if key == "language": + await setter(Language(value)) + else: + await setter(value) + except Exception as e: + logger.warning(f"Error setting {key}: {e}") + else: + logger.warning(f"Unknown setting for TTS service: {key}") + + self._settings.update(settings) + async def say(self, text: str): await self.process_frame(TextFrame(text=text), FrameDirection.DOWNSTREAM) @@ -256,34 +272,6 @@ async def _push_tts_frames(self, text: str): # interrupted, the text is not added to the assistant context. await self.push_frame(TextFrame(text)) - async def _update_tts_settings(self, frame: TTSUpdateSettingsFrame): - if frame.model is not None: - await self.set_model(frame.model) - if frame.voice is not None: - await self.set_voice(frame.voice) - if frame.language is not None: - await self.set_language(frame.language) - if frame.speed is not None: - await self.set_speed(frame.speed) - if frame.emotion is not None: - await self.set_emotion(frame.emotion) - if frame.engine is not None: - await self.set_engine(frame.engine) - if frame.pitch is not None: - await self.set_pitch(frame.pitch) - if frame.rate is not None: - await self.set_rate(frame.rate) - if frame.volume is not None: - await self.set_volume(frame.volume) - if frame.emphasis is not None: - await self.set_emphasis(frame.emphasis) - if frame.style is not None: - await self.set_style(frame.style) - if frame.style_degree is not None: - await self.set_style_degree(frame.style_degree) - if frame.role is not None: - await self.set_role(frame.role) - async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -302,8 +290,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): await self.push_frame(frame, direction) elif isinstance(frame, TTSSpeakFrame): await self._push_tts_frames(frame.text) - elif isinstance(frame, TTSUpdateSettingsFrame): - await self._update_tts_settings(frame) + elif isinstance(frame, ServiceUpdateSettingsFrame) and frame.service_type == "tts": + await self._update_settings(frame.settings) else: await self.push_frame(frame, direction) @@ -451,6 +439,7 @@ class STTService(AIService): def __init__(self, **kwargs): super().__init__(**kwargs) + self._settings: Dict[str, Any] = {} @abstractmethod async def set_model(self, model: str): @@ -465,11 +454,21 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: """Returns transcript as a string""" pass - async def _update_stt_settings(self, frame: STTUpdateSettingsFrame): - if frame.model is not None: - await self.set_model(frame.model) - if frame.language is not None: - await self.set_language(frame.language) + async def _update_settings(self, settings: Dict[str, Any]): + for key, value in settings.items(): + setter = getattr(self, f"set_{key}", None) + if setter and callable(setter): + try: + if key == "language": + await setter(Language(value)) + else: + await setter(value) + except Exception as e: + logger.warning(f"Error setting {key}: {e}") + else: + logger.warning(f"Unknown setting for STT service: {key}") + + self._settings.update(settings) async def process_audio_frame(self, frame: AudioRawFrame): await self.process_generator(self.run_stt(frame.audio)) @@ -482,8 +481,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): # In this service we accumulate audio internally and at the end we # push a TextFrame. We don't really want to push audio frames down. await self.process_audio_frame(frame) - elif isinstance(frame, STTUpdateSettingsFrame): - await self._update_stt_settings(frame) + elif isinstance(frame, ServiceUpdateSettingsFrame) and frame.service_type == "stt": + await self._update_settings(frame.settings) else: await self.push_frame(frame, direction) diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic.py index bc91e4e16..bf12f50ac 100644 --- a/src/pipecat/services/anthropic.py +++ b/src/pipecat/services/anthropic.py @@ -25,7 +25,7 @@ LLMFullResponseEndFrame, LLMFullResponseStartFrame, LLMMessagesFrame, - LLMUpdateSettingsFrame, + ServiceUpdateSettingsFrame, StartInterruptionFrame, TextFrame, UserImageRawFrame, @@ -110,9 +110,13 @@ def enable_prompt_caching_beta(self) -> bool: return self._enable_prompt_caching_beta @staticmethod - def create_context_aggregator(context: OpenAILLMContext) -> AnthropicContextAggregatorPair: + def create_context_aggregator( + context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True + ) -> AnthropicContextAggregatorPair: user = AnthropicUserContextAggregator(context) - assistant = AnthropicAssistantContextAggregator(user) + assistant = AnthropicAssistantContextAggregator( + user, expect_stripped_words=assistant_expect_stripped_words + ) return AnthropicContextAggregatorPair(_user=user, _assistant=assistant) async def set_enable_prompt_caching_beta(self, enable_prompt_caching_beta: bool): @@ -279,20 +283,16 @@ async def _process_context(self, context: OpenAILLMContext): cache_read_input_tokens=cache_read_input_tokens, ) - async def _update_settings(self, frame: LLMUpdateSettingsFrame): - if frame.model is not None: - logger.debug(f"Switching LLM model to: [{frame.model}]") - self.set_model_name(frame.model) - if frame.max_tokens is not None: - await self.set_max_tokens(frame.max_tokens) - if frame.temperature is not None: - await self.set_temperature(frame.temperature) - if frame.top_k is not None: - await self.set_top_k(frame.top_k) - if frame.top_p is not None: - await self.set_top_p(frame.top_p) - if frame.extra: - await self.set_extra(frame.extra) + async def _update_settings(self, settings: Dict[str, Any]): + for key, value in settings.items(): + setter = getattr(self, f"set_{key}", None) + if setter and callable(setter): + try: + await setter(value) + except Exception as e: + logger.warning(f"Error setting {key}: {e}") + else: + logger.warning(f"Unknown setting for Anthropic LLM service: {key}") async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -308,8 +308,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): # UserImageRawFrames coming through the pipeline and add them # to the context. context = AnthropicLLMContext.from_image_frame(frame) - elif isinstance(frame, LLMUpdateSettingsFrame): - await self._update_settings(frame) + elif isinstance(frame, ServiceUpdateSettingsFrame) and frame.service_type == "llm": + await self._update_settings(frame.settings) elif isinstance(frame, LLMEnablePromptCachingFrame): logger.debug(f"Setting enable prompt caching to: [{frame.enable}]") self._enable_prompt_caching_beta = frame.enable @@ -541,8 +541,8 @@ async def process_frame(self, frame, direction): class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator): - def __init__(self, user_context_aggregator: AnthropicUserContextAggregator): - super().__init__(context=user_context_aggregator._context) + def __init__(self, user_context_aggregator: AnthropicUserContextAggregator, **kwargs): + super().__init__(context=user_context_aggregator._context, **kwargs) self._user_context_aggregator = user_context_aggregator self._function_call_in_progress = None self._function_call_result = None @@ -579,7 +579,7 @@ async def _push_aggregation(self): run_llm = False aggregation = self._aggregation - self._aggregation = "" + self._reset() try: if self._function_call_result: @@ -630,5 +630,8 @@ async def _push_aggregation(self): if run_llm: await self._user_context_aggregator.push_context_frame() + frame = OpenAILLMContextFrame(self._context) + await self.push_frame(frame) + except Exception as e: logger.error(f"Error processing frame: {e}") diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index f0892b9ca..1adc346ce 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -24,7 +24,7 @@ LLMFullResponseEndFrame, LLMFullResponseStartFrame, LLMMessagesFrame, - LLMUpdateSettingsFrame, + ServiceUpdateSettingsFrame, StartInterruptionFrame, TextFrame, TTSAudioRawFrame, @@ -273,22 +273,16 @@ async def _handle_function_call(self, context, tool_call_id, function_name, argu arguments=arguments, ) - async def _update_settings(self, frame: LLMUpdateSettingsFrame): - if frame.model is not None: - logger.debug(f"Switching LLM model to: [{frame.model}]") - self.set_model_name(frame.model) - if frame.frequency_penalty is not None: - await self.set_frequency_penalty(frame.frequency_penalty) - if frame.presence_penalty is not None: - await self.set_presence_penalty(frame.presence_penalty) - if frame.seed is not None: - await self.set_seed(frame.seed) - if frame.temperature is not None: - await self.set_temperature(frame.temperature) - if frame.top_p is not None: - await self.set_top_p(frame.top_p) - if frame.extra: - await self.set_extra(frame.extra) + async def _update_settings(self, settings: Dict[str, Any]): + for key, value in settings.items(): + setter = getattr(self, f"set_{key}", None) + if setter and callable(setter): + try: + await setter(value) + except Exception as e: + logger.warning(f"Error setting {key}: {e}") + else: + logger.warning(f"Unknown setting for OpenAI LLM service: {key}") async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -300,8 +294,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): context = OpenAILLMContext.from_messages(frame.messages) elif isinstance(frame, VisionImageRawFrame): context = OpenAILLMContext.from_image_frame(frame) - elif isinstance(frame, LLMUpdateSettingsFrame): - await self._update_settings(frame) + elif isinstance(frame, ServiceUpdateSettingsFrame) and frame.service_type == "llm": + await self._update_settings(frame.settings) else: await self.push_frame(frame, direction) diff --git a/src/pipecat/services/together.py b/src/pipecat/services/together.py index e4068ecfc..f7a5993f9 100644 --- a/src/pipecat/services/together.py +++ b/src/pipecat/services/together.py @@ -21,7 +21,7 @@ LLMFullResponseEndFrame, LLMFullResponseStartFrame, LLMMessagesFrame, - LLMUpdateSettingsFrame, + ServiceUpdateSettingsFrame, StartInterruptionFrame, TextFrame, UserImageRequestFrame, @@ -128,24 +128,16 @@ async def set_extra(self, extra: Dict[str, Any]): logger.debug(f"Switching LLM extra to: [{extra}]") self._extra = extra - async def _update_settings(self, frame: LLMUpdateSettingsFrame): - if frame.model is not None: - logger.debug(f"Switching LLM model to: [{frame.model}]") - self.set_model_name(frame.model) - if frame.frequency_penalty is not None: - await self.set_frequency_penalty(frame.frequency_penalty) - if frame.max_tokens is not None: - await self.set_max_tokens(frame.max_tokens) - if frame.presence_penalty is not None: - await self.set_presence_penalty(frame.presence_penalty) - if frame.temperature is not None: - await self.set_temperature(frame.temperature) - if frame.top_k is not None: - await self.set_top_k(frame.top_k) - if frame.top_p is not None: - await self.set_top_p(frame.top_p) - if frame.extra: - await self.set_extra(frame.extra) + async def _update_settings(self, settings: Dict[str, Any]): + for key, value in settings.items(): + setter = getattr(self, f"set_{key}", None) + if setter and callable(setter): + try: + await setter(value) + except Exception as e: + logger.warning(f"Error setting {key}: {e}") + else: + logger.warning(f"Unknown setting for Together LLM service: {key}") async def _process_context(self, context: OpenAILLMContext): try: @@ -224,8 +216,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): context = frame.context elif isinstance(frame, LLMMessagesFrame): context = TogetherLLMContext.from_messages(frame.messages) - elif isinstance(frame, LLMUpdateSettingsFrame): - await self._update_settings(frame) + elif isinstance(frame, ServiceUpdateSettingsFrame) and frame.service_type == "llm": + await self._update_settings(frame.settings) else: await self.push_frame(frame, direction)