From 667fcea5685b7a8ab9e164fe8437dfa30da2f6e5 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Mon, 23 Sep 2024 10:49:50 -0400 Subject: [PATCH] Add setter functions --- src/pipecat/services/azure.py | 146 +++++++++++++++++++++++----------- 1 file changed, 98 insertions(+), 48 deletions(-) diff --git a/src/pipecat/services/azure.py b/src/pipecat/services/azure.py index dc574dc87..57c6387c9 100644 --- a/src/pipecat/services/azure.py +++ b/src/pipecat/services/azure.py @@ -4,12 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # -import aiohttp import asyncio import io +from typing import AsyncGenerator, Optional +import aiohttp +from loguru import logger from PIL import Image -from typing import AsyncGenerator, Optional from pydantic import BaseModel from pipecat.frames.frames import ( @@ -18,46 +19,43 @@ ErrorFrame, Frame, StartFrame, + TranscriptionFrame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame, - TranscriptionFrame, - URLImageRawFrame) -from pipecat.metrics.metrics import TTSUsageMetricsData -from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.ai_services import STTService, TTSService, ImageGenService + URLImageRawFrame, +) +from pipecat.services.ai_services import ImageGenService, STTService, TTSService from pipecat.services.openai import BaseOpenAILLMService from pipecat.utils.time import time_now_iso8601 -from loguru import logger - # See .env.example for Azure configuration needed try: - from openai import AsyncAzureOpenAI from azure.cognitiveservices.speech import ( + CancellationReason, + ResultReason, SpeechConfig, SpeechRecognizer, SpeechSynthesizer, - ResultReason, - CancellationReason, ) - from azure.cognitiveservices.speech.audio import AudioStreamFormat, PushAudioInputStream + from azure.cognitiveservices.speech.audio import ( + AudioStreamFormat, + PushAudioInputStream, + ) from azure.cognitiveservices.speech.dialog import AudioConfig + from openai import AsyncAzureOpenAI except ModuleNotFoundError as e: logger.error(f"Exception: {e}") logger.error( - "In order to use Azure, you need to `pip install pipecat-ai[azure]`. Also, set `AZURE_SPEECH_API_KEY` and `AZURE_SPEECH_REGION` environment variables.") + "In order to use Azure, you need to `pip install pipecat-ai[azure]`. Also, set `AZURE_SPEECH_API_KEY` and `AZURE_SPEECH_REGION` environment variables." + ) raise Exception(f"Missing module: {e}") class AzureLLMService(BaseOpenAILLMService): def __init__( - self, - *, - api_key: str, - endpoint: str, - model: str, - api_version: str = "2023-12-01-preview"): + self, *, api_key: str, endpoint: str, model: str, api_version: str = "2023-12-01-preview" + ): # Initialize variables before calling parent __init__() because that # will call create_client() and we need those values there. self._endpoint = endpoint @@ -83,16 +81,16 @@ class InputParams(BaseModel): style_degree: Optional[str] = None volume: Optional[str] = None - def __init__( - self, - *, - api_key: str, - region: str, - voice="en-US-SaraNeural", - sample_rate: int = 16000, - params: InputParams = InputParams(), - **kwargs): + self, + *, + api_key: str, + region: str, + voice="en-US-SaraNeural", + sample_rate: int = 16000, + params: InputParams = InputParams(), + **kwargs, + ): super().__init__(sample_rate=sample_rate, **kwargs) speech_config = SpeechConfig(subscription=api_key, region=region) @@ -129,7 +127,7 @@ def _construct_ssml(self, text: str) -> str: prosody_attrs.append(f"pitch='{self._params.pitch}'") if self._params.volume: prosody_attrs.append(f"volume='{self._params.volume}'") - + ssml += f"" if self._params.emphasis: @@ -153,6 +151,59 @@ async def set_voice(self, voice: str): logger.debug(f"Switching TTS voice to: [{voice}]") self._voice = voice + async def set_emphasis(self, emphasis: str): + logger.debug(f"Setting TTS emphasis to: [{emphasis}]") + self._params.emphasis = emphasis + + async def set_language_code(self, language_code: str): + logger.debug(f"Setting TTS language code to: [{language_code}]") + self._params.language_code = language_code + + async def set_pitch(self, pitch: str): + logger.debug(f"Setting TTS pitch to: [{pitch}]") + self._params.pitch = pitch + + async def set_rate(self, rate: str): + logger.debug(f"Setting TTS rate to: [{rate}]") + self._params.rate = rate + + async def set_role(self, role: str): + logger.debug(f"Setting TTS role to: [{role}]") + self._params.role = role + + async def set_style(self, style: str): + logger.debug(f"Setting TTS style to: [{style}]") + self._params.style = style + + async def set_style_degree(self, style_degree: str): + logger.debug(f"Setting TTS style degree to: [{style_degree}]") + self._params.style_degree = style_degree + + async def set_volume(self, volume: str): + logger.debug(f"Setting TTS volume to: [{volume}]") + self._params.volume = volume + + async def set_params(self, **kwargs): + valid_params = { + "voice": self.set_voice, + "emphasis": self.set_emphasis, + "language_code": self.set_language_code, + "pitch": self.set_pitch, + "rate": self.set_rate, + "role": self.set_role, + "style": self.set_style, + "style_degree": self.set_style_degree, + "volume": self.set_volume, + } + + for param, value in kwargs.items(): + if param in valid_params: + await valid_params[param](value) + else: + logger.warning(f"Ignoring unknown parameter: {param}") + + logger.debug(f"Updated TTS parameters: {', '.join(kwargs.keys())}") + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: logger.debug(f"Generating TTS: [{text}]") @@ -167,7 +218,9 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: await self.stop_ttfb_metrics() await self.push_frame(TTSStartedFrame()) # Azure always sends a 44-byte header. Strip it off. - yield TTSAudioRawFrame(audio=result.audio_data[44:], sample_rate=self._sample_rate, num_channels=1) + yield TTSAudioRawFrame( + audio=result.audio_data[44:], sample_rate=self._sample_rate, num_channels=1 + ) await self.push_frame(TTSStoppedFrame()) elif result.reason == ResultReason.Canceled: cancellation_details = result.cancellation_details @@ -178,14 +231,15 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: class AzureSTTService(STTService): def __init__( - self, - *, - api_key: str, - region: str, - language="en-US", - sample_rate=16000, - channels=1, - **kwargs): + self, + *, + api_key: str, + region: str, + language="en-US", + sample_rate=16000, + channels=1, + **kwargs, + ): super().__init__(**kwargs) speech_config = SpeechConfig(subscription=api_key, region=region) @@ -196,7 +250,8 @@ def __init__( audio_config = AudioConfig(stream=self._audio_stream) self._speech_recognizer = SpeechRecognizer( - speech_config=speech_config, audio_config=audio_config) + speech_config=speech_config, audio_config=audio_config + ) self._speech_recognizer.recognized.connect(self._on_handle_recognized) async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: @@ -226,7 +281,6 @@ def _on_handle_recognized(self, event): class AzureImageGenServiceREST(ImageGenService): - def __init__( self, *, @@ -249,9 +303,7 @@ def __init__( async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: url = f"{self._azure_endpoint}openai/images/generations:submit?api-version={self._api_version}" - headers = { - "api-key": self._api_key, - "Content-Type": "application/json"} + headers = {"api-key": self._api_key, "Content-Type": "application/json"} body = { # Enter your prompt text here @@ -293,8 +345,6 @@ async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: image_stream = io.BytesIO(await response.content.read()) image = Image.open(image_stream) frame = URLImageRawFrame( - url=image_url, - image=image.tobytes(), - size=image.size, - format=image.format) + url=image_url, image=image.tobytes(), size=image.size, format=image.format + ) yield frame