Skip to content

Commit

Permalink
Add setter functions
Browse files Browse the repository at this point in the history
  • Loading branch information
markbackman committed Sep 23, 2024
1 parent b2e1381 commit 667fcea
Showing 1 changed file with 98 additions and 48 deletions.
146 changes: 98 additions & 48 deletions src/pipecat/services/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
# SPDX-License-Identifier: BSD 2-Clause License
#

import aiohttp
import asyncio
import io
from typing import AsyncGenerator, Optional

import aiohttp
from loguru import logger
from PIL import Image
from typing import AsyncGenerator, Optional
from pydantic import BaseModel

from pipecat.frames.frames import (
Expand All @@ -18,46 +19,43 @@
ErrorFrame,
Frame,
StartFrame,
TranscriptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
TranscriptionFrame,
URLImageRawFrame)
from pipecat.metrics.metrics import TTSUsageMetricsData
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import STTService, TTSService, ImageGenService
URLImageRawFrame,
)
from pipecat.services.ai_services import ImageGenService, STTService, TTSService
from pipecat.services.openai import BaseOpenAILLMService
from pipecat.utils.time import time_now_iso8601

from loguru import logger

# See .env.example for Azure configuration needed
try:
from openai import AsyncAzureOpenAI
from azure.cognitiveservices.speech import (
CancellationReason,
ResultReason,
SpeechConfig,
SpeechRecognizer,
SpeechSynthesizer,
ResultReason,
CancellationReason,
)
from azure.cognitiveservices.speech.audio import AudioStreamFormat, PushAudioInputStream
from azure.cognitiveservices.speech.audio import (
AudioStreamFormat,
PushAudioInputStream,
)
from azure.cognitiveservices.speech.dialog import AudioConfig
from openai import AsyncAzureOpenAI
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use Azure, you need to `pip install pipecat-ai[azure]`. Also, set `AZURE_SPEECH_API_KEY` and `AZURE_SPEECH_REGION` environment variables.")
"In order to use Azure, you need to `pip install pipecat-ai[azure]`. Also, set `AZURE_SPEECH_API_KEY` and `AZURE_SPEECH_REGION` environment variables."
)
raise Exception(f"Missing module: {e}")


class AzureLLMService(BaseOpenAILLMService):
def __init__(
self,
*,
api_key: str,
endpoint: str,
model: str,
api_version: str = "2023-12-01-preview"):
self, *, api_key: str, endpoint: str, model: str, api_version: str = "2023-12-01-preview"
):
# Initialize variables before calling parent __init__() because that
# will call create_client() and we need those values there.
self._endpoint = endpoint
Expand All @@ -83,16 +81,16 @@ class InputParams(BaseModel):
style_degree: Optional[str] = None
volume: Optional[str] = None


def __init__(
self,
*,
api_key: str,
region: str,
voice="en-US-SaraNeural",
sample_rate: int = 16000,
params: InputParams = InputParams(),
**kwargs):
self,
*,
api_key: str,
region: str,
voice="en-US-SaraNeural",
sample_rate: int = 16000,
params: InputParams = InputParams(),
**kwargs,
):
super().__init__(sample_rate=sample_rate, **kwargs)

speech_config = SpeechConfig(subscription=api_key, region=region)
Expand Down Expand Up @@ -129,7 +127,7 @@ def _construct_ssml(self, text: str) -> str:
prosody_attrs.append(f"pitch='{self._params.pitch}'")
if self._params.volume:
prosody_attrs.append(f"volume='{self._params.volume}'")

ssml += f"<prosody {' '.join(prosody_attrs)}>"

if self._params.emphasis:
Expand All @@ -153,6 +151,59 @@ async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._voice = voice

async def set_emphasis(self, emphasis: str):
logger.debug(f"Setting TTS emphasis to: [{emphasis}]")
self._params.emphasis = emphasis

async def set_language_code(self, language_code: str):
logger.debug(f"Setting TTS language code to: [{language_code}]")
self._params.language_code = language_code

async def set_pitch(self, pitch: str):
logger.debug(f"Setting TTS pitch to: [{pitch}]")
self._params.pitch = pitch

async def set_rate(self, rate: str):
logger.debug(f"Setting TTS rate to: [{rate}]")
self._params.rate = rate

async def set_role(self, role: str):
logger.debug(f"Setting TTS role to: [{role}]")
self._params.role = role

async def set_style(self, style: str):
logger.debug(f"Setting TTS style to: [{style}]")
self._params.style = style

async def set_style_degree(self, style_degree: str):
logger.debug(f"Setting TTS style degree to: [{style_degree}]")
self._params.style_degree = style_degree

async def set_volume(self, volume: str):
logger.debug(f"Setting TTS volume to: [{volume}]")
self._params.volume = volume

async def set_params(self, **kwargs):
valid_params = {
"voice": self.set_voice,
"emphasis": self.set_emphasis,
"language_code": self.set_language_code,
"pitch": self.set_pitch,
"rate": self.set_rate,
"role": self.set_role,
"style": self.set_style,
"style_degree": self.set_style_degree,
"volume": self.set_volume,
}

for param, value in kwargs.items():
if param in valid_params:
await valid_params[param](value)
else:
logger.warning(f"Ignoring unknown parameter: {param}")

logger.debug(f"Updated TTS parameters: {', '.join(kwargs.keys())}")

async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating TTS: [{text}]")

Expand All @@ -167,7 +218,9 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
await self.stop_ttfb_metrics()
await self.push_frame(TTSStartedFrame())
# Azure always sends a 44-byte header. Strip it off.
yield TTSAudioRawFrame(audio=result.audio_data[44:], sample_rate=self._sample_rate, num_channels=1)
yield TTSAudioRawFrame(
audio=result.audio_data[44:], sample_rate=self._sample_rate, num_channels=1
)
await self.push_frame(TTSStoppedFrame())
elif result.reason == ResultReason.Canceled:
cancellation_details = result.cancellation_details
Expand All @@ -178,14 +231,15 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:

class AzureSTTService(STTService):
def __init__(
self,
*,
api_key: str,
region: str,
language="en-US",
sample_rate=16000,
channels=1,
**kwargs):
self,
*,
api_key: str,
region: str,
language="en-US",
sample_rate=16000,
channels=1,
**kwargs,
):
super().__init__(**kwargs)

speech_config = SpeechConfig(subscription=api_key, region=region)
Expand All @@ -196,7 +250,8 @@ def __init__(

audio_config = AudioConfig(stream=self._audio_stream)
self._speech_recognizer = SpeechRecognizer(
speech_config=speech_config, audio_config=audio_config)
speech_config=speech_config, audio_config=audio_config
)
self._speech_recognizer.recognized.connect(self._on_handle_recognized)

async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
Expand Down Expand Up @@ -226,7 +281,6 @@ def _on_handle_recognized(self, event):


class AzureImageGenServiceREST(ImageGenService):

def __init__(
self,
*,
Expand All @@ -249,9 +303,7 @@ def __init__(
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
url = f"{self._azure_endpoint}openai/images/generations:submit?api-version={self._api_version}"

headers = {
"api-key": self._api_key,
"Content-Type": "application/json"}
headers = {"api-key": self._api_key, "Content-Type": "application/json"}

body = {
# Enter your prompt text here
Expand Down Expand Up @@ -293,8 +345,6 @@ async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
image_stream = io.BytesIO(await response.content.read())
image = Image.open(image_stream)
frame = URLImageRawFrame(
url=image_url,
image=image.tobytes(),
size=image.size,
format=image.format)
url=image_url, image=image.tobytes(), size=image.size, format=image.format
)
yield frame

0 comments on commit 667fcea

Please sign in to comment.