Skip to content

Commit

Permalink
Merge pull request #527 from pipecat-ai/mb/google-tts-inputs
Browse files Browse the repository at this point in the history
Further consolidate service update settings into a single ServiceUpdateSettingsFrame class
  • Loading branch information
markbackman authored Oct 2, 2024
2 parents 8aae4e9 + 3d642df commit 096a15e
Show file tree
Hide file tree
Showing 19 changed files with 867 additions and 725 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ async def on_connected(processor):

### Changed

- Updated individual update settings frame classes into a single UpdateSettingsFrame
class for STT, LLM, and TTS.
- Updated individual update settings frame classes into a single
ServiceUpdateSettingsFrame class.

- We now distinguish between input and output audio and image frames. We
introduce `InputAudioRawFrame`, `OutputAudioRawFrame`, `InputImageRawFrame`
Expand Down
14 changes: 6 additions & 8 deletions examples/foundational/07e-interruptible-playht.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
# SPDX-License-Identifier: BSD 2-Clause License
#

import aiohttp
import asyncio
import os
import sys

import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure

from pipecat.frames.frames import LLMMessagesFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
Expand All @@ -17,17 +21,11 @@
LLMAssistantResponseAggregator,
LLMUserResponseAggregator,
)
from pipecat.services.playht import PlayHTTTSService
from pipecat.services.openai import OpenAILLMService
from pipecat.services.playht import PlayHTTTSService
from pipecat.transports.services.daily import DailyParams, DailyTransport
from pipecat.vad.silero import SileroVADAnalyzer

from runner import configure

from loguru import logger

from dotenv import load_dotenv

load_dotenv(override=True)

logger.remove(0)
Expand Down
15 changes: 6 additions & 9 deletions examples/foundational/07g-interruptible-openai-tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
# SPDX-License-Identifier: BSD 2-Clause License
#

import aiohttp
import asyncio
import os
import sys

import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure

from pipecat.frames.frames import LLMMessagesFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
Expand All @@ -17,17 +21,10 @@
LLMAssistantResponseAggregator,
LLMUserResponseAggregator,
)
from pipecat.services.openai import OpenAITTSService
from pipecat.services.openai import OpenAILLMService
from pipecat.services.openai import OpenAILLMService, OpenAITTSService
from pipecat.transports.services.daily import DailyParams, DailyTransport
from pipecat.vad.silero import SileroVADAnalyzer

from runner import configure

from loguru import logger

from dotenv import load_dotenv

load_dotenv(override=True)

logger.remove(0)
Expand Down
12 changes: 5 additions & 7 deletions examples/foundational/16-gpu-container-local-bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
#

import asyncio
import aiohttp
import os
import sys

import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure

from pipecat.frames.frames import LLMMessagesFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
Expand All @@ -26,12 +30,6 @@
)
from pipecat.vad.silero import SileroVADAnalyzer

from runner import configure

from loguru import logger

from dotenv import load_dotenv

load_dotenv(override=True)

logger.remove(0)
Expand Down
44 changes: 12 additions & 32 deletions src/pipecat/frames/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#

from dataclasses import dataclass, field
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple

from pipecat.clocks.base_clock import BaseClock
from pipecat.metrics.metrics import MetricsData
Expand Down Expand Up @@ -527,45 +527,25 @@ def __str__(self):


@dataclass
class LLMUpdateSettingsFrame(ControlFrame):
"""A control frame containing a request to update LLM settings."""
class ServiceUpdateSettingsFrame(ControlFrame):
"""A control frame containing a request to update service settings."""

model: Optional[str] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
max_tokens: Optional[int] = None
seed: Optional[int] = None
extra: dict = field(default_factory=dict)
settings: Dict[str, Any]


@dataclass
class TTSUpdateSettingsFrame(ControlFrame):
"""A control frame containing a request to update TTS settings."""

model: Optional[str] = None
voice: Optional[str] = None
language: Optional[Language] = None
speed: Optional[Union[str, float]] = None
emotion: Optional[List[str]] = None
engine: Optional[str] = None
pitch: Optional[str] = None
rate: Optional[str] = None
volume: Optional[str] = None
emphasis: Optional[str] = None
style: Optional[str] = None
style_degree: Optional[str] = None
role: Optional[str] = None
class LLMUpdateSettingsFrame(ServiceUpdateSettingsFrame):
pass


@dataclass
class STTUpdateSettingsFrame(ControlFrame):
"""A control frame containing a request to update STT settings."""
class TTSUpdateSettingsFrame(ServiceUpdateSettingsFrame):
pass

model: Optional[str] = None
language: Optional[Language] = None

@dataclass
class STTUpdateSettingsFrame(ServiceUpdateSettingsFrame):
pass


@dataclass
Expand Down
131 changes: 45 additions & 86 deletions src/pipecat/services/ai_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import io
import wave
from abc import abstractmethod
from typing import AsyncGenerator, List, Optional, Tuple, Union
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple

from loguru import logger

Expand Down Expand Up @@ -45,6 +45,7 @@ class AIService(FrameProcessor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._model_name: str = ""
self._settings: Dict[str, Any] = {}

@property
def model_name(self) -> str:
Expand All @@ -63,6 +64,16 @@ async def stop(self, frame: EndFrame):
async def cancel(self, frame: CancelFrame):
pass

async def _update_settings(self, settings: Dict[str, Any]):
for key, value in settings.items():
if key in self._settings:
logger.debug(f"Updating setting {key} to: [{value}] for {self.name}")
self._settings[key] = value
elif key == "model":
self.set_model_name(value)
else:
logger.warning(f"Unknown setting for {self.name} service: {key}")

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

Expand Down Expand Up @@ -169,6 +180,8 @@ def __init__(
self._push_stop_frames: bool = push_stop_frames
self._stop_frame_timeout_s: float = stop_frame_timeout_s
self._sample_rate: int = sample_rate
self._voice_id: str = ""
self._settings: Dict[str, Any] = {}

self._stop_frame_task: Optional[asyncio.Task] = None
self._stop_frame_queue: asyncio.Queue = asyncio.Queue()
Expand All @@ -184,52 +197,8 @@ async def set_model(self, model: str):
self.set_model_name(model)

@abstractmethod
async def set_voice(self, voice: str):
pass

@abstractmethod
async def set_language(self, language: Language):
pass

@abstractmethod
async def set_speed(self, speed: Union[str, float]):
pass

@abstractmethod
async def set_emotion(self, emotion: List[str]):
pass

@abstractmethod
async def set_engine(self, engine: str):
pass

@abstractmethod
async def set_pitch(self, pitch: str):
pass

@abstractmethod
async def set_rate(self, rate: str):
pass

@abstractmethod
async def set_volume(self, volume: str):
pass

@abstractmethod
async def set_emphasis(self, emphasis: str):
pass

@abstractmethod
async def set_style(self, style: str):
pass

@abstractmethod
async def set_style_degree(self, style_degree: str):
pass

@abstractmethod
async def set_role(self, role: str):
pass
def set_voice(self, voice: str):
self._voice_id = voice

@abstractmethod
async def flush_audio(self):
Expand Down Expand Up @@ -259,6 +228,20 @@ async def cancel(self, frame: CancelFrame):
await self._stop_frame_task
self._stop_frame_task = None

async def _update_settings(self, settings: Dict[str, Any]):
for key, value in settings.items():
if key in self._settings:
logger.debug(f"Updating TTS setting {key} to: [{value}]")
self._settings[key] = value
if key == "language":
self._settings[key] = Language(value)
elif key == "model":
self.set_model_name(value)
elif key == "voice":
self.set_voice(value)
else:
logger.warning(f"Unknown setting for TTS service: {key}")

async def say(self, text: str):
aggregate_sentences = self._aggregate_sentences
self._aggregate_sentences = False
Expand Down Expand Up @@ -286,7 +269,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
await self._push_tts_frames(frame.text)
await self.flush_audio()
elif isinstance(frame, TTSUpdateSettingsFrame):
await self._update_tts_settings(frame)
await self._update_settings(frame.settings)
else:
await self.push_frame(frame, direction)

Expand Down Expand Up @@ -333,34 +316,6 @@ async def _push_tts_frames(self, text: str):
# interrupted, the text is not added to the assistant context.
await self.push_frame(TextFrame(text))

async def _update_tts_settings(self, frame: TTSUpdateSettingsFrame):
if frame.model is not None:
await self.set_model(frame.model)
if frame.voice is not None:
await self.set_voice(frame.voice)
if frame.language is not None:
await self.set_language(frame.language)
if frame.speed is not None:
await self.set_speed(frame.speed)
if frame.emotion is not None:
await self.set_emotion(frame.emotion)
if frame.engine is not None:
await self.set_engine(frame.engine)
if frame.pitch is not None:
await self.set_pitch(frame.pitch)
if frame.rate is not None:
await self.set_rate(frame.rate)
if frame.volume is not None:
await self.set_volume(frame.volume)
if frame.emphasis is not None:
await self.set_emphasis(frame.emphasis)
if frame.style is not None:
await self.set_style(frame.style)
if frame.style_degree is not None:
await self.set_style_degree(frame.style_degree)
if frame.role is not None:
await self.set_role(frame.role)

async def _stop_frame_handler(self):
try:
has_started = False
Expand Down Expand Up @@ -446,25 +401,29 @@ class STTService(AIService):

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._settings: Dict[str, Any] = {}

@abstractmethod
async def set_model(self, model: str):
self.set_model_name(model)

@abstractmethod
async def set_language(self, language: Language):
pass

@abstractmethod
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Returns transcript as a string"""
pass

async def _update_stt_settings(self, frame: STTUpdateSettingsFrame):
if frame.model is not None:
await self.set_model(frame.model)
if frame.language is not None:
await self.set_language(frame.language)
async def _update_settings(self, settings: Dict[str, Any]):
logger.debug(f"Updating STT settings: {self._settings}")
for key, value in settings.items():
if key in self._settings:
logger.debug(f"Updating STT setting {key} to: [{value}]")
self._settings[key] = value
if key == "language":
self._settings[key] = Language(value)
elif key == "model":
self.set_model_name(value)
else:
logger.warning(f"Unknown setting for STT service: {key}")

async def process_audio_frame(self, frame: AudioRawFrame):
await self.process_generator(self.run_stt(frame.audio))
Expand All @@ -478,7 +437,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
# push a TextFrame. We don't really want to push audio frames down.
await self.process_audio_frame(frame)
elif isinstance(frame, STTUpdateSettingsFrame):
await self._update_stt_settings(frame)
await self._update_settings(frame.settings)
else:
await self.push_frame(frame, direction)

Expand Down
Loading

0 comments on commit 096a15e

Please sign in to comment.