diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml
index e678623cb..652df99d2 100644
--- a/.github/workflows/lint.yaml
+++ b/.github/workflows/lint.yaml
@@ -22,11 +22,23 @@ jobs:
steps:
- name: Checkout repo
uses: actions/checkout@v4
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: '3.10'
+ - name: Setup virtual environment
+ run: |
+ python -m venv .venv
+ - name: Install basic Python dependencies
+ run: |
+ source .venv/bin/activate
+ python -m pip install --upgrade pip
+ pip install -r requirements.txt
- name: autopep8
id: autopep8
- uses: peter-evans/autopep8@v2
- with:
- args: --exit-code -r -d -a -a src/
+ run: |
+ source .venv/bin/activate
+ autopep8 --exit-code -r -d --exclude "*_pb2.py" -a -a src/
- name: Fail if autopep8 requires changes
if: steps.autopep8.outputs.exit-code == 2
run: exit 1
diff --git a/examples/foundational/websocket-server/index.html b/examples/foundational/websocket-server/index.html
new file mode 100644
index 000000000..51b409573
--- /dev/null
+++ b/examples/foundational/websocket-server/index.html
@@ -0,0 +1,142 @@
+
+
+
+
+
+
+
+ WebSocket Audio Stream
+
+
+
+ WebSocket Audio Stream
+
+
+
+
+
+
diff --git a/examples/foundational/websocket-server/sample.py b/examples/foundational/websocket-server/sample.py
new file mode 100644
index 000000000..a8d10d996
--- /dev/null
+++ b/examples/foundational/websocket-server/sample.py
@@ -0,0 +1,70 @@
+import asyncio
+import aiohttp
+import logging
+import os
+import wave
+
+from dailyai.pipeline.frame_processor import FrameProcessor
+from dailyai.pipeline.frames import AudioFrame, EndFrame, EndPipeFrame, TextFrame
+from dailyai.pipeline.pipeline import Pipeline
+from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
+from dailyai.services.websocket_transport_service import WebsocketTransport
+from dailyai.services.whisper_ai_services import WhisperSTTService
+
+logging.basicConfig(format=f"%(levelno)s %(asctime)s %(message)s")
+logger = logging.getLogger("dailyai")
+logger.setLevel(logging.DEBUG)
+
+
+async def main():
+ async with aiohttp.ClientSession() as session:
+ transport = WebsocketTransport(
+ mic_enabled=True, speaker_enabled=True)
+
+ tts = ElevenLabsTTSService(
+ aiohttp_session=session,
+ api_key=os.getenv("ELEVENLABS_API_KEY"),
+ voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
+ )
+
+ class WriteToWav(FrameProcessor):
+ async def process_frame(self, frame):
+ if isinstance(frame, AudioFrame):
+ with wave.open("output.wav", "wb") as f:
+ f.setnchannels(1)
+ f.setsampwidth(2)
+ f.setframerate(16000)
+ f.writeframes(frame.data)
+ yield frame
+
+ class AudioChunker(FrameProcessor):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._buffer = bytes()
+
+ async def process_frame(self, frame):
+ if isinstance(frame, AudioFrame):
+ self._buffer += frame.data
+ if len(self._buffer) >= 1600:
+ yield AudioFrame(self._buffer[:1600])
+ self._buffer = self._buffer[1600:]
+ elif isinstance(frame, EndFrame) or isinstance(frame, EndPipeFrame):
+ if self._buffer:
+ yield AudioFrame(self._buffer)
+ self._buffer = bytes()
+ else:
+ yield frame
+
+ # pipeline = Pipeline([WriteToWav(), WhisperSTTService(), tts])
+ pipeline = Pipeline([tts, AudioChunker()])
+
+ @transport.on_connection
+ async def queue_frame():
+ await pipeline.queue_frames([TextFrame("Hello there!")])
+
+ await asyncio.gather(transport.run(pipeline))
+ del tts
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/pyproject.toml b/pyproject.toml
index e1eac802b..a559076f9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -35,7 +35,8 @@ dependencies = [
"torch",
"torchaudio",
"pyaudio",
- "typing-extensions"
+ "typing-extensions",
+ "websockets"
]
[project.urls]
diff --git a/src/dailyai/pipeline/frames.proto b/src/dailyai/pipeline/frames.proto
new file mode 100644
index 000000000..7ecea6d25
--- /dev/null
+++ b/src/dailyai/pipeline/frames.proto
@@ -0,0 +1,25 @@
+syntax = "proto3";
+
+package dailyai_proto;
+
+message TextFrame {
+ string text = 1;
+}
+
+message AudioFrame {
+ bytes audio = 1;
+}
+
+message TranscriptionFrame {
+ string text = 1;
+ string participant_id = 2;
+ string timestamp = 3;
+}
+
+message Frame {
+ oneof frame {
+ TextFrame text = 1;
+ AudioFrame audio = 2;
+ TranscriptionFrame transcription = 3;
+ }
+}
diff --git a/src/dailyai/pipeline/frames.py b/src/dailyai/pipeline/frames.py
index 6c39fb2c7..893a50aea 100644
--- a/src/dailyai/pipeline/frames.py
+++ b/src/dailyai/pipeline/frames.py
@@ -2,12 +2,20 @@
from typing import Any, List
from dailyai.services.openai_llm_context import OpenAILLMContext
+import dailyai.pipeline.protobufs.frames_pb2 as frame_protos
class Frame:
def __str__(self):
return f"{self.__class__.__name__}"
+ def to_proto(self):
+ raise NotImplementedError
+
+ @staticmethod
+ def from_proto(proto):
+ raise NotImplementedError
+
class ControlFrame(Frame):
# Control frames should contain no instance data, so
@@ -61,12 +69,36 @@ class LLMResponseEndFrame(ControlFrame):
@dataclass()
class AudioFrame(Frame):
"""A chunk of audio. Will be played by the transport if the transport's mic
- has been enabled."""
+ has been enabled.
+
+ >>> str(AudioFrame(data=b'1234567890'))
+ 'AudioFrame, size: 10 B'
+
+ >>> AudioFrame.from_proto(AudioFrame(data=b'1234567890').to_proto())
+ AudioFrame(data=b'1234567890')
+
+ The to_proto() function will always return a top-level Frame protobuf, so
+ it can be sent across the wire with other frames of various types.
+
+ >>> type(AudioFrame(data=b'1234567890').to_proto())
+
+ """
data: bytes
def __str__(self):
return f"{self.__class__.__name__}, size: {len(self.data)} B"
+ def to_proto(self) -> frame_protos.Frame:
+ frame = frame_protos.Frame()
+ frame.audio.audio = self.data
+ return frame
+
+ @staticmethod
+ def from_proto(proto: frame_protos.Frame):
+ if proto.WhichOneof("frame") != "audio":
+ raise ValueError("Proto does not contain an audio frame")
+ return AudioFrame(data=proto.audio.audio)
+
@dataclass()
class ImageFrame(Frame):
@@ -93,20 +125,71 @@ def __str__(self):
@dataclass()
class TextFrame(Frame):
"""A chunk of text. Emitted by LLM services, consumed by TTS services, can
- be used to send text through pipelines."""
+ be used to send text through pipelines.
+
+ >>> str(TextFrame.from_proto(TextFrame(text='hello world').to_proto()))
+ 'TextFrame: "hello world"'
+
+ The to_proto() function will always return a top-level Frame protobuf, so
+ it can be sent across the wire with other frames of various types.
+
+ >>> type(TextFrame(text='hello world').to_proto())
+
+ """
text: str
def __str__(self):
return f'{self.__class__.__name__}: "{self.text}"'
+ def to_proto(self) -> frame_protos.Frame:
+ proto_frame = frame_protos.Frame()
+ proto_frame.text.text = self.text
+ return proto_frame
+
+ @staticmethod
+ def from_proto(proto: frame_protos.TextFrame):
+ return TextFrame(text=proto.text.text)
+
@dataclass()
class TranscriptionQueueFrame(TextFrame):
"""A text frame with transcription-specific data. Will be placed in the
- transport's receive queue when a participant speaks."""
+ transport's receive queue when a participant speaks.
+
+ >>> transcription_frame = TranscriptionQueueFrame(text="Hello there!", participantId="123", timestamp="2021-01-01")
+ >>> transcription_frame
+ TranscriptionQueueFrame(text='Hello there!', participantId='123', timestamp='2021-01-01')
+
+ >>> TranscriptionQueueFrame.from_proto(transcription_frame.to_proto())
+ TranscriptionQueueFrame(text='Hello there!', participantId='123', timestamp='2021-01-01')
+
+ The to_proto() function will always return a top-level Frame protobuf, so
+ it can be sent across the wire with other frames of various types.
+
+ >>> type(transcription_frame.to_proto())
+
+ """
participantId: str
timestamp: str
+ def __str__(self):
+ return f"{self.__class__.__name__}, text: '{self.text}' participantId: {self.participantId}, timestamp: {self.timestamp}"
+
+ def to_proto(self) -> frame_protos.Frame:
+ frame = frame_protos.Frame()
+ frame.transcription.text = self.text
+ frame.transcription.participant_id = self.participantId
+ frame.transcription.timestamp = self.timestamp
+ return frame
+
+ @staticmethod
+ def from_proto(proto: frame_protos.Frame):
+ return TranscriptionQueueFrame(
+ text=proto.transcription.text,
+ participantId=proto.transcription.participant_id,
+ timestamp=proto.transcription.timestamp
+ )
+
@dataclass()
class LLMMessagesQueueFrame(Frame):
@@ -179,3 +262,46 @@ class LLMFunctionCallFrame(Frame):
"""Emitted when the LLM has received an entire function call completion."""
function_name: str
arguments: str
+
+
+def FrameFromProto(proto: frame_protos.Frame) -> Frame:
+ """Returns a Frame object from a Frame protobuf. Used to convert frames
+ passed over the wire as protobufs to Frame objects used in pipelines
+ and frame processors.
+
+ >>> FrameFromProto(AudioFrame(data=b'1234567890').to_proto())
+ AudioFrame(data=b'1234567890')
+
+ >>> FrameFromProto(TextFrame(text='hello world').to_proto())
+ TextFrame(text='hello world')
+
+ >>> FrameFromProto(TranscriptionQueueFrame(text="Hello there!", participantId="123", timestamp="2021-01-01").to_proto())
+ TranscriptionQueueFrame(text='Hello there!', participantId='123', timestamp='2021-01-01')
+ """
+ if proto.WhichOneof("frame") == "audio":
+ return AudioFrame.from_proto(proto)
+ elif proto.WhichOneof("frame") == "text":
+ return TextFrame.from_proto(proto)
+ elif proto.WhichOneof("frame") == "transcription":
+ return TranscriptionQueueFrame.from_proto(proto)
+ else:
+ raise ValueError(
+ "Proto does not contain a valid frame. You may need to add a new case to FrameFromProto.")
+
+
+if __name__ == "__main__":
+ audio_frame = AudioFrame(data=b'1234567890')
+ print(audio_frame)
+ print(audio_frame.to_proto().SerializeToString())
+ print(AudioFrame.from_proto(audio_frame.to_proto()))
+
+ text_frame = TextFrame(text="Hello there!")
+ print(text_frame)
+ print(text_frame.to_proto().SerializeToString())
+ print(TextFrame.from_proto(text_frame.to_proto()))
+
+ transcripton_frame = TranscriptionQueueFrame(
+ text="Hello there!", participantId="123", timestamp="2021-01-01")
+ print(transcripton_frame)
+ print(transcripton_frame.to_proto().SerializeToString())
+ print(TranscriptionQueueFrame.from_proto(transcripton_frame.to_proto()))
diff --git a/src/dailyai/pipeline/protobufs/frames_pb2.py b/src/dailyai/pipeline/protobufs/frames_pb2.py
new file mode 100644
index 000000000..b923f8fe3
--- /dev/null
+++ b/src/dailyai/pipeline/protobufs/frames_pb2.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: frames.proto
+# Protobuf Python Version: 4.25.3
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x66rames.proto\x12\rdailyai_proto\"\x19\n\tTextFrame\x12\x0c\n\x04text\x18\x01 \x01(\t\"\x1b\n\nAudioFrame\x12\r\n\x05\x61udio\x18\x01 \x01(\x0c\"M\n\x12TranscriptionFrame\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x16\n\x0eparticipant_id\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\t\"\xa2\x01\n\x05\x46rame\x12(\n\x04text\x18\x01 \x01(\x0b\x32\x18.dailyai_proto.TextFrameH\x00\x12*\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x19.dailyai_proto.AudioFrameH\x00\x12:\n\rtranscription\x18\x03 \x01(\x0b\x32!.dailyai_proto.TranscriptionFrameH\x00\x42\x07\n\x05\x66rameb\x06proto3')
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'frames_pb2', _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+ DESCRIPTOR._options = None
+ _globals['_TEXTFRAME']._serialized_start = 31
+ _globals['_TEXTFRAME']._serialized_end = 56
+ _globals['_AUDIOFRAME']._serialized_start = 58
+ _globals['_AUDIOFRAME']._serialized_end = 85
+ _globals['_TRANSCRIPTIONFRAME']._serialized_start = 87
+ _globals['_TRANSCRIPTIONFRAME']._serialized_end = 164
+ _globals['_FRAME']._serialized_start = 167
+ _globals['_FRAME']._serialized_end = 329
+# @@protoc_insertion_point(module_scope)
diff --git a/src/dailyai/services/local_stt_service.py b/src/dailyai/services/local_stt_service.py
index 727bafbc6..98d37dc10 100644
--- a/src/dailyai/services/local_stt_service.py
+++ b/src/dailyai/services/local_stt_service.py
@@ -42,6 +42,7 @@ def _new_wave(self):
async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
"""Processes a frame of audio data, either buffering or transcribing it."""
if not isinstance(frame, AudioFrame):
+ yield frame
return
data = frame.data
diff --git a/src/dailyai/services/websocket_transport_service.py b/src/dailyai/services/websocket_transport_service.py
new file mode 100644
index 000000000..53bd40958
--- /dev/null
+++ b/src/dailyai/services/websocket_transport_service.py
@@ -0,0 +1,87 @@
+import asyncio
+import time
+import wave
+import websockets
+
+from dailyai.pipeline.frames import AudioFrame, EndFrame, FrameFromProto
+from dailyai.pipeline.pipeline import Pipeline
+from dailyai.services.base_transport_service import BaseTransportService
+import dailyai.pipeline.protobufs.frames_pb2 as frame_protos
+
+
+class WebsocketTransport(BaseTransportService):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self._sample_width = kwargs.get("sample_width") or 2
+ self._n_channels = kwargs.get("n_channels") or 1
+ self._port = kwargs.get("port") or 8765
+ self._host = kwargs.get("host") or "localhost"
+
+ if self._camera_enabled:
+ raise ValueError(
+ "Camera is not supported in WebsocketTransportService")
+
+ if self._speaker_enabled:
+ self._speaker_buffer_pending = bytearray()
+
+ self._server: websockets.WebSocketServer | None = None
+ self._websocket: websockets.WebSocketServerProtocol | None = None
+
+ self._connection_handlers = []
+
+ async def run(self, pipeline: Pipeline, override_pipeline_source_queue=True):
+ self._stop_server_event = asyncio.Event()
+ pipeline.set_sink(self.send_queue)
+ if override_pipeline_source_queue:
+ pipeline.set_source(self.receive_queue)
+
+ async def timeout_task():
+ sleep_time = self._expiration - time.time()
+ await asyncio.sleep(sleep_time)
+ self._stop_server_event.set()
+
+ async def send_task():
+ while not self._stop_server_event.is_set():
+ frame = await self.send_queue.get()
+ if isinstance(frame, EndFrame):
+ self._stop_server_event.set()
+ break
+ if self._websocket:
+ await self._websocket.send(frame.to_proto().SerializeToString())
+
+ async def start_server() -> None:
+ async with websockets.serve(
+ self._websocket_handler, self._host, self._port) as server:
+ print("Server started")
+ await self._stop_server_event.wait()
+ print("Server stopped")
+ await self.receive_queue.put(EndFrame)
+
+ await asyncio.gather(start_server(), timeout_task(), send_task(), pipeline.run_pipeline())
+
+ def on_connection(self, handler):
+ self._connection_handlers.append(handler)
+
+ async def _websocket_handler(self, websocket: websockets.WebSocketServerProtocol, path):
+ if self._websocket:
+ await self._websocket.close()
+ self._logger.warning(
+ "Got another websocket connection; closing first.")
+
+ for handler in self._connection_handlers:
+ await handler()
+
+ self._websocket = websocket
+ async for message in websocket:
+ """
+ generic_frame = frame_protos.Frame.ParseFromString(message)
+ frame = FrameFromProto(generic_frame)
+ if isinstance(frame, AudioFrame):
+ with wave.open("output.wav", "wb") as f:
+ f.setnchannels(1)
+ f.setsampwidth(2)
+ f.setframerate(16000)
+ f.writeframes(frame.data)
+ await self.receive_queue.put(frame)
+ """
+ pass
diff --git a/src/dailyai/services/whisper_ai_services.py b/src/dailyai/services/whisper_ai_services.py
index cc657e6cf..2ed1c3563 100644
--- a/src/dailyai/services/whisper_ai_services.py
+++ b/src/dailyai/services/whisper_ai_services.py
@@ -52,4 +52,5 @@ async def run_stt(self, audio: BinaryIO) -> str:
res: str = ""
for segment in segments:
res += f"{segment.text} "
+ print("Transcription: ", segment.text)
return res