Skip to content

Commit

Permalink
Merge pull request #201 from TomTom101/TomTom101/openai_tts
Browse files Browse the repository at this point in the history
Added OpenAI TTS (#196)
  • Loading branch information
aconchillo authored Jun 4, 2024
2 parents fe71825 + d462c03 commit 20a5256
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 11 deletions.
74 changes: 63 additions & 11 deletions src/pipecat/services/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import base64
import io
import json
import time
import aiohttp
import base64
from typing import AsyncGenerator, List, Literal

import aiohttp
from loguru import logger
from PIL import Image

from typing import AsyncGenerator, List, Literal

from pipecat.frames.frames import (
AudioRawFrame,
ErrorFrame,
Frame,
LLMFullResponseEndFrame,
Expand All @@ -26,15 +26,19 @@
URLImageRawFrame,
VisionImageRawFrame
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext, OpenAILLMContextFrame
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import LLMService, ImageGenService

from loguru import logger
from pipecat.services.ai_services import (
ImageGenService,
LLMService,
TTSService
)

try:
from openai import AsyncOpenAI, AsyncStream

from openai import AsyncOpenAI, AsyncStream, BadRequestError
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
Expand Down Expand Up @@ -272,3 +276,51 @@ async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
image = Image.open(image_stream)
frame = URLImageRawFrame(image_url, image.tobytes(), image.size, image.format)
yield frame


class OpenAITTSService(TTSService):
"""This service uses the OpenAI TTS API to generate audio from text.
The returned audio is PCM encoded at 24kHz. When using the DailyTransport, set the sample rate in the DailyParams accordingly:
```
DailyParams(
audio_out_enabled=True,
audio_out_sample_rate=24_000,
)
```
"""

def __init__(
self,
*,
api_key: str | None = None,
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = "alloy",
model: Literal["tts-1", "tts-1-hd"] = "tts-1",
**kwargs):
super().__init__(**kwargs)

self._voice = voice
self._model = model

self._client = AsyncOpenAI(api_key=api_key)

async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating TTS: [{text}]")

try:
async with self._client.audio.speech.with_streaming_response.create(
input=text,
model=self._model,
voice=self._voice,
response_format="pcm",
) as r:
if r.status_code != 200:
error = await r.text()
logger.error(f"Error getting audio (status: {r.status_code}, error: {error})")
yield ErrorFrame(f"Error getting audio (status: {r.status_code}, error: {error})")
return
async for chunk in r.iter_bytes(8192):
if len(chunk) > 0:
frame = AudioRawFrame(chunk, 24_000, 1)
yield frame
except BadRequestError as e:
logger.error(f"Error generating TTS: {e}")
39 changes: 39 additions & 0 deletions tests/test_openai_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import asyncio
import unittest

import openai
import pyaudio
from dotenv import load_dotenv

from pipecat.frames.frames import AudioRawFrame, ErrorFrame
from pipecat.services.openai import OpenAITTSService

load_dotenv()


class TestWhisperOpenAIService(unittest.IsolatedAsyncioTestCase):
async def test_whisper_tts(self):
pa = pyaudio.PyAudio()
stream = pa.open(format=pyaudio.paInt16,
channels=1,
rate=24_000,
output=True)

tts = OpenAITTSService(voice="nova")

async for frame in tts.run_tts("Hello, there. Nice to meet you, seems to work well"):
self.assertIsInstance(frame, AudioRawFrame)
stream.write(frame.audio)

await asyncio.sleep(.5)
stream.stop_stream()
pa.terminate()

tts = OpenAITTSService(voice="invalid_voice")
with self.assertRaises(openai.BadRequestError):
async for frame in tts.run_tts("wont work"):
self.assertIsInstance(frame, ErrorFrame)


if __name__ == "__main__":
unittest.main()

0 comments on commit 20a5256

Please sign in to comment.