diff --git a/CHANGELOG.md b/CHANGELOG.md
index 0686d95b9..432eaf9bb 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
+- Added WebsocketServerTransport. This will create a websocket server and will
+ read messages coming from a client. The messages are serialized/deserialized
+ with protobufs. See `examples/websocket-server` for a detailed example.
+
- Added function calling (LLMService.register_function()). This will allow the
LLM to call functions you have registered when needed. For example, if you
register a function to get the weather in Los Angeles and ask the LLM about
@@ -24,6 +28,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed an issue where `camera_out_enabled` would cause the highg CPU usage if
no image was provided.
+### Performance
+
+- Removed unnecessary audio input tasks.
## [0.0.24] - 2024-05-29
diff --git a/LICENSE b/LICENSE
index b60f5327e..cd6220df2 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,6 +1,6 @@
BSD 2-Clause License
-Copyright (c) 2024, Kwindla Hultman Kramer
+Copyright (c) 2024, Daily
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
diff --git a/dev-requirements.txt b/dev-requirements.txt
index 9e0d93cbe..2d7da9ca6 100644
--- a/dev-requirements.txt
+++ b/dev-requirements.txt
@@ -1,5 +1,6 @@
autopep8~=2.1.0
build~=1.2.1
+grpcio-tools~=1.62.2
pip-tools~=7.4.1
pytest~=8.2.0
setuptools~=69.5.1
diff --git a/examples/foundational/02-llm-say-one-thing.py b/examples/foundational/02-llm-say-one-thing.py
index 7e12263dc..20756dcb6 100644
--- a/examples/foundational/02-llm-say-one-thing.py
+++ b/examples/foundational/02-llm-say-one-thing.py
@@ -44,7 +44,7 @@ async def main(room_url):
llm = OpenAILLMService(
api_key=os.getenv("OPENAI_API_KEY"),
- model="gpt-4-turbo-preview")
+ model="gpt-4o")
messages = [
{
diff --git a/examples/foundational/05-sync-speech-and-image.py b/examples/foundational/05-sync-speech-and-image.py
index 60dd50d07..f057c847c 100644
--- a/examples/foundational/05-sync-speech-and-image.py
+++ b/examples/foundational/05-sync-speech-and-image.py
@@ -93,7 +93,7 @@ async def main(room_url):
llm = OpenAILLMService(
api_key=os.getenv("OPENAI_API_KEY"),
- model="gpt-4-turbo-preview")
+ model="gpt-4o")
imagegen = FalImageGenService(
params=FalImageGenService.InputParams(
diff --git a/examples/foundational/05a-local-sync-speech-and-image.py b/examples/foundational/05a-local-sync-speech-and-image.py
index bfbd453e2..d476754fb 100644
--- a/examples/foundational/05a-local-sync-speech-and-image.py
+++ b/examples/foundational/05a-local-sync-speech-and-image.py
@@ -76,7 +76,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
llm = OpenAILLMService(
api_key=os.getenv("OPENAI_API_KEY"),
- model="gpt-4-turbo-preview")
+ model="gpt-4o")
tts = ElevenLabsTTSService(
aiohttp_session=session,
diff --git a/examples/foundational/06a-image-sync.py b/examples/foundational/06a-image-sync.py
index 3ec2752b4..2f5528ee4 100644
--- a/examples/foundational/06a-image-sync.py
+++ b/examples/foundational/06a-image-sync.py
@@ -81,7 +81,7 @@ async def main(room_url: str, token):
llm = OpenAILLMService(
api_key=os.getenv("OPENAI_API_KEY"),
- model="gpt-4-turbo-preview")
+ model="gpt-4o")
messages = [
{
diff --git a/examples/foundational/07-interruptible.py b/examples/foundational/07-interruptible.py
index ce37344f0..9ed146774 100644
--- a/examples/foundational/07-interruptible.py
+++ b/examples/foundational/07-interruptible.py
@@ -53,7 +53,7 @@ async def main(room_url: str, token):
llm = OpenAILLMService(
api_key=os.getenv("OPENAI_API_KEY"),
- model="gpt-4-turbo-preview")
+ model="gpt-4o")
messages = [
{
diff --git a/examples/foundational/07c-interruptible-deepgram.py b/examples/foundational/07c-interruptible-deepgram.py
index 27245b02b..818c8bc93 100644
--- a/examples/foundational/07c-interruptible-deepgram.py
+++ b/examples/foundational/07c-interruptible-deepgram.py
@@ -53,7 +53,7 @@ async def main(room_url: str, token):
llm = OpenAILLMService(
api_key=os.getenv("OPENAI_API_KEY"),
- model="gpt-4-turbo-preview")
+ model="gpt-4o")
messages = [
{
diff --git a/examples/foundational/11-sound-effects.py b/examples/foundational/11-sound-effects.py
index 1ca568bf0..2a3e8effc 100644
--- a/examples/foundational/11-sound-effects.py
+++ b/examples/foundational/11-sound-effects.py
@@ -95,7 +95,7 @@ async def main(room_url: str, token):
llm = OpenAILLMService(
api_key=os.getenv("OPENAI_API_KEY"),
- model="gpt-4-turbo-preview")
+ model="gpt-4o")
tts = ElevenLabsTTSService(
aiohttp_session=session,
diff --git a/examples/foundational/14-function-calling.py b/examples/foundational/14-function-calling.py
index 14f834fe1..4a3a8b515 100644
--- a/examples/foundational/14-function-calling.py
+++ b/examples/foundational/14-function-calling.py
@@ -66,7 +66,7 @@ async def main(room_url: str, token):
llm = OpenAILLMService(
api_key=os.getenv("OPENAI_API_KEY"),
- model="gpt-4-turbo-preview")
+ model="gpt-4o")
llm.register_function(
"get_current_weather",
fetch_weather_from_api,
diff --git a/examples/foundational/websocket-server/frames.proto b/examples/foundational/websocket-server/frames.proto
deleted file mode 100644
index 830e3062c..000000000
--- a/examples/foundational/websocket-server/frames.proto
+++ /dev/null
@@ -1,25 +0,0 @@
-syntax = "proto3";
-
-package pipecat_proto;
-
-message TextFrame {
- string text = 1;
-}
-
-message AudioFrame {
- bytes audio = 1;
-}
-
-message TranscriptionFrame {
- string text = 1;
- string participant_id = 2;
- string timestamp = 3;
-}
-
-message Frame {
- oneof frame {
- TextFrame text = 1;
- AudioFrame audio = 2;
- TranscriptionFrame transcription = 3;
- }
-}
diff --git a/examples/foundational/websocket-server/index.html b/examples/foundational/websocket-server/index.html
deleted file mode 100644
index a38e1e78b..000000000
--- a/examples/foundational/websocket-server/index.html
+++ /dev/null
@@ -1,134 +0,0 @@
-
-
-
-
-
-
-
- WebSocket Audio Stream
-
-
-
- WebSocket Audio Stream
-
-
-
-
-
-
diff --git a/examples/foundational/websocket-server/sample.py b/examples/foundational/websocket-server/sample.py
deleted file mode 100644
index b3a4a731d..000000000
--- a/examples/foundational/websocket-server/sample.py
+++ /dev/null
@@ -1,50 +0,0 @@
-import asyncio
-import aiohttp
-import logging
-import os
-from pipecat.pipeline.frame_processor import FrameProcessor
-from pipecat.pipeline.frames import TextFrame, TranscriptionFrame
-from pipecat.pipeline.pipeline import Pipeline
-from pipecat.services.elevenlabs_ai_services import ElevenLabsTTSService
-from pipecat.transports.websocket_transport import WebsocketTransport
-from pipecat.services.whisper_ai_services import WhisperSTTService
-
-logging.basicConfig(format="%(levelno)s %(asctime)s %(message)s")
-logger = logging.getLogger("pipecat")
-logger.setLevel(logging.DEBUG)
-
-
-class WhisperTranscriber(FrameProcessor):
- async def process_frame(self, frame):
- if isinstance(frame, TranscriptionFrame):
- print(f"Transcribed: {frame.text}")
- else:
- yield frame
-
-
-async def main():
- async with aiohttp.ClientSession() as session:
- transport = WebsocketTransport(
- mic_enabled=True,
- speaker_enabled=True,
- )
- tts = ElevenLabsTTSService(
- aiohttp_session=session,
- api_key=os.getenv("ELEVENLABS_API_KEY"),
- voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
- )
-
- pipeline = Pipeline([
- WhisperSTTService(),
- WhisperTranscriber(),
- tts,
- ])
-
- @transport.on_connection
- async def queue_frame():
- await pipeline.queue_frames([TextFrame("Hello there!")])
-
- await transport.run(pipeline)
-
-if __name__ == "__main__":
- asyncio.run(main())
diff --git a/examples/moondream-chatbot/bot.py b/examples/moondream-chatbot/bot.py
index 7830cf46a..3e9ced259 100644
--- a/examples/moondream-chatbot/bot.py
+++ b/examples/moondream-chatbot/bot.py
@@ -145,7 +145,7 @@ async def main(room_url: str, token):
llm = OpenAILLMService(
api_key=os.getenv("OPENAI_API_KEY"),
- model="gpt-4-turbo-preview")
+ model="gpt-4o")
ta = TalkingAnimation()
diff --git a/examples/simple-chatbot/bot.py b/examples/simple-chatbot/bot.py
index a63b215ab..80b60833f 100644
--- a/examples/simple-chatbot/bot.py
+++ b/examples/simple-chatbot/bot.py
@@ -117,7 +117,7 @@ async def main(room_url: str, token):
llm = OpenAILLMService(
api_key=os.getenv("OPENAI_API_KEY"),
- model="gpt-4-turbo-preview")
+ model="gpt-4o")
messages = [
{
diff --git a/examples/storytelling-chatbot/src/bot.py b/examples/storytelling-chatbot/src/bot.py
index c5a75e949..96bb7626a 100644
--- a/examples/storytelling-chatbot/src/bot.py
+++ b/examples/storytelling-chatbot/src/bot.py
@@ -56,7 +56,7 @@ async def main(room_url, token=None):
llm_service = OpenAILLMService(
api_key=os.getenv("OPENAI_API_KEY"),
- model="gpt-4-turbo"
+ model="gpt-4o"
)
tts_service = ElevenLabsTTSService(
diff --git a/examples/translation-chatbot/bot.py b/examples/translation-chatbot/bot.py
index 89ca461b1..38667d897 100644
--- a/examples/translation-chatbot/bot.py
+++ b/examples/translation-chatbot/bot.py
@@ -97,7 +97,8 @@ async def main(room_url: str, token):
)
llm = OpenAILLMService(
- api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4-turbo-preview"
+ api_key=os.getenv("OPENAI_API_KEY"),
+ model="gpt-4o"
)
sa = SentenceAggregator()
diff --git a/examples/websocket-server/README.md b/examples/websocket-server/README.md
new file mode 100644
index 000000000..8417f5a9d
--- /dev/null
+++ b/examples/websocket-server/README.md
@@ -0,0 +1,27 @@
+# Websocket Server
+
+This is an example that shows how to use `WebsocketServerTransport` to communicate with a web client.
+
+## Get started
+
+```python
+python3 -m venv venv
+source venv/bin/activate
+pip install -r requirements.txt
+```
+
+## Run the bot
+
+```bash
+python bot.py
+```
+
+## Run the HTTP server
+
+This will host the static web client:
+
+```bash
+python -m http.server
+```
+
+Then, visit `http://localhost:8000` in your browser to start a session.
diff --git a/examples/websocket-server/bot.py b/examples/websocket-server/bot.py
new file mode 100644
index 000000000..ba61084de
--- /dev/null
+++ b/examples/websocket-server/bot.py
@@ -0,0 +1,94 @@
+#
+# Copyright (c) 2024, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import aiohttp
+import asyncio
+import os
+import sys
+
+from pipecat.frames.frames import LLMMessagesFrame
+from pipecat.pipeline.pipeline import Pipeline
+from pipecat.pipeline.runner import PipelineRunner
+from pipecat.pipeline.task import PipelineParams, PipelineTask
+from pipecat.processors.aggregators.llm_response import (
+ LLMAssistantResponseAggregator,
+ LLMUserResponseAggregator
+)
+from pipecat.services.elevenlabs import ElevenLabsTTSService
+from pipecat.services.openai import OpenAILLMService
+from pipecat.services.whisper import WhisperSTTService
+from pipecat.transports.network.websocket_server import WebsocketServerParams, WebsocketServerTransport
+from pipecat.vad.silero import SileroVADAnalyzer
+
+from loguru import logger
+
+from dotenv import load_dotenv
+load_dotenv(override=True)
+
+logger.remove(0)
+logger.add(sys.stderr, level="DEBUG")
+
+
+async def main():
+ async with aiohttp.ClientSession() as session:
+ transport = WebsocketServerTransport(
+ params=WebsocketServerParams(
+ audio_in_enabled=True,
+ audio_out_enabled=True,
+ add_wav_header=True,
+ vad_enabled=True,
+ vad_analyzer=SileroVADAnalyzer(),
+ vad_audio_passthrough=True
+ )
+ )
+
+ llm = OpenAILLMService(
+ api_key=os.getenv("OPENAI_API_KEY"),
+ model="gpt-4o")
+
+ stt = WhisperSTTService()
+
+ tts = ElevenLabsTTSService(
+ aiohttp_session=session,
+ api_key=os.getenv("ELEVENLABS_API_KEY"),
+ voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
+ )
+
+ messages = [
+ {
+ "role": "system",
+ "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
+ },
+ ]
+
+ tma_in = LLMUserResponseAggregator(messages)
+ tma_out = LLMAssistantResponseAggregator(messages)
+
+ pipeline = Pipeline([
+ transport.input(), # Websocket input from client
+ stt, # Speech-To-Text
+ tma_in, # User responses
+ llm, # LLM
+ tts, # Text-To-Speech
+ transport.output(), # Websocket output to client
+ tma_out # LLM responses
+ ])
+
+ task = PipelineTask(pipeline)
+
+ @transport.event_handler("on_client_connected")
+ async def on_client_connected(transport, client):
+ # Kick off the conversation.
+ messages.append(
+ {"role": "system", "content": "Please introduce yourself to the user."})
+ await task.queue_frames([LLMMessagesFrame(messages)])
+
+ runner = PipelineRunner()
+
+ await runner.run(task)
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/examples/websocket-server/frames.proto b/examples/websocket-server/frames.proto
new file mode 100644
index 000000000..5c5d81d4d
--- /dev/null
+++ b/examples/websocket-server/frames.proto
@@ -0,0 +1,43 @@
+//
+// Copyright (c) 2024, Daily
+//
+// SPDX-License-Identifier: BSD 2-Clause License
+//
+
+// Generate frames_pb2.py with:
+//
+// python -m grpc_tools.protoc --proto_path=./ --python_out=./protobufs frames.proto
+
+syntax = "proto3";
+
+package pipecat;
+
+message TextFrame {
+ uint64 id = 1;
+ string name = 2;
+ string text = 3;
+}
+
+message AudioRawFrame {
+ uint64 id = 1;
+ string name = 2;
+ bytes audio = 3;
+ uint32 sample_rate = 4;
+ uint32 num_channels = 5;
+}
+
+message TranscriptionFrame {
+ uint64 id = 1;
+ string name = 2;
+ string text = 3;
+ string user_id = 4;
+ string timestamp = 5;
+}
+
+message Frame {
+ oneof frame {
+ TextFrame text = 1;
+ AudioRawFrame audio = 2;
+ TranscriptionFrame transcription = 3;
+ }
+}
diff --git a/examples/websocket-server/index.html b/examples/websocket-server/index.html
new file mode 100644
index 000000000..514a4a821
--- /dev/null
+++ b/examples/websocket-server/index.html
@@ -0,0 +1,205 @@
+
+
+
+
+
+
+
+ Pipecat WebSocket Client Example
+
+
+
+ Pipecat WebSocket Client Example
+ Loading, wait...
+
+
+
+
+
+
diff --git a/examples/websocket-server/requirements.txt b/examples/websocket-server/requirements.txt
new file mode 100644
index 000000000..77e5b9e91
--- /dev/null
+++ b/examples/websocket-server/requirements.txt
@@ -0,0 +1,2 @@
+python-dotenv
+pipecat-ai[openai,silero,websocket,whisper]
diff --git a/linux-py3.10-requirements.txt b/linux-py3.10-requirements.txt
index 204a0e932..d3687c17e 100644
--- a/linux-py3.10-requirements.txt
+++ b/linux-py3.10-requirements.txt
@@ -42,7 +42,7 @@ coloredlogs==15.0.1
# via onnxruntime
ctranslate2==4.2.1
# via faster-whisper
-daily-python==0.9.0
+daily-python==0.9.1
# via pipecat-ai (pyproject.toml)
distro==1.9.0
# via
@@ -226,6 +226,7 @@ protobuf==4.25.3
# googleapis-common-protos
# grpcio-status
# onnxruntime
+ # pipecat-ai (pyproject.toml)
# proto-plus
# pyht
pyasn1==0.6.0
@@ -259,7 +260,7 @@ pyyaml==6.0.1
# transformers
regex==2024.5.15
# via transformers
-requests==2.32.2
+requests==2.32.3
# via
# google-api-core
# huggingface-hub
diff --git a/macos-py3.10-requirements.txt b/macos-py3.10-requirements.txt
index 35ddcd8b6..e6f91d0f9 100644
--- a/macos-py3.10-requirements.txt
+++ b/macos-py3.10-requirements.txt
@@ -208,6 +208,7 @@ protobuf==4.25.3
# googleapis-common-protos
# grpcio-status
# onnxruntime
+ # pipecat-ai (pyproject.toml)
# proto-plus
# pyht
pyasn1==0.6.0
diff --git a/pyproject.toml b/pyproject.toml
index 90363e34b..aa3558f87 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -24,6 +24,7 @@ dependencies = [
"numpy~=1.26.4",
"loguru~=0.7.0",
"Pillow~=10.3.0",
+ "protobuf~=4.25.3",
"pyloudnorm~=0.1.1",
"typing-extensions~=4.11.0",
]
diff --git a/src/pipecat/frames/frames.proto b/src/pipecat/frames/frames.proto
index 18e59e492..5c5d81d4d 100644
--- a/src/pipecat/frames/frames.proto
+++ b/src/pipecat/frames/frames.proto
@@ -4,28 +4,40 @@
// SPDX-License-Identifier: BSD 2-Clause License
//
+// Generate frames_pb2.py with:
+//
+// python -m grpc_tools.protoc --proto_path=./ --python_out=./protobufs frames.proto
+
syntax = "proto3";
-package pipecat_proto;
+package pipecat;
message TextFrame {
- string text = 1;
+ uint64 id = 1;
+ string name = 2;
+ string text = 3;
}
-message AudioFrame {
- bytes data = 1;
+message AudioRawFrame {
+ uint64 id = 1;
+ string name = 2;
+ bytes audio = 3;
+ uint32 sample_rate = 4;
+ uint32 num_channels = 5;
}
message TranscriptionFrame {
- string text = 1;
- string participantId = 2;
- string timestamp = 3;
+ uint64 id = 1;
+ string name = 2;
+ string text = 3;
+ string user_id = 4;
+ string timestamp = 5;
}
message Frame {
- oneof frame {
- TextFrame text = 1;
- AudioFrame audio = 2;
- TranscriptionFrame transcription = 3;
- }
+ oneof frame {
+ TextFrame text = 1;
+ AudioRawFrame audio = 2;
+ TranscriptionFrame transcription = 3;
+ }
}
diff --git a/src/pipecat/frames/protobufs/frames_pb2.py b/src/pipecat/frames/protobufs/frames_pb2.py
index bdc34d385..5040efc97 100644
--- a/src/pipecat/frames/protobufs/frames_pb2.py
+++ b/src/pipecat/frames/protobufs/frames_pb2.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: frames.proto
-# Protobuf Python Version: 4.25.3
+# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
@@ -14,19 +14,19 @@
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x66rames.proto\x12\rpipecat_proto\"\x19\n\tTextFrame\x12\x0c\n\x04text\x18\x01 \x01(\t\"\x1a\n\nAudioFrame\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"L\n\x12TranscriptionFrame\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x15\n\rparticipantId\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\t\"\xa2\x01\n\x05\x46rame\x12(\n\x04text\x18\x01 \x01(\x0b\x32\x18.pipecat_proto.TextFrameH\x00\x12*\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x19.pipecat_proto.AudioFrameH\x00\x12:\n\rtranscription\x18\x03 \x01(\x0b\x32!.pipecat_proto.TranscriptionFrameH\x00\x42\x07\n\x05\x66rameb\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x66rames.proto\x12\x07pipecat\"3\n\tTextFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\"c\n\rAudioRawFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05\x61udio\x18\x03 \x01(\x0c\x12\x13\n\x0bsample_rate\x18\x04 \x01(\r\x12\x14\n\x0cnum_channels\x18\x05 \x01(\r\"`\n\x12TranscriptionFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\x12\x0f\n\x07user_id\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\t\"\x93\x01\n\x05\x46rame\x12\"\n\x04text\x18\x01 \x01(\x0b\x32\x12.pipecat.TextFrameH\x00\x12\'\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x16.pipecat.AudioRawFrameH\x00\x12\x34\n\rtranscription\x18\x03 \x01(\x0b\x32\x1b.pipecat.TranscriptionFrameH\x00\x42\x07\n\x05\x66rameb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'frames_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
- _globals['_TEXTFRAME']._serialized_start=31
- _globals['_TEXTFRAME']._serialized_end=56
- _globals['_AUDIOFRAME']._serialized_start=58
- _globals['_AUDIOFRAME']._serialized_end=84
- _globals['_TRANSCRIPTIONFRAME']._serialized_start=86
- _globals['_TRANSCRIPTIONFRAME']._serialized_end=162
- _globals['_FRAME']._serialized_start=165
- _globals['_FRAME']._serialized_end=327
+ _globals['_TEXTFRAME']._serialized_start=25
+ _globals['_TEXTFRAME']._serialized_end=76
+ _globals['_AUDIORAWFRAME']._serialized_start=78
+ _globals['_AUDIORAWFRAME']._serialized_end=177
+ _globals['_TRANSCRIPTIONFRAME']._serialized_start=179
+ _globals['_TRANSCRIPTIONFRAME']._serialized_end=275
+ _globals['_FRAME']._serialized_start=278
+ _globals['_FRAME']._serialized_end=425
# @@protoc_insertion_point(module_scope)
diff --git a/src/pipecat/pipeline/pipeline.py b/src/pipecat/pipeline/pipeline.py
index 5d1f9d3cf..2ba061a37 100644
--- a/src/pipecat/pipeline/pipeline.py
+++ b/src/pipecat/pipeline/pipeline.py
@@ -67,7 +67,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
await self._sink.process_frame(frame, FrameDirection.UPSTREAM)
async def _cleanup_processors(self):
- await asyncio.gather(*[p.cleanup() for p in self._processors])
+ for p in self._processors:
+ await p.cleanup()
def _link_processors(self):
prev = self._processors[0]
diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py
index 3bb750218..793aa8eb1 100644
--- a/src/pipecat/processors/frame_processor.py
+++ b/src/pipecat/processors/frame_processor.py
@@ -5,7 +5,7 @@
#
import asyncio
-from asyncio import AbstractEventLoop
+
from enum import Enum
from pipecat.frames.frames import ErrorFrame, Frame
@@ -21,12 +21,12 @@ class FrameDirection(Enum):
class FrameProcessor:
- def __init__(self):
+ def __init__(self, loop: asyncio.AbstractEventLoop | None = None):
self.id: int = obj_id()
self.name = f"{self.__class__.__name__}#{obj_count(self)}"
self._prev: "FrameProcessor" | None = None
self._next: "FrameProcessor" | None = None
- self._loop: AbstractEventLoop = asyncio.get_running_loop()
+ self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_running_loop()
async def cleanup(self):
pass
@@ -36,7 +36,7 @@ def link(self, processor: 'FrameProcessor'):
processor._prev = self
logger.debug(f"Linking {self} -> {self._next}")
- def get_event_loop(self) -> AbstractEventLoop:
+ def get_event_loop(self) -> asyncio.AbstractEventLoop:
return self._loop
async def process_frame(self, frame: Frame, direction: FrameDirection):
diff --git a/src/pipecat/serializers/abstract_frame_serializer.py b/src/pipecat/serializers/abstract_frame_serializer.py
deleted file mode 100644
index 8da0bd11d..000000000
--- a/src/pipecat/serializers/abstract_frame_serializer.py
+++ /dev/null
@@ -1,16 +0,0 @@
-from abc import abstractmethod
-
-from pipecat.pipeline.frames import Frame
-
-
-class FrameSerializer:
- def __init__(self):
- pass
-
- @abstractmethod
- def serialize(self, frame: Frame) -> bytes:
- raise NotImplementedError
-
- @abstractmethod
- def deserialize(self, data: bytes) -> Frame:
- raise NotImplementedError
diff --git a/src/pipecat/serializers/base_serializer.py b/src/pipecat/serializers/base_serializer.py
new file mode 100644
index 000000000..c137f873d
--- /dev/null
+++ b/src/pipecat/serializers/base_serializer.py
@@ -0,0 +1,20 @@
+#
+# Copyright (c) 2024, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+from abc import ABC, abstractmethod
+
+from pipecat.frames.frames import Frame
+
+
+class FrameSerializer(ABC):
+
+ @abstractmethod
+ def serialize(self, frame: Frame) -> bytes:
+ pass
+
+ @abstractmethod
+ def deserialize(self, data: bytes) -> Frame:
+ pass
diff --git a/src/pipecat/serializers/protobuf_serializer.py b/src/pipecat/serializers/protobuf.py
similarity index 72%
rename from src/pipecat/serializers/protobuf_serializer.py
rename to src/pipecat/serializers/protobuf.py
index 04b348b86..50692a51b 100644
--- a/src/pipecat/serializers/protobuf_serializer.py
+++ b/src/pipecat/serializers/protobuf.py
@@ -1,14 +1,21 @@
+#
+# Copyright (c) 2024, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
import dataclasses
-from typing import Text
-from pipecat.pipeline.frames import AudioFrame, Frame, TextFrame, TranscriptionFrame
-import pipecat.pipeline.protobufs.frames_pb2 as frame_protos
-from pipecat.serializers.abstract_frame_serializer import FrameSerializer
+
+import pipecat.frames.protobufs.frames_pb2 as frame_protos
+
+from pipecat.frames.frames import AudioRawFrame, Frame, TextFrame, TranscriptionFrame
+from pipecat.serializers.base_serializer import FrameSerializer
class ProtobufFrameSerializer(FrameSerializer):
SERIALIZABLE_TYPES = {
TextFrame: "text",
- AudioFrame: "audio",
+ AudioRawFrame: "audio",
TranscriptionFrame: "transcription"
}
@@ -29,7 +36,8 @@ def serialize(self, frame: Frame) -> bytes:
setattr(getattr(proto_frame, proto_optional_name), field.name,
getattr(frame, field.name))
- return proto_frame.SerializeToString()
+ result = proto_frame.SerializeToString()
+ return result
def deserialize(self, data: bytes) -> Frame:
"""Returns a Frame object from a Frame protobuf. Used to convert frames
@@ -61,4 +69,22 @@ def deserialize(self, data: bytes) -> Frame:
args_dict = {}
for field in proto.DESCRIPTOR.fields_by_name[which].message_type.fields:
args_dict[field.name] = getattr(args, field.name)
- return class_name(**args_dict)
+
+ # Remove special fields if needed
+ id = getattr(args, "id")
+ name = getattr(args, "name")
+ if not id:
+ del args_dict["id"]
+ if not name:
+ del args_dict["name"]
+
+ # Create the instance
+ instance = class_name(**args_dict)
+
+ # Set special fields
+ if id:
+ setattr(instance, "id", getattr(args, "id"))
+ if name:
+ setattr(instance, "name", getattr(args, "name"))
+
+ return instance
diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py
index a7f74ccc3..fc879d887 100644
--- a/src/pipecat/services/ai_services.py
+++ b/src/pipecat/services/ai_services.py
@@ -196,7 +196,7 @@ def __init__(self):
super().__init__()
# Renders the image. Returns an Image object.
- @ abstractmethod
+ @abstractmethod
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
pass
@@ -215,7 +215,7 @@ def __init__(self):
super().__init__()
self._describe_text = None
- @ abstractmethod
+ @abstractmethod
async def run_vision(self, frame: VisionImageRawFrame) -> AsyncGenerator[Frame, None]:
pass
diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py
index 96c855fa4..ac5fa1d99 100644
--- a/src/pipecat/services/openai.py
+++ b/src/pipecat/services/openai.py
@@ -229,7 +229,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
class OpenAILLMService(BaseOpenAILLMService):
- def __init__(self, model="gpt-4", **kwargs):
+ def __init__(self, model="gpt-4o", **kwargs):
super().__init__(model, **kwargs)
diff --git a/src/pipecat/storage/__init__.py b/src/pipecat/storage/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/src/pipecat/storage/search.py b/src/pipecat/storage/search.py
deleted file mode 100644
index 5740d1cae..000000000
--- a/src/pipecat/storage/search.py
+++ /dev/null
@@ -1,9 +0,0 @@
-class SearchIndexer():
- def __init__(self, story_id):
- pass
-
- def index_text(self, text):
- pass
-
- def index_image(self, text):
- pass
diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py
index 71733c59c..ba4f8ae4e 100644
--- a/src/pipecat/transports/base_input.py
+++ b/src/pipecat/transports/base_input.py
@@ -21,7 +21,7 @@
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame)
from pipecat.transports.base_transport import TransportParams
-from pipecat.vad.vad_analyzer import VADState
+from pipecat.vad.vad_analyzer import VADAnalyzer, VADState
from loguru import logger
@@ -59,10 +59,7 @@ async def start(self, frame: StartFrame):
if self._params.audio_in_enabled or self._params.vad_enabled:
loop = self.get_event_loop()
- self._audio_in_thread = loop.run_in_executor(
- self._in_executor, self._audio_in_thread_handler)
- self._audio_out_thread = loop.run_in_executor(
- self._in_executor, self._audio_out_thread_handler)
+ self._audio_thread = loop.run_in_executor(self._in_executor, self._audio_thread_handler)
async def stop(self):
if not self._running:
@@ -73,15 +70,14 @@ async def stop(self):
# Wait for the threads to finish.
if self._params.audio_in_enabled or self._params.vad_enabled:
- await self._audio_in_thread
- await self._audio_out_thread
+ await self._audio_thread
self._push_frame_task.cancel()
- def vad_analyze(self, audio_frames: bytes) -> VADState:
- pass
+ def vad_analyzer(self) -> VADAnalyzer | None:
+ return self._params.vad_analyzer
- def read_raw_audio_frames(self, frame_count: int) -> bytes:
+ def read_next_audio_frame(self) -> AudioRawFrame | None:
pass
#
@@ -150,8 +146,15 @@ async def _handle_interruptions(self, frame: Frame):
# Audio input
#
+ def _vad_analyze(self, audio_frames: bytes) -> VADState:
+ state = VADState.QUIET
+ vad_analyzer = self.vad_analyzer()
+ if vad_analyzer:
+ state = vad_analyzer.analyze_audio(audio_frames)
+ return state
+
def _handle_vad(self, audio_frames: bytes, vad_state: VADState):
- new_vad_state = self.vad_analyze(audio_frames)
+ new_vad_state = self._vad_analyze(audio_frames)
if new_vad_state != vad_state and new_vad_state != VADState.STARTING and new_vad_state != VADState.STOPPING:
frame = None
if new_vad_state == VADState.SPEAKING:
@@ -167,44 +170,25 @@ def _handle_vad(self, audio_frames: bytes, vad_state: VADState):
vad_state = new_vad_state
return vad_state
- def _audio_in_thread_handler(self):
- sample_rate = self._params.audio_in_sample_rate
- num_channels = self._params.audio_in_channels
- num_frames = int(sample_rate / 100) # 10ms of audio
- while self._running:
- try:
- audio_frames = self.read_raw_audio_frames(num_frames)
- if len(audio_frames) > 0:
- frame = AudioRawFrame(
- audio=audio_frames,
- sample_rate=sample_rate,
- num_channels=num_channels)
- self._audio_in_queue.put(frame)
- except BaseException as e:
- logger.error(f"Error reading audio frames: {e}")
-
- def _audio_out_thread_handler(self):
+ def _audio_thread_handler(self):
vad_state: VADState = VADState.QUIET
while self._running:
try:
- frame = self._audio_in_queue.get(timeout=1)
-
- audio_passthrough = True
-
- # Check VAD and push event if necessary. We just care about changes
- # from QUIET to SPEAKING and vice versa.
- if self._params.vad_enabled:
- vad_state = self._handle_vad(frame.audio, vad_state)
- audio_passthrough = self._params.vad_audio_passthrough
-
- # Push audio downstream if passthrough.
- if audio_passthrough:
- future = asyncio.run_coroutine_threadsafe(
- self._internal_push_frame(frame), self.get_event_loop())
- future.result()
-
- self._audio_in_queue.task_done()
- except queue.Empty:
- pass
+ frame = self.read_next_audio_frame()
+
+ if frame:
+ audio_passthrough = True
+
+ # Check VAD and push event if necessary. We just care about
+ # changes from QUIET to SPEAKING and vice versa.
+ if self._params.vad_enabled:
+ vad_state = self._handle_vad(frame.audio, vad_state)
+ audio_passthrough = self._params.vad_audio_passthrough
+
+ # Push audio downstream if passthrough.
+ if audio_passthrough:
+ future = asyncio.run_coroutine_threadsafe(
+ self._internal_push_frame(frame), self.get_event_loop())
+ future.result()
except BaseException as e:
- logger.error(f"Error pushing audio frames: {e}")
+ logger.error(f"Error reading audio frames: {e}")
diff --git a/src/pipecat/transports/base_transport.py b/src/pipecat/transports/base_transport.py
index 7f22d2c2c..7034b81eb 100644
--- a/src/pipecat/transports/base_transport.py
+++ b/src/pipecat/transports/base_transport.py
@@ -4,6 +4,9 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
+import asyncio
+import inspect
+
from abc import ABC, abstractmethod
from pydantic import ConfigDict
@@ -12,6 +15,8 @@
from pipecat.processors.frame_processor import FrameProcessor
from pipecat.vad.vad_analyzer import VADAnalyzer
+from loguru import logger
+
class TransportParams(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -36,6 +41,10 @@ class TransportParams(BaseModel):
class BaseTransport(ABC):
+ def __init__(self, loop: asyncio.AbstractEventLoop | None):
+ self._loop = loop or asyncio.get_running_loop()
+ self._event_handlers: dict = {}
+
@abstractmethod
def input(self) -> FrameProcessor:
raise NotImplementedError
@@ -43,3 +52,30 @@ def input(self) -> FrameProcessor:
@abstractmethod
def output(self) -> FrameProcessor:
raise NotImplementedError
+
+ def event_handler(self, event_name: str):
+ def decorator(handler):
+ self._add_event_handler(event_name, handler)
+ return handler
+ return decorator
+
+ def _register_event_handler(self, event_name: str):
+ if event_name in self._event_handlers:
+ raise Exception(f"Event handler {event_name} already registered")
+ self._event_handlers[event_name] = []
+
+ def _add_event_handler(self, event_name: str, handler):
+ if event_name not in self._event_handlers:
+ raise Exception(f"Event handler {event_name} not registered")
+ self._event_handlers[event_name].append(handler)
+
+ async def _call_event_handler(self, event_name: str, *args, **kwargs):
+ try:
+ for handler in self._event_handlers[event_name]:
+ if inspect.iscoroutinefunction(handler):
+ await handler(self, *args, **kwargs)
+ else:
+ handler(self, *args, **kwargs)
+ except Exception as e:
+ logger.error(f"Exception in event handler {event_name}: {e}")
+ raise e
diff --git a/src/pipecat/transports/local/audio.py b/src/pipecat/transports/local/audio.py
index 771715111..14b8bd5d3 100644
--- a/src/pipecat/transports/local/audio.py
+++ b/src/pipecat/transports/local/audio.py
@@ -6,7 +6,7 @@
import asyncio
-from pipecat.frames.frames import StartFrame
+from pipecat.frames.frames import AudioRawFrame, StartFrame
from pipecat.processors.frame_processor import FrameProcessor
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
@@ -35,8 +35,14 @@ def __init__(self, py_audio: pyaudio.PyAudio, params: TransportParams):
frames_per_buffer=params.audio_in_sample_rate,
input=True)
- def read_raw_audio_frames(self, frame_count: int) -> bytes:
- return self._in_stream.read(frame_count, exception_on_overflow=False)
+ def read_next_audio_frame(self) -> AudioRawFrame | None:
+ sample_rate = self._params.audio_in_sample_rate
+ num_channels = self._params.audio_in_channels
+ num_frames = int(sample_rate / 100) # 10ms of audio
+
+ audio = self._in_stream.read(num_frames, exception_on_overflow=False)
+
+ return AudioRawFrame(audio=audio, sample_rate=sample_rate, num_channels=num_channels)
async def start(self, frame: StartFrame):
await super().start(frame)
diff --git a/src/pipecat/transports/local/tk.py b/src/pipecat/transports/local/tk.py
index 782c01dae..808837998 100644
--- a/src/pipecat/transports/local/tk.py
+++ b/src/pipecat/transports/local/tk.py
@@ -9,7 +9,7 @@
import numpy as np
import tkinter as tk
-from pipecat.frames.frames import ImageRawFrame, StartFrame
+from pipecat.frames.frames import AudioRawFrame, ImageRawFrame, StartFrame
from pipecat.processors.frame_processor import FrameProcessor
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
@@ -45,8 +45,14 @@ def __init__(self, py_audio: pyaudio.PyAudio, params: TransportParams):
frames_per_buffer=params.audio_in_sample_rate,
input=True)
- def read_raw_audio_frames(self, frame_count: int) -> bytes:
- return self._in_stream.read(frame_count, exception_on_overflow=False)
+ def read_next_audio_frame(self) -> AudioRawFrame | None:
+ sample_rate = self._params.audio_in_sample_rate
+ num_channels = self._params.audio_in_channels
+ num_frames = int(sample_rate / 100) # 10ms of audio
+
+ audio = self._in_stream.read(num_frames, exception_on_overflow=False)
+
+ return AudioRawFrame(audio=audio, sample_rate=sample_rate, num_channels=num_channels)
async def start(self, frame: StartFrame):
await super().start(frame)
diff --git a/src/pipecat/transports/network/websocket_server.py b/src/pipecat/transports/network/websocket_server.py
new file mode 100644
index 000000000..8456c4536
--- /dev/null
+++ b/src/pipecat/transports/network/websocket_server.py
@@ -0,0 +1,206 @@
+#
+# Copyright (c) 2024, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+
+import asyncio
+import io
+import queue
+import wave
+import websockets
+
+from typing import Awaitable, Callable
+from pydantic.main import BaseModel
+
+from pipecat.frames.frames import AudioRawFrame, StartFrame
+from pipecat.processors.frame_processor import FrameProcessor
+from pipecat.serializers.base_serializer import FrameSerializer
+from pipecat.serializers.protobuf import ProtobufFrameSerializer
+from pipecat.transports.base_input import BaseInputTransport
+from pipecat.transports.base_output import BaseOutputTransport
+from pipecat.transports.base_transport import BaseTransport, TransportParams
+
+from loguru import logger
+
+
+class WebsocketServerParams(TransportParams):
+ add_wav_header: bool = False
+ audio_frame_size: int = 6400 # 200ms
+ serializer: FrameSerializer = ProtobufFrameSerializer()
+
+
+class WebsocketServerCallbacks(BaseModel):
+ on_client_connected: Callable[[websockets.WebSocketServerProtocol], Awaitable[None]]
+ on_client_disconnected: Callable[[websockets.WebSocketServerProtocol], Awaitable[None]]
+
+
+class WebsocketServerInputTransport(BaseInputTransport):
+
+ def __init__(
+ self,
+ host: str,
+ port: int,
+ params: WebsocketServerParams,
+ callbacks: WebsocketServerCallbacks):
+ super().__init__(params)
+
+ self._host = host
+ self._port = port
+ self._params = params
+ self._callbacks = callbacks
+
+ self._websocket: websockets.WebSocketServerProtocol | None = None
+
+ self._client_audio_queue = queue.Queue()
+ self._stop_server_event = asyncio.Event()
+
+ async def start(self, frame: StartFrame):
+ self._server_task = self.get_event_loop().create_task(self._server_task_handler())
+ await super().start(frame)
+
+ async def stop(self):
+ self._stop_server_event.set()
+ await self._server_task
+ await super().stop()
+
+ def read_next_audio_frame(self) -> AudioRawFrame | None:
+ try:
+ return self._client_audio_queue.get(timeout=1)
+ except queue.Empty:
+ return None
+
+ async def _server_task_handler(self):
+ logger.info(f"Starting websocket server on {self._host}:{self._port}")
+ async with websockets.serve(self._client_handler, self._host, self._port) as server:
+ await self._stop_server_event.wait()
+
+ async def _client_handler(self, websocket: websockets.WebSocketServerProtocol, path):
+ logger.info(f"New client connection from {websocket.remote_address}")
+ if self._websocket:
+ await self._websocket.close()
+ logger.warning("Only one client connected, using new connection")
+
+ self._websocket = websocket
+
+ # Notify
+ await self._callbacks.on_client_connected(websocket)
+
+ # Handle incoming messages
+ async for message in websocket:
+ frame = self._params.serializer.deserialize(message)
+ if isinstance(frame, AudioRawFrame) and self._params.audio_in_enabled:
+ self._client_audio_queue.put_nowait(frame)
+ else:
+ await self._internal_push_frame(frame)
+
+ # Notify disconnection
+ await self._callbacks.on_client_disconnected(websocket)
+
+ await self._websocket.close()
+ self._websocket = None
+
+ logger.info(f"Client {websocket.remote_address} disconnected")
+
+
+class WebsocketServerOutputTransport(BaseOutputTransport):
+
+ def __init__(self, params: WebsocketServerParams):
+ super().__init__(params)
+
+ self._params = params
+
+ self._websocket: websockets.WebSocketServerProtocol | None = None
+
+ self._audio_buffer = bytes()
+
+ async def set_client_connection(self, websocket: websockets.WebSocketServerProtocol | None):
+ if self._websocket:
+ await self._websocket.close()
+ logger.warning("Only one client allowed, using new connection")
+ self._websocket = websocket
+
+ def write_raw_audio_frames(self, frames: bytes):
+ self._audio_buffer += frames
+ while len(self._audio_buffer) >= self._params.audio_frame_size:
+ frame = AudioRawFrame(
+ audio=self._audio_buffer[:self._params.audio_frame_size],
+ sample_rate=self._params.audio_out_sample_rate,
+ num_channels=self._params.audio_out_channels
+ )
+
+ if self._params.add_wav_header:
+ content = io.BytesIO()
+ ww = wave.open(content, "wb")
+ ww.setsampwidth(2)
+ ww.setnchannels(frame.num_channels)
+ ww.setframerate(frame.sample_rate)
+ ww.writeframes(frame.audio)
+ ww.close()
+ content.seek(0)
+ wav_frame = AudioRawFrame(
+ content.read(),
+ sample_rate=frame.sample_rate,
+ num_channels=frame.num_channels)
+ frame = wav_frame
+
+ proto = self._params.serializer.serialize(frame)
+
+ future = asyncio.run_coroutine_threadsafe(
+ self._websocket.send(proto), self.get_event_loop())
+ future.result()
+
+ self._audio_buffer = self._audio_buffer[self._params.audio_frame_size:]
+
+
+class WebsocketServerTransport(BaseTransport):
+
+ def __init__(
+ self,
+ host: str = "localhost",
+ port: int = 8765,
+ params: WebsocketServerParams = WebsocketServerParams(),
+ loop: asyncio.AbstractEventLoop | None = None):
+ super().__init__(loop)
+ self._host = host
+ self._port = port
+ self._params = params
+
+ self._callbacks = WebsocketServerCallbacks(
+ on_client_connected=self._on_client_connected,
+ on_client_disconnected=self._on_client_disconnected
+ )
+ self._input: WebsocketServerInputTransport | None = None
+ self._output: WebsocketServerOutputTransport | None = None
+ self._websocket: websockets.WebSocketServerProtocol | None = None
+
+ # Register supported handlers. The user will only be able to register
+ # these handlers.
+ self._register_event_handler("on_client_connected")
+ self._register_event_handler("on_client_disconnected")
+
+ def input(self) -> FrameProcessor:
+ if not self._input:
+ self._input = WebsocketServerInputTransport(
+ self._host, self._port, self._params, self._callbacks)
+ return self._input
+
+ def output(self) -> FrameProcessor:
+ if not self._output:
+ self._output = WebsocketServerOutputTransport(self._params)
+ return self._output
+
+ async def _on_client_connected(self, websocket):
+ if self._output:
+ await self._output.set_client_connection(websocket)
+ await self._call_event_handler("on_client_connected", websocket)
+ else:
+ logger.error("A WebsocketServerTransport output is missing in the pipeline")
+
+ async def _on_client_disconnected(self, websocket):
+ if self._output:
+ await self._output.set_client_connection(None)
+ await self._call_event_handler("on_client_disconnected", websocket)
+ else:
+ logger.error("A WebsocketServerTransport output is missing in the pipeline")
diff --git a/src/pipecat/transports/services/daily.py b/src/pipecat/transports/services/daily.py
index e4c0a9762..f20eaa4dc 100644
--- a/src/pipecat/transports/services/daily.py
+++ b/src/pipecat/transports/services/daily.py
@@ -6,15 +6,12 @@
import aiohttp
import asyncio
-from concurrent.futures import ThreadPoolExecutor
-import inspect
import queue
import time
-import types
from dataclasses import dataclass
-from functools import partial
from typing import Any, Callable, Mapping
+from concurrent.futures import ThreadPoolExecutor
from daily import (
CallClient,
@@ -139,7 +136,8 @@ def __init__(
token: str | None,
bot_name: str,
params: DailyParams,
- callbacks: DailyCallbacks):
+ callbacks: DailyCallbacks,
+ loop: asyncio.AbstractEventLoop):
super().__init__()
if not self._daily_initialized:
@@ -151,6 +149,7 @@ def __init__(
self._bot_name: str = bot_name
self._params: DailyParams = params
self._callbacks = callbacks
+ self._loop = loop
self._participant_id: str = ""
self._video_renderers = {}
@@ -189,15 +188,22 @@ def set_callbacks(self, callbacks: DailyCallbacks):
def send_message(self, frame: DailyTransportMessageFrame):
self._client.send_app_message(frame.message, frame.participant_id)
- def read_raw_audio_frames(self, frame_count: int) -> bytes:
+ def read_next_audio_frame(self) -> AudioRawFrame | None:
+ sample_rate = self._params.audio_in_sample_rate
+ num_channels = self._params.audio_in_channels
+
if self._other_participant_has_joined:
- return self._speaker.read_frames(frame_count)
+ num_frames = int(sample_rate / 100) # 10ms of audio
+
+ audio = self._speaker.read_frames(num_frames)
+
+ return AudioRawFrame(audio=audio, sample_rate=sample_rate, num_channels=num_channels)
else:
# If no one has ever joined the meeting `read_frames()` would block,
# instead we just wait a bit. daily-python should probably return
# silence instead.
time.sleep(0.01)
- return b''
+ return None
def write_raw_audio_frames(self, frames: bytes):
self._mic.write_frames(frames)
@@ -212,8 +218,7 @@ async def join(self):
self._joining = True
- loop = asyncio.get_running_loop()
- await loop.run_in_executor(self._executor, self._join)
+ await self._loop.run_in_executor(self._executor, self._join)
def _join(self):
logger.info(f"Joining {self._room_url}")
@@ -304,8 +309,7 @@ async def leave(self):
self._joined = False
self._leaving = True
- loop = asyncio.get_running_loop()
- await loop.run_in_executor(self._executor, self._leave)
+ await self._loop.run_in_executor(self._executor, self._leave)
def _leave(self):
logger.info(f"Leaving {self._room_url}")
@@ -335,8 +339,7 @@ def _handle_leave_response(self):
self._callbacks.on_error(error_msg)
async def cleanup(self):
- loop = asyncio.get_running_loop()
- await loop.run_in_executor(self._executor, self._cleanup)
+ await self._loop.run_in_executor(self._executor, self._cleanup)
def _cleanup(self):
if self._client:
@@ -471,7 +474,7 @@ def __init__(self, client: DailyTransportClient, params: DailyParams):
self._video_renderers = {}
self._camera_in_queue = queue.Queue()
- self._vad_analyzer = params.vad_analyzer
+ self._vad_analyzer: VADAnalyzer | None = params.vad_analyzer
if params.vad_enabled and not params.vad_analyzer:
self._vad_analyzer = WebRTCVADAnalyzer(
sample_rate=self._params.audio_in_sample_rate,
@@ -485,8 +488,7 @@ async def start(self, frame: StartFrame):
# This will set _running=True
await super().start(frame)
# Create camera in thread (runs if _running is true).
- loop = asyncio.get_running_loop()
- self._camera_in_thread = loop.run_in_executor(
+ self._camera_in_thread = self._loop.run_in_executor(
self._in_executor, self._camera_in_thread_handler)
async def stop(self):
@@ -503,14 +505,11 @@ async def cleanup(self):
await super().cleanup()
await self._client.cleanup()
- def vad_analyze(self, audio_frames: bytes) -> VADState:
- state = VADState.QUIET
- if self._vad_analyzer:
- state = self._vad_analyzer.analyze_audio(audio_frames)
- return state
+ def vad_analyzer(self) -> VADAnalyzer | None:
+ return self._vad_analyzer
- def read_raw_audio_frames(self, frame_count: int) -> bytes:
- return self._client.read_raw_audio_frames(frame_count)
+ def read_next_audio_frame(self) -> AudioRawFrame | None:
+ return self._client.read_next_audio_frame()
#
# FrameProcessor
@@ -642,7 +641,15 @@ def write_frame_to_camera(self, frame: ImageRawFrame):
class DailyTransport(BaseTransport):
- def __init__(self, room_url: str, token: str | None, bot_name: str, params: DailyParams):
+ def __init__(
+ self,
+ room_url: str,
+ token: str | None,
+ bot_name: str,
+ params: DailyParams,
+ loop: asyncio.AbstractEventLoop | None = None):
+ super().__init__(loop)
+
callbacks = DailyCallbacks(
on_joined=self._on_joined,
on_left=self._on_left,
@@ -660,12 +667,10 @@ def __init__(self, room_url: str, token: str | None, bot_name: str, params: Dail
)
self._params = params
- self._client = DailyTransportClient(room_url, token, bot_name, params, callbacks)
+ self._client = DailyTransportClient(
+ room_url, token, bot_name, params, callbacks, self._loop)
self._input: DailyInputTransport | None = None
self._output: DailyOutputTransport | None = None
- self._loop = asyncio.get_running_loop()
-
- self._event_handlers: dict = {}
# Register supported handlers. The user will only be able to register
# these handlers.
@@ -741,10 +746,10 @@ def capture_participant_video(
participant_id, framerate, video_source, color_format)
def _on_joined(self, participant):
- self.on_joined(participant)
+ self._call_async_event_handler("on_joined", participant)
def _on_left(self):
- self.on_left()
+ self._call_async_event_handler("on_left")
def _on_error(self, error):
# TODO(aleix): Report error to input/output transports. The one managing
@@ -754,10 +759,10 @@ def _on_error(self, error):
def _on_app_message(self, message: Any, sender: str):
if self._input:
self._input.push_app_message(message, sender)
- self.on_app_message(message, sender)
+ self._call_async_event_handler("on_app_message", message, sender)
def _on_call_state_updated(self, state: str):
- self.on_call_state_updated(state)
+ self._call_async_event_handler("on_call_state_updated", state)
async def _handle_dialin_ready(self, sip_endpoint: str):
if not self._params.dialin_settings:
@@ -793,28 +798,28 @@ async def _handle_dialin_ready(self, sip_endpoint: str):
def _on_dialin_ready(self, sip_endpoint):
if self._params.dialin_settings:
asyncio.run_coroutine_threadsafe(self._handle_dialin_ready(sip_endpoint), self._loop)
- self.on_dialin_ready(sip_endpoint)
+ self._call_async_event_handler("on_dialin_ready", sip_endpoint)
def _on_dialout_connected(self, data):
- self.on_dialout_connected(data)
+ self._call_async_event_handler("on_dialout_connected", data)
def _on_dialout_stopped(self, data):
- self.on_dialout_stopped(data)
+ self._call_async_event_handler("on_dialout_stopped", data)
def _on_dialout_error(self, data):
- self.on_dialout_error(data)
+ self._call_async_event_handler("on_dialout_error", data)
def _on_dialout_warning(self, data):
- self.on_dialout_warning(data)
+ self._call_async_event_handler("on_dialout_warning", data)
def _on_participant_joined(self, participant):
- self.on_participant_joined(participant)
+ self._call_async_event_handler("on_participant_joined", participant)
def _on_participant_left(self, participant, reason):
- self.on_participant_left(participant, reason)
+ self._call_async_event_handler("on_participant_left", participant, reason)
def _on_first_participant_joined(self, participant):
- self.on_first_participant_joined(participant)
+ self._call_async_event_handler("on_first_participant_joined", participant)
def _on_transcription_message(self, participant_id, message):
text = message["text"]
@@ -829,84 +834,7 @@ def _on_transcription_message(self, participant_id, message):
if self._input:
self._input.push_transcription_frame(frame)
- #
- # Decorators (event handlers)
- #
-
- def on_joined(self, participant):
- pass
-
- def on_left(self):
- pass
-
- def on_app_message(self, message, sender):
- pass
-
- def on_call_state_updated(self, state):
- pass
-
- def on_dialin_ready(self, sip_endpoint):
- pass
-
- def on_dialout_connected(self, data):
- pass
-
- def on_dialout_stopped(self, data):
- pass
-
- def on_dialout_error(self, data):
- pass
-
- def on_dialout_warning(self, data):
- pass
-
- def on_first_participant_joined(self, participant):
- pass
-
- def on_participant_joined(self, participant):
- pass
-
- def on_participant_left(self, participant, reason):
- pass
-
- def event_handler(self, event_name: str):
- def decorator(handler):
- self._add_event_handler(event_name, handler)
- return handler
- return decorator
-
- def _register_event_handler(self, event_name: str):
- methods = inspect.getmembers(self, predicate=inspect.ismethod)
- if event_name not in [method[0] for method in methods]:
- raise Exception(f"Event handler {event_name} not found")
-
- self._event_handlers[event_name] = [getattr(self, event_name)]
-
- patch_method = types.MethodType(partial(self._patch_method, event_name), self)
- setattr(self, event_name, patch_method)
-
- def _add_event_handler(self, event_name: str, handler):
- if event_name not in self._event_handlers:
- raise Exception(f"Event handler {event_name} not registered")
- self._event_handlers[event_name].append(types.MethodType(handler, self))
-
- def _patch_method(self, event_name, *args, **kwargs):
- try:
- for handler in self._event_handlers[event_name]:
- if inspect.iscoroutinefunction(handler):
- # Beware, if handler() calls another event handler it
- # will deadlock. You shouldn't do that anyways.
- future = asyncio.run_coroutine_threadsafe(
- handler(*args[1:], **kwargs), self._loop)
-
- # wait for the coroutine to finish. This will also
- # raise any exceptions raised by the coroutine.
- future.result()
- else:
- handler(*args[1:], **kwargs)
- except Exception as e:
- logger.error(f"Exception in event handler {event_name}: {e}")
- raise e
-
- # def start_recording(self):
- # self.client.start_recording()
+ def _call_async_event_handler(self, event_name: str, *args, **kwargs):
+ future = asyncio.run_coroutine_threadsafe(
+ self._call_event_handler(event_name, *args, **kwargs), self._loop)
+ future.result()