Skip to content

Commit

Permalink
Merge pull request #279 from pipecat-ai/aleix/gladia-stt-support
Browse files Browse the repository at this point in the history
add Gladia STT support
  • Loading branch information
aconchillo authored Jul 2, 2024
2 parents 974d9c3 + 82e93a0 commit 8f6db5e
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 8 deletions.
9 changes: 6 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added `GladiaSTTService`.
See https://docs.gladia.io/chapters/speech-to-text-api/pages/live-speech-recognition

- Added `XTTSService`. This is a local Text-To-Speech service.
See https://github.com/coqui-ai/TTS

- Added `UserIdleProcessor`. This processor can be used to wait for any
interaction with the user. If the user doesn't say anything within a given
timeout a provided callback is called.
Expand All @@ -20,9 +26,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added new frame `BotSpeakingFrame`. This frame will be continuously pushed
upstream while the bot is talking.

- Added `XTTSService`. This is a local Text-To-Speech service.
See https://github.com/coqui-ai/TTS

- It is now possible to specify a Silero VAD version when using `SileroVADAnalyzer`
or `SileroVAD`.

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pip install "pipecat-ai[option,...]"

Your project may or may not need these, so they're made available as optional requirements. Here is a list:

- **AI services**: `anthropic`, `azure`, `deepgram`, `google`, `fal`, `moondream`, `openai`, `openpipe`, `playht`, `silero`, `whisper`, `xtts`
- **AI services**: `anthropic`, `azure`, `deepgram`, `gladia`, `google`, `fal`, `moondream`, `openai`, `openpipe`, `playht`, `silero`, `whisper`, `xtts`
- **Transports**: `local`, `websocket`, `daily`

## Code examples
Expand Down
3 changes: 3 additions & 0 deletions dot-env.template
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ FAL_KEY=...
# Fireworks
FIREWORKS_API_KEY=...

# Gladia
GLADIA_API_KEY=...

# PlayHT
PLAY_HT_USER_ID=...
PLAY_HT_API_KEY=...
Expand Down
101 changes: 101 additions & 0 deletions examples/foundational/07j-interruptible-gladia.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import asyncio
import aiohttp
import os
import sys

from pipecat.frames.frames import LLMMessagesFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_response import (
LLMAssistantResponseAggregator, LLMUserResponseAggregator)
from pipecat.services.deepgram import DeepgramSTTService, DeepgramTTSService
from pipecat.services.gladia import GladiaSTTService
from pipecat.services.openai import OpenAILLMService
from pipecat.services.xtts import XTTSService
from pipecat.transports.services.daily import DailyParams, DailyTransport
from pipecat.vad.silero import SileroVADAnalyzer

from runner import configure

from loguru import logger

from dotenv import load_dotenv
load_dotenv(override=True)

logger.remove(0)
logger.add(sys.stderr, level="DEBUG")


async def main(room_url: str, token):
async with aiohttp.ClientSession() as session:
transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
vad_audio_passthrough=True,
)
)

stt = GladiaSTTService(
api_key=os.getenv("GLADIA_API_KEY"),
)

tts = DeepgramTTSService(
aiohttp_session=session,
api_key=os.getenv("DEEPGRAM_API_KEY"),
voice="aura-helios-en"
)

llm = OpenAILLMService(
api_key=os.getenv("OPENAI_API_KEY"),
model="gpt-4o")

messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]

tma_in = LLMUserResponseAggregator(messages)
tma_out = LLMAssistantResponseAggregator(messages)

pipeline = Pipeline([
transport.input(), # Transport user input
stt, # STT
tma_in, # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
tma_out # Assistant spoken responses
])

task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))

@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
transport.capture_participant_transcription(participant["id"])
# Kick off the conversation.
messages.append(
{"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMMessagesFrame(messages)])

runner = PipelineRunner()

await runner.run(task)


if __name__ == "__main__":
(url, token) = configure()
asyncio.run(main(url, token))
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ daily = [ "daily-python~=0.10.1" ]
deepgram = [ "deepgram-sdk~=3.2.7" ]
examples = [ "python-dotenv~=1.0.0", "flask~=3.0.3", "flask_cors~=4.0.1" ]
fal = [ "fal-client~=0.4.0" ]
gladia = [ "websockets~=12.0" ]
google = [ "google-generativeai~=0.5.3" ]
fireworks = [ "openai~=1.26.0" ]
langchain = [ "langchain~=0.2.1", "langchain-community~=0.2.1", "langchain-openai~=0.1.8" ]
Expand Down
2 changes: 1 addition & 1 deletion src/pipecat/services/fal.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:

response = await fal_client.run_async(
self._model,
arguments={"prompt": prompt, **self._params.model_dump()}
arguments={"prompt": prompt, **self._params.model_dump(exclude_none=True)}
)

image_url = response["images"][0]["url"] if response else None
Expand Down
115 changes: 115 additions & 0 deletions src/pipecat/services/gladia.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import base64
import json
import time

from typing import Optional
from pydantic.main import BaseModel

from pipecat.frames.frames import (
AudioRawFrame,
CancelFrame,
EndFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
SystemFrame,
TranscriptionFrame)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import AsyncAIService

from loguru import logger

# See .env.example for Gladia configuration needed
try:
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use Gladia, you need to `pip install pipecat-ai[gladia]`. Also, set `GLADIA_API_KEY` environment variable.")
raise Exception(f"Missing module: {e}")


class GladiaSTTService(AsyncAIService):
class InputParams(BaseModel):
sample_rate: Optional[int] = 16000
language: Optional[str] = "english"
transcription_hint: Optional[str] = None
endpointing: Optional[int] = 200
prosody: Optional[bool] = None

def __init__(self,
*,
api_key: str,
url: str = "wss://api.gladia.io/audio/text/audio-transcription",
confidence: float = 0.5,
params: InputParams = InputParams(),
**kwargs):
super().__init__(**kwargs)

self._api_key = api_key
self._url = url
self._params = params
self._confidence = confidence

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

if isinstance(frame, SystemFrame):
await self.push_frame(frame, direction)
elif isinstance(frame, AudioRawFrame):
await self._send_audio(frame)
else:
await self.queue_frame(frame, direction)

async def start(self, frame: StartFrame):
self._websocket = await websockets.connect(self._url)
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
await self._setup_gladia()

async def stop(self, frame: EndFrame):
await self._websocket.close()

async def cancel(self, frame: CancelFrame):
await self._websocket.close()

async def _setup_gladia(self):
configuration = {
"x_gladia_key": self._api_key,
"encoding": "WAV/PCM",
"model_type": "fast",
"language_behaviour": "manual",
**self._params.model_dump(exclude_none=True)
}

await self._websocket.send(json.dumps(configuration))

async def _send_audio(self, frame: AudioRawFrame):
message = {
'frames': base64.b64encode(frame.audio).decode("utf-8")
}
await self._websocket.send(json.dumps(message))

async def _receive_task_handler(self):
async for message in self._websocket:
utterance = json.loads(message)
if not utterance:
continue

if "error" in utterance:
message = utterance["message"]
logger.error(f"Gladia error: {message}")
elif "confidence" in utterance:
type = utterance["type"]
confidence = utterance["confidence"]
transcript = utterance["transcription"]
if confidence >= self._confidence:
if type == "final":
await self.queue_frame(TranscriptionFrame(transcript, "", int(time.time_ns() / 1000000)))
else:
await self.queue_frame(InterimTranscriptionFrame(transcript, "", int(time.time_ns() / 1000000)))
6 changes: 3 additions & 3 deletions src/pipecat/transports/services/daily.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time

from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Mapping
from typing import Any, Awaitable, Callable, Mapping, Optional
from concurrent.futures import ThreadPoolExecutor

from daily import (
Expand Down Expand Up @@ -101,7 +101,7 @@ class DailyTranscriptionSettings(BaseModel):
class DailyParams(TransportParams):
api_url: str = "https://api.daily.co/v1"
api_key: str = ""
dialin_settings: DailyDialinSettings | None = None
dialin_settings: Optional[DailyDialinSettings] = None
transcription_enabled: bool = False
transcription_settings: DailyTranscriptionSettings = DailyTranscriptionSettings()

Expand Down Expand Up @@ -268,7 +268,7 @@ async def join(self):
logger.info(
f"Enabling transcription with settings {self._params.transcription_settings}")
self._client.start_transcription(
self._params.transcription_settings.model_dump())
self._params.transcription_settings.model_dump(exclude_none=True))

await self._callbacks.on_joined(data["participants"]["local"])
else:
Expand Down

0 comments on commit 8f6db5e

Please sign in to comment.