Skip to content

Commit

Permalink
Consolidate service UpdateSettingsFrame into a single ServiceUpdateSe…
Browse files Browse the repository at this point in the history
…ttingsFrame
  • Loading branch information
markbackman committed Oct 1, 2024
1 parent a397b85 commit 88cca7b
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 150 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ async def on_connected(processor):

### Changed

- Updated individual update settings frame classes into a single UpdateSettingsFrame
class for STT, LLM, and TTS.
- Updated individual update settings frame classes into a single
ServiceUpdateSettingsFrame class.

- We now distinguish between input and output audio and image frames. We
introduce `InputAudioRawFrame`, `OutputAudioRawFrame`, `InputImageRawFrame`
Expand Down
46 changes: 5 additions & 41 deletions src/pipecat/frames/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#

from dataclasses import dataclass, field
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple

from pipecat.clocks.base_clock import BaseClock
from pipecat.metrics.metrics import MetricsData
Expand Down Expand Up @@ -527,47 +527,11 @@ def __str__(self):


@dataclass
class LLMUpdateSettingsFrame(ControlFrame):
"""A control frame containing a request to update LLM settings."""
class ServiceUpdateSettingsFrame(ControlFrame):
"""A control frame containing a request to update service settings."""

model: Optional[str] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
max_tokens: Optional[int] = None
seed: Optional[int] = None
extra: dict = field(default_factory=dict)


@dataclass
class TTSUpdateSettingsFrame(ControlFrame):
"""A control frame containing a request to update TTS settings."""

model: Optional[str] = None
voice: Optional[str] = None
language: Optional[Language] = None
speed: Optional[Union[str, float]] = None
emotion: Optional[List[str]] = None
engine: Optional[str] = None
pitch: Optional[str] = None
rate: Optional[str] = None
volume: Optional[str] = None
emphasis: Optional[str] = None
style: Optional[str] = None
style_degree: Optional[str] = None
role: Optional[str] = None
gender: Optional[str] = None
google_style: Optional[str] = None


@dataclass
class STTUpdateSettingsFrame(ControlFrame):
"""A control frame containing a request to update STT settings."""

model: Optional[str] = None
language: Optional[Language] = None
service_type: str
settings: Dict[str, Any]


@dataclass
Expand Down
85 changes: 42 additions & 43 deletions src/pipecat/services/ai_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import io
import wave
from abc import abstractmethod
from typing import AsyncGenerator, List, Optional, Tuple, Union
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union

from loguru import logger

Expand All @@ -19,15 +19,14 @@
ErrorFrame,
Frame,
LLMFullResponseEndFrame,
ServiceUpdateSettingsFrame,
StartFrame,
StartInterruptionFrame,
STTUpdateSettingsFrame,
TextFrame,
TTSAudioRawFrame,
TTSSpeakFrame,
TTSStartedFrame,
TTSStoppedFrame,
TTSUpdateSettingsFrame,
UserImageRequestFrame,
VisionImageRawFrame,
)
Expand Down Expand Up @@ -169,6 +168,7 @@ def __init__(
self._push_stop_frames: bool = push_stop_frames
self._stop_frame_timeout_s: float = stop_frame_timeout_s
self._sample_rate: int = sample_rate
self._settings: Dict[str, Any] = {}

self._stop_frame_task: Optional[asyncio.Task] = None
self._stop_frame_queue: asyncio.Queue = asyncio.Queue()
Expand Down Expand Up @@ -232,15 +232,15 @@ async def set_role(self, role: str):
pass

@abstractmethod
async def flush_audio(self):
async def set_gender(self, gender: str):
pass

@abstractmethod
async def set_gender(self, gender: str):
async def set_google_style(self, google_style: str):
pass

@abstractmethod
async def set_google_style(self, google_style: str):
async def flush_audio(self):
pass

# Converts the text to audio.
Expand All @@ -267,6 +267,22 @@ async def cancel(self, frame: CancelFrame):
await self._stop_frame_task
self._stop_frame_task = None

async def _update_settings(self, settings: Dict[str, Any]):
for key, value in settings.items():
setter = getattr(self, f"set_{key}", None)
if setter and callable(setter):
try:
if key == "language":
await setter(Language(value))
else:
await setter(value)
except Exception as e:
logger.warning(f"Error setting {key}: {e}")
else:
logger.warning(f"Unknown setting for TTS service: {key}")

self._settings.update(settings)

async def say(self, text: str):
aggregate_sentences = self._aggregate_sentences
self._aggregate_sentences = False
Expand All @@ -293,8 +309,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
elif isinstance(frame, TTSSpeakFrame):
await self._push_tts_frames(frame.text)
await self.flush_audio()
elif isinstance(frame, TTSUpdateSettingsFrame):
await self._update_tts_settings(frame)
elif isinstance(frame, ServiceUpdateSettingsFrame) and frame.service_type == "tts":
await self._update_settings(frame.settings)
else:
await self.push_frame(frame, direction)

Expand Down Expand Up @@ -341,34 +357,6 @@ async def _push_tts_frames(self, text: str):
# interrupted, the text is not added to the assistant context.
await self.push_frame(TextFrame(text))

async def _update_tts_settings(self, frame: TTSUpdateSettingsFrame):
if frame.model is not None:
await self.set_model(frame.model)
if frame.voice is not None:
await self.set_voice(frame.voice)
if frame.language is not None:
await self.set_language(frame.language)
if frame.speed is not None:
await self.set_speed(frame.speed)
if frame.emotion is not None:
await self.set_emotion(frame.emotion)
if frame.engine is not None:
await self.set_engine(frame.engine)
if frame.pitch is not None:
await self.set_pitch(frame.pitch)
if frame.rate is not None:
await self.set_rate(frame.rate)
if frame.volume is not None:
await self.set_volume(frame.volume)
if frame.emphasis is not None:
await self.set_emphasis(frame.emphasis)
if frame.style is not None:
await self.set_style(frame.style)
if frame.style_degree is not None:
await self.set_style_degree(frame.style_degree)
if frame.role is not None:
await self.set_role(frame.role)

async def _stop_frame_handler(self):
try:
has_started = False
Expand Down Expand Up @@ -454,6 +442,7 @@ class STTService(AIService):

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._settings: Dict[str, Any] = {}

@abstractmethod
async def set_model(self, model: str):
Expand All @@ -468,11 +457,21 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Returns transcript as a string"""
pass

async def _update_stt_settings(self, frame: STTUpdateSettingsFrame):
if frame.model is not None:
await self.set_model(frame.model)
if frame.language is not None:
await self.set_language(frame.language)
async def _update_settings(self, settings: Dict[str, Any]):
for key, value in settings.items():
setter = getattr(self, f"set_{key}", None)
if setter and callable(setter):
try:
if key == "language":
await setter(Language(value))
else:
await setter(value)
except Exception as e:
logger.warning(f"Error setting {key}: {e}")
else:
logger.warning(f"Unknown setting for STT service: {key}")

self._settings.update(settings)

async def process_audio_frame(self, frame: AudioRawFrame):
await self.process_generator(self.run_stt(frame.audio))
Expand All @@ -485,8 +484,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
# In this service we accumulate audio internally and at the end we
# push a TextFrame. We don't really want to push audio frames down.
await self.process_audio_frame(frame)
elif isinstance(frame, STTUpdateSettingsFrame):
await self._update_stt_settings(frame)
elif isinstance(frame, ServiceUpdateSettingsFrame) and frame.service_type == "stt":
await self._update_settings(frame.settings)
else:
await self.push_frame(frame, direction)

Expand Down
30 changes: 13 additions & 17 deletions src/pipecat/services/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
LLMUpdateSettingsFrame,
ServiceUpdateSettingsFrame,
StartInterruptionFrame,
TextFrame,
UserImageRawFrame,
Expand Down Expand Up @@ -284,20 +284,16 @@ async def _process_context(self, context: OpenAILLMContext):
cache_read_input_tokens=cache_read_input_tokens,
)

async def _update_settings(self, frame: LLMUpdateSettingsFrame):
if frame.model is not None:
logger.debug(f"Switching LLM model to: [{frame.model}]")
self.set_model_name(frame.model)
if frame.max_tokens is not None:
await self.set_max_tokens(frame.max_tokens)
if frame.temperature is not None:
await self.set_temperature(frame.temperature)
if frame.top_k is not None:
await self.set_top_k(frame.top_k)
if frame.top_p is not None:
await self.set_top_p(frame.top_p)
if frame.extra:
await self.set_extra(frame.extra)
async def _update_settings(self, settings: Dict[str, Any]):
for key, value in settings.items():
setter = getattr(self, f"set_{key}", None)
if setter and callable(setter):
try:
await setter(value)
except Exception as e:
logger.warning(f"Error setting {key}: {e}")
else:
logger.warning(f"Unknown setting for Anthropic LLM service: {key}")

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
Expand All @@ -313,8 +309,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
# UserImageRawFrames coming through the pipeline and add them
# to the context.
context = AnthropicLLMContext.from_image_frame(frame)
elif isinstance(frame, LLMUpdateSettingsFrame):
await self._update_settings(frame)
elif isinstance(frame, ServiceUpdateSettingsFrame) and frame.service_type == "llm":
await self._update_settings(frame.settings)
elif isinstance(frame, LLMEnablePromptCachingFrame):
logger.debug(f"Setting enable prompt caching to: [{frame.enable}]")
self._enable_prompt_caching_beta = frame.enable
Expand Down
25 changes: 19 additions & 6 deletions src/pipecat/services/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import asyncio
import json
from typing import AsyncGenerator, List, Literal, Optional
from typing import Any, AsyncGenerator, Dict, List, Literal, Optional

from loguru import logger
from pydantic import BaseModel
Expand All @@ -17,7 +17,7 @@
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
LLMUpdateSettingsFrame,
ServiceUpdateSettingsFrame,
TextFrame,
TTSAudioRawFrame,
TTSStartedFrame,
Expand Down Expand Up @@ -64,6 +64,21 @@ def _create_client(self, model: str):
self.set_model_name(model)
self._client = gai.GenerativeModel(model)

async def set_model(self, model: str):
logger.debug(f"Switching LLM model to: [{model}]")
self._create_client(model)

async def _update_settings(self, settings: Dict[str, Any]):
for key, value in settings.items():
setter = getattr(self, f"set_{key}", None)
if setter and callable(setter):
try:
await setter(value)
except Exception as e:
logger.warning(f"Error setting {key}: {e}")
else:
logger.warning(f"Unknown setting for Google LLM service: {key}")

def _get_messages_from_openai_context(self, context: OpenAILLMContext) -> List[glm.Content]:
openai_messages = context.get_messages()
google_messages = []
Expand Down Expand Up @@ -136,10 +151,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
context = OpenAILLMContext.from_messages(frame.messages)
elif isinstance(frame, VisionImageRawFrame):
context = OpenAILLMContext.from_image_frame(frame)
elif isinstance(frame, LLMUpdateSettingsFrame):
if frame.model is not None:
logger.debug(f"Switching LLM model to: [{frame.model}]")
self.set_model_name(frame.model)
elif isinstance(frame, ServiceUpdateSettingsFrame) and frame.service_type == "llm":
await self._update_settings(frame.settings)
else:
await self.push_frame(frame, direction)

Expand Down
32 changes: 13 additions & 19 deletions src/pipecat/services/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
LLMUpdateSettingsFrame,
ServiceUpdateSettingsFrame,
StartInterruptionFrame,
TextFrame,
TTSAudioRawFrame,
Expand Down Expand Up @@ -295,22 +295,16 @@ async def _process_context(self, context: OpenAILLMContext):
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
)

async def _update_settings(self, frame: LLMUpdateSettingsFrame):
if frame.model is not None:
logger.debug(f"Switching LLM model to: [{frame.model}]")
self.set_model_name(frame.model)
if frame.frequency_penalty is not None:
await self.set_frequency_penalty(frame.frequency_penalty)
if frame.presence_penalty is not None:
await self.set_presence_penalty(frame.presence_penalty)
if frame.seed is not None:
await self.set_seed(frame.seed)
if frame.temperature is not None:
await self.set_temperature(frame.temperature)
if frame.top_p is not None:
await self.set_top_p(frame.top_p)
if frame.extra:
await self.set_extra(frame.extra)
async def _update_settings(self, settings: Dict[str, Any]):
for key, value in settings.items():
setter = getattr(self, f"set_{key}", None)
if setter and callable(setter):
try:
await setter(value)
except Exception as e:
logger.warning(f"Error setting {key}: {e}")
else:
logger.warning(f"Unknown setting for OpenAI LLM service: {key}")

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
Expand All @@ -322,8 +316,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
context = OpenAILLMContext.from_messages(frame.messages)
elif isinstance(frame, VisionImageRawFrame):
context = OpenAILLMContext.from_image_frame(frame)
elif isinstance(frame, LLMUpdateSettingsFrame):
await self._update_settings(frame)
elif isinstance(frame, ServiceUpdateSettingsFrame) and frame.service_type == "llm":
await self._update_settings(frame.settings)
else:
await self.push_frame(frame, direction)

Expand Down
Loading

0 comments on commit 88cca7b

Please sign in to comment.