diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml
index e678623cb..da097d237 100644
--- a/.github/workflows/lint.yaml
+++ b/.github/workflows/lint.yaml
@@ -22,11 +22,23 @@ jobs:
steps:
- name: Checkout repo
uses: actions/checkout@v4
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: '3.10'
+ - name: Setup virtual environment
+ run: |
+ python -m venv .venv
+ - name: Install basic Python dependencies
+ run: |
+ source .venv/bin/activate
+ python -m pip install --upgrade pip
+ pip install -r requirements.txt
- name: autopep8
id: autopep8
- uses: peter-evans/autopep8@v2
- with:
- args: --exit-code -r -d -a -a src/
+ run: |
+ source .venv/bin/activate
+ autopep8 --max-line-length 100 --exit-code -r -d --exclude "*_pb2.py" -a -a src/
- name: Fail if autopep8 requires changes
if: steps.autopep8.outputs.exit-code == 2
run: exit 1
diff --git a/README.md b/README.md
index feed35dd6..21e38a31d 100644
--- a/README.md
+++ b/README.md
@@ -127,7 +127,7 @@ You can use [use-package](https://github.com/jwiegley/use-package) to install [p
:defer t
:hook ((python-mode . py-autopep8-mode))
:config
- (setq py-autopep8-options '("-a" "-a")))
+ (setq py-autopep8-options '("-a" "-a", "--max-line-length=100")))
```
`autopep8` was installed in the `venv` environment described before, so you should be able to use [pyvenv-auto](https://github.com/ryotaro612/pyvenv-auto) to automatically load that environment inside Emacs.
@@ -152,6 +152,7 @@ Install the
},
"autopep8.args": [
"-a",
- "-a"
+ "-a",
+ "--max-line-length=100"
],
```
diff --git a/examples/foundational/10-wake-word.py b/examples/foundational/10-wake-word.py
index 00c3acdc6..cd82a59e3 100644
--- a/examples/foundational/10-wake-word.py
+++ b/examples/foundational/10-wake-word.py
@@ -172,7 +172,8 @@ async def handle_transcriptions():
isa.run(
tma_out.run(
llm.run(
- tma_in.run(ncf.run(tf.run(transport.get_receive_frames())))
+ tma_in.run(
+ ncf.run(tf.run(transport.get_receive_frames())))
)
)
),
diff --git a/examples/foundational/websocket-server/frames.proto b/examples/foundational/websocket-server/frames.proto
new file mode 100644
index 000000000..7ecea6d25
--- /dev/null
+++ b/examples/foundational/websocket-server/frames.proto
@@ -0,0 +1,25 @@
+syntax = "proto3";
+
+package dailyai_proto;
+
+message TextFrame {
+ string text = 1;
+}
+
+message AudioFrame {
+ bytes audio = 1;
+}
+
+message TranscriptionFrame {
+ string text = 1;
+ string participant_id = 2;
+ string timestamp = 3;
+}
+
+message Frame {
+ oneof frame {
+ TextFrame text = 1;
+ AudioFrame audio = 2;
+ TranscriptionFrame transcription = 3;
+ }
+}
diff --git a/examples/foundational/websocket-server/index.html b/examples/foundational/websocket-server/index.html
new file mode 100644
index 000000000..77be13518
--- /dev/null
+++ b/examples/foundational/websocket-server/index.html
@@ -0,0 +1,134 @@
+
+
+
+
+
+
+
+ WebSocket Audio Stream
+
+
+
+ WebSocket Audio Stream
+
+
+
+
+
+
diff --git a/examples/foundational/websocket-server/sample.py b/examples/foundational/websocket-server/sample.py
new file mode 100644
index 000000000..72390831e
--- /dev/null
+++ b/examples/foundational/websocket-server/sample.py
@@ -0,0 +1,50 @@
+import asyncio
+import aiohttp
+import logging
+import os
+from dailyai.pipeline.frame_processor import FrameProcessor
+from dailyai.pipeline.frames import TextFrame, TranscriptionQueueFrame
+from dailyai.pipeline.pipeline import Pipeline
+from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
+from dailyai.services.websocket_transport_service import WebsocketTransport
+from dailyai.services.whisper_ai_services import WhisperSTTService
+
+logging.basicConfig(format="%(levelno)s %(asctime)s %(message)s")
+logger = logging.getLogger("dailyai")
+logger.setLevel(logging.DEBUG)
+
+
+class WhisperTranscriber(FrameProcessor):
+ async def process_frame(self, frame):
+ if isinstance(frame, TranscriptionQueueFrame):
+ print(f"Transcribed: {frame.text}")
+ else:
+ yield frame
+
+
+async def main():
+ async with aiohttp.ClientSession() as session:
+ transport = WebsocketTransport(
+ mic_enabled=True,
+ speaker_enabled=True,
+ )
+ tts = ElevenLabsTTSService(
+ aiohttp_session=session,
+ api_key=os.getenv("ELEVENLABS_API_KEY"),
+ voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
+ )
+
+ pipeline = Pipeline([
+ WhisperSTTService(),
+ WhisperTranscriber(),
+ tts,
+ ])
+
+ @transport.on_connection
+ async def queue_frame():
+ await pipeline.queue_frames([TextFrame("Hello there!")])
+
+ await transport.run(pipeline)
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/pyproject.toml b/pyproject.toml
index e1eac802b..a559076f9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -35,7 +35,8 @@ dependencies = [
"torch",
"torchaudio",
"pyaudio",
- "typing-extensions"
+ "typing-extensions",
+ "websockets"
]
[project.urls]
diff --git a/src/dailyai/pipeline/frame_processor.py b/src/dailyai/pipeline/frame_processor.py
index 3d361b4d9..3f42b6987 100644
--- a/src/dailyai/pipeline/frame_processor.py
+++ b/src/dailyai/pipeline/frame_processor.py
@@ -23,8 +23,6 @@ async def process_frame(
self, frame: Frame
) -> AsyncGenerator[Frame, None]:
"""Process a single frame and yield 0 or more frames."""
- if isinstance(frame, ControlFrame):
- yield frame
yield frame
@abstractmethod
diff --git a/src/dailyai/pipeline/frames.proto b/src/dailyai/pipeline/frames.proto
new file mode 100644
index 000000000..b19fbccbf
--- /dev/null
+++ b/src/dailyai/pipeline/frames.proto
@@ -0,0 +1,25 @@
+syntax = "proto3";
+
+package dailyai_proto;
+
+message TextFrame {
+ string text = 1;
+}
+
+message AudioFrame {
+ bytes data = 1;
+}
+
+message TranscriptionFrame {
+ string text = 1;
+ string participantId = 2;
+ string timestamp = 3;
+}
+
+message Frame {
+ oneof frame {
+ TextFrame text = 1;
+ AudioFrame audio = 2;
+ TranscriptionFrame transcription = 3;
+ }
+}
diff --git a/src/dailyai/pipeline/frames.py b/src/dailyai/pipeline/frames.py
index 6c39fb2c7..c5fdcbbb4 100644
--- a/src/dailyai/pipeline/frames.py
+++ b/src/dailyai/pipeline/frames.py
@@ -2,6 +2,7 @@
from typing import Any, List
from dailyai.services.openai_llm_context import OpenAILLMContext
+import dailyai.pipeline.protobufs.frames_pb2 as frame_protos
class Frame:
@@ -107,6 +108,22 @@ class TranscriptionQueueFrame(TextFrame):
participantId: str
timestamp: str
+ def __str__(self):
+ return f"{self.__class__.__name__}, text: '{self.text}' participantId: {self.participantId}, timestamp: {self.timestamp}"
+
+
+class TTSStartFrame(ControlFrame):
+ """Used to indicate the beginning of a TTS response. Following AudioFrames
+ are part of the TTS response until an TTEndFrame. These frames can be used
+ for aggregating audio frames in a transport to optimize the size of frames
+ sent to the session, without needing to control this in the TTS service."""
+ pass
+
+
+class TTSEndFrame(ControlFrame):
+ """Indicates the end of a TTS response."""
+ pass
+
@dataclass()
class LLMMessagesQueueFrame(Frame):
diff --git a/src/dailyai/pipeline/pipeline.py b/src/dailyai/pipeline/pipeline.py
index b055f3d81..fbb4488db 100644
--- a/src/dailyai/pipeline/pipeline.py
+++ b/src/dailyai/pipeline/pipeline.py
@@ -24,7 +24,7 @@ def __init__(
queues. If this pipeline is run by a transport, its sink and source queues
will be overridden.
"""
- self.processors: List[FrameProcessor] = processors
+ self._processors: List[FrameProcessor] = processors
self.source: asyncio.Queue[Frame] = source or asyncio.Queue()
self.sink: asyncio.Queue[Frame] = sink or asyncio.Queue()
@@ -40,6 +40,9 @@ def set_sink(self, sink: asyncio.Queue[Frame]):
has processed a frame, its output will be placed on this queue."""
self.sink = sink
+ def add_processor(self, processor: FrameProcessor):
+ self._processors.append(processor)
+
async def get_next_source_frame(self) -> AsyncGenerator[Frame, None]:
"""Convenience function to get the next frame from the source queue. This
lets us consistently have an AsyncGenerator yield frames, from either the
@@ -80,7 +83,7 @@ async def run_pipeline(self):
while True:
initial_frame = await self.source.get()
async for frame in self._run_pipeline_recursively(
- initial_frame, self.processors
+ initial_frame, self._processors
):
await self.sink.put(frame)
@@ -91,7 +94,7 @@ async def run_pipeline(self):
except asyncio.CancelledError:
# this means there's been an interruption, do any cleanup necessary
# here.
- for processor in self.processors:
+ for processor in self._processors:
await processor.interrupted()
pass
diff --git a/src/dailyai/pipeline/protobufs/frames_pb2.py b/src/dailyai/pipeline/protobufs/frames_pb2.py
new file mode 100644
index 000000000..ce71723d3
--- /dev/null
+++ b/src/dailyai/pipeline/protobufs/frames_pb2.py
@@ -0,0 +1,32 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: frames.proto
+# Protobuf Python Version: 4.25.3
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x66rames.proto\x12\rdailyai_proto\"\x19\n\tTextFrame\x12\x0c\n\x04text\x18\x01 \x01(\t\"\x1a\n\nAudioFrame\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"L\n\x12TranscriptionFrame\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x15\n\rparticipantId\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\t\"\xa2\x01\n\x05\x46rame\x12(\n\x04text\x18\x01 \x01(\x0b\x32\x18.dailyai_proto.TextFrameH\x00\x12*\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x19.dailyai_proto.AudioFrameH\x00\x12:\n\rtranscription\x18\x03 \x01(\x0b\x32!.dailyai_proto.TranscriptionFrameH\x00\x42\x07\n\x05\x66rameb\x06proto3')
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'frames_pb2', _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+ DESCRIPTOR._options = None
+ _globals['_TEXTFRAME']._serialized_start=31
+ _globals['_TEXTFRAME']._serialized_end=56
+ _globals['_AUDIOFRAME']._serialized_start=58
+ _globals['_AUDIOFRAME']._serialized_end=84
+ _globals['_TRANSCRIPTIONFRAME']._serialized_start=86
+ _globals['_TRANSCRIPTIONFRAME']._serialized_end=162
+ _globals['_FRAME']._serialized_start=165
+ _globals['_FRAME']._serialized_end=327
+# @@protoc_insertion_point(module_scope)
diff --git a/src/dailyai/serializers/abstract_frame_serializer.py b/src/dailyai/serializers/abstract_frame_serializer.py
new file mode 100644
index 000000000..cf0831bcb
--- /dev/null
+++ b/src/dailyai/serializers/abstract_frame_serializer.py
@@ -0,0 +1,16 @@
+from abc import abstractmethod
+
+from dailyai.pipeline.frames import Frame
+
+
+class FrameSerializer:
+ def __init__(self):
+ pass
+
+ @abstractmethod
+ def serialize(self, frame: Frame) -> bytes:
+ raise NotImplementedError()
+
+ @abstractmethod
+ def deserialize(self, data: bytes) -> Frame:
+ raise NotImplementedError
diff --git a/src/dailyai/serializers/protobuf_serializer.py b/src/dailyai/serializers/protobuf_serializer.py
new file mode 100644
index 000000000..1b7e3ded0
--- /dev/null
+++ b/src/dailyai/serializers/protobuf_serializer.py
@@ -0,0 +1,64 @@
+import dataclasses
+from typing import Text
+from dailyai.pipeline.frames import AudioFrame, Frame, TextFrame, TranscriptionQueueFrame
+import dailyai.pipeline.protobufs.frames_pb2 as frame_protos
+from dailyai.serializers.abstract_frame_serializer import FrameSerializer
+
+
+class ProtobufFrameSerializer(FrameSerializer):
+ SERIALIZABLE_TYPES = {
+ TextFrame: "text",
+ AudioFrame: "audio",
+ TranscriptionQueueFrame: "transcription"
+ }
+
+ SERIALIZABLE_FIELDS = {v: k for k, v in SERIALIZABLE_TYPES.items()}
+
+ def __init__(self):
+ pass
+
+ def serialize(self, frame: Frame) -> bytes:
+ proto_frame = frame_protos.Frame()
+ if type(frame) not in self.SERIALIZABLE_TYPES:
+ raise ValueError(
+ f"Frame type {type(frame)} is not serializable. You may need to add it to ProtobufFrameSerializer.SERIALIZABLE_FIELDS.")
+
+ # ignoring linter errors; we check that type(frame) is in this dict above
+ proto_optional_name = self.SERIALIZABLE_TYPES[type(frame)] # type: ignore
+ for field in dataclasses.fields(frame): # type: ignore
+ setattr(getattr(proto_frame, proto_optional_name), field.name,
+ getattr(frame, field.name))
+
+ return proto_frame.SerializeToString()
+
+ def deserialize(self, data: bytes) -> Frame:
+ """Returns a Frame object from a Frame protobuf. Used to convert frames
+ passed over the wire as protobufs to Frame objects used in pipelines
+ and frame processors.
+
+ >>> serializer = ProtobufFrameSerializer()
+ >>> serializer.deserialize(
+ ... serializer.serialize(AudioFrame(data=b'1234567890')))
+ AudioFrame(data=b'1234567890')
+
+ >>> serializer.deserialize(
+ ... serializer.serialize(TextFrame(text='hello world')))
+ TextFrame(text='hello world')
+
+ >>> serializer.deserialize(serializer.serialize(TranscriptionQueueFrame(
+ ... text="Hello there!", participantId="123", timestamp="2021-01-01")))
+ TranscriptionQueueFrame(text='Hello there!', participantId='123', timestamp='2021-01-01')
+ """
+
+ proto = frame_protos.Frame.FromString(data)
+ which = proto.WhichOneof("frame")
+ if which not in self.SERIALIZABLE_FIELDS:
+ raise ValueError(
+ "Proto does not contain a valid frame. You may need to add a new case to ProtobufFrameSerializer.deserialize.")
+
+ class_name = self.SERIALIZABLE_FIELDS[which]
+ args = getattr(proto, which)
+ args_dict = {}
+ for field in proto.DESCRIPTOR.fields_by_name[which].message_type.fields:
+ args_dict[field.name] = getattr(args, field.name)
+ return class_name(**args_dict)
diff --git a/src/dailyai/services/ai_services.py b/src/dailyai/services/ai_services.py
index 5db115e9a..b9f77a685 100644
--- a/src/dailyai/services/ai_services.py
+++ b/src/dailyai/services/ai_services.py
@@ -10,6 +10,8 @@
EndPipeFrame,
ImageFrame,
Frame,
+ TTSEndFrame,
+ TTSStartFrame,
TextFrame,
TranscriptionQueueFrame,
)
@@ -47,12 +49,18 @@ async def run_tts(self, text) -> AsyncGenerator[bytes, None]:
# yield empty bytes here, so linting can infer what this method does
yield bytes()
+ async def wrap_tts(self, text) -> AsyncGenerator[Frame, None]:
+ yield TTSStartFrame()
+ async for audio_chunk in self.run_tts(text):
+ yield AudioFrame(audio_chunk)
+ yield TTSEndFrame()
+ yield TextFrame(text)
+
async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
if isinstance(frame, EndFrame) or isinstance(frame, EndPipeFrame):
if self.current_sentence:
- async for audio_chunk in self.run_tts(self.current_sentence):
- yield AudioFrame(audio_chunk)
- yield TextFrame(self.current_sentence)
+ async for cleanup_frame in self.wrap_tts(self.current_sentence):
+ yield cleanup_frame
if not isinstance(frame, TextFrame):
yield frame
@@ -68,12 +76,8 @@ async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
self.current_sentence = ""
if text:
- async for audio_chunk in self.run_tts(text):
- yield AudioFrame(audio_chunk)
-
- # note we pass along the text frame *after* the audio, so the text
- # frame is completed after the audio is processed.
- yield TextFrame(text)
+ async for frame in self.wrap_tts(text):
+ yield frame
class ImageGenService(AIService):
diff --git a/src/dailyai/services/local_stt_service.py b/src/dailyai/services/local_stt_service.py
index 727bafbc6..98d37dc10 100644
--- a/src/dailyai/services/local_stt_service.py
+++ b/src/dailyai/services/local_stt_service.py
@@ -42,6 +42,7 @@ def _new_wave(self):
async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
"""Processes a frame of audio data, either buffering or transcribing it."""
if not isinstance(frame, AudioFrame):
+ yield frame
return
data = frame.data
diff --git a/src/dailyai/services/websocket_transport_service.py b/src/dailyai/services/websocket_transport_service.py
new file mode 100644
index 000000000..a9df7c07d
--- /dev/null
+++ b/src/dailyai/services/websocket_transport_service.py
@@ -0,0 +1,117 @@
+import asyncio
+import time
+from typing import AsyncGenerator, List
+import websockets
+
+from dailyai.pipeline.frame_processor import FrameProcessor
+from dailyai.pipeline.frames import AudioFrame, ControlFrame, EndFrame, Frame, TTSEndFrame, TTSStartFrame, TextFrame
+from dailyai.pipeline.pipeline import Pipeline
+from dailyai.serializers.protobuf_serializer import ProtobufFrameSerializer
+from dailyai.services.base_transport_service import BaseTransportService
+
+
+class WebSocketFrameProcessor(FrameProcessor):
+ """This FrameProcessor filters and mutates frames before they're sent over the websocket.
+ This is necessary to aggregate audio frames into sizes that are cleanly playable by the client"""
+
+ def __init__(
+ self,
+ audio_frame_size: int | None = None,
+ sendable_frames: List[Frame] | None = None):
+ super().__init__()
+ if not audio_frame_size:
+ raise ValueError("audio_frame_size must be provided")
+
+ self._audio_frame_size = audio_frame_size
+ self._sendable_frames = sendable_frames or [TextFrame, AudioFrame]
+ self._audio_buffer = bytes()
+ self._in_tts_audio = False
+
+ async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
+ if isinstance(frame, TTSStartFrame):
+ self._in_tts_audio = True
+ elif isinstance(frame, AudioFrame):
+ if self._in_tts_audio:
+ self._audio_buffer += frame.data
+ while len(self._audio_buffer) >= self._audio_frame_size:
+ yield AudioFrame(self._audio_buffer[:self._audio_frame_size])
+ self._audio_buffer = self._audio_buffer[self._audio_frame_size:]
+ elif isinstance(frame, TTSEndFrame):
+ self._in_tts_audio = False
+ if self._audio_buffer:
+ yield AudioFrame(self._audio_buffer)
+ self._audio_buffer = bytes()
+ elif type(frame) in self._sendable_frames or isinstance(frame, ControlFrame):
+ yield frame
+
+
+class WebsocketTransport(BaseTransportService):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self._sample_width = kwargs.get("sample_width", 2)
+ self._n_channels = kwargs.get("n_channels", 1)
+ self._port = kwargs.get("port", 8765)
+ self._host = kwargs.get("host", "localhost")
+ self._audio_frame_size = kwargs.get("audio_frame_size", 16000)
+ self._sendable_frames = kwargs.get(
+ "sendable_frames", [
+ TextFrame, AudioFrame, TTSEndFrame, TTSStartFrame])
+ self._serializer = kwargs.get("serializer", ProtobufFrameSerializer())
+
+ self._server: websockets.WebSocketServer | None = None
+ self._websocket: websockets.WebSocketServerProtocol | None = None
+
+ self._connection_handlers = []
+
+ async def run(self, pipeline: Pipeline, override_pipeline_source_queue=True):
+ self._stop_server_event = asyncio.Event()
+ pipeline.set_sink(self.send_queue)
+ if override_pipeline_source_queue:
+ pipeline.set_source(self.receive_queue)
+
+ pipeline.add_processor(WebSocketFrameProcessor(
+ audio_frame_size=self._audio_frame_size,
+ sendable_frames=self._sendable_frames))
+
+ async def timeout():
+ sleep_time = self._expiration - time.time()
+ await asyncio.sleep(sleep_time)
+ self._stop_server_event.set()
+
+ async def send_task():
+ while not self._stop_server_event.is_set():
+ frame = await self.send_queue.get()
+ if isinstance(frame, EndFrame):
+ self._stop_server_event.set()
+ break
+ if self._websocket and frame:
+ proto = self._serializer.serialize(frame)
+ await self._websocket.send(proto)
+
+ async def start_server():
+ async with websockets.serve(self._websocket_handler, self._host, self._port) as server:
+ self._logger.debug("Websocket server started.")
+ await self._stop_server_event.wait()
+ self._logger.debug("Websocket server stopped.")
+ await self.receive_queue.put(EndFrame())
+
+ timeout_task = asyncio.create_task(timeout())
+ await asyncio.gather(start_server(), send_task(), pipeline.run_pipeline())
+ timeout_task.cancel()
+
+ def on_connection(self, handler):
+ self._connection_handlers.append(handler)
+
+ async def _websocket_handler(self, websocket: websockets.WebSocketServerProtocol, path):
+ if self._websocket:
+ await self._websocket.close()
+ self._logger.warning(
+ "Got another websocket connection; closing first.")
+
+ for handler in self._connection_handlers:
+ await handler()
+
+ self._websocket = websocket
+ async for message in websocket:
+ frame = self._serializer.deserialize(message)
+ await self.receive_queue.put(frame)
diff --git a/tests/test_protobuf_serializer.py b/tests/test_protobuf_serializer.py
new file mode 100644
index 000000000..9d49fdff5
--- /dev/null
+++ b/tests/test_protobuf_serializer.py
@@ -0,0 +1,30 @@
+import unittest
+
+from dailyai.pipeline.frames import AudioFrame, TextFrame, TranscriptionQueueFrame
+from dailyai.serializers.protobuf_serializer import ProtobufFrameSerializer
+
+
+class TestProtobufFrameSerializer(unittest.IsolatedAsyncioTestCase):
+ def setUp(self):
+ self.serializer = ProtobufFrameSerializer()
+
+ async def test_roundtrip(self):
+ text_frame = TextFrame(text='hello world')
+ frame = self.serializer.deserialize(
+ self.serializer.serialize(text_frame))
+ self.assertEqual(frame, TextFrame(text='hello world'))
+
+ transcription_frame = TranscriptionQueueFrame(
+ text="Hello there!", participantId="123", timestamp="2021-01-01")
+ frame = self.serializer.deserialize(
+ self.serializer.serialize(transcription_frame))
+ self.assertEqual(frame, transcription_frame)
+
+ audio_frame = AudioFrame(data=b'1234567890')
+ frame = self.serializer.deserialize(
+ self.serializer.serialize(audio_frame))
+ self.assertEqual(frame, audio_frame)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_websocket_transport.py b/tests/test_websocket_transport.py
new file mode 100644
index 000000000..4b0d3d9b3
--- /dev/null
+++ b/tests/test_websocket_transport.py
@@ -0,0 +1,113 @@
+import asyncio
+import unittest
+from unittest.mock import AsyncMock, patch, Mock
+
+from dailyai.pipeline.frames import AudioFrame, EndFrame, TextFrame, TTSEndFrame, TTSStartFrame
+from dailyai.pipeline.pipeline import Pipeline
+from dailyai.services.websocket_transport_service import WebSocketFrameProcessor, WebsocketTransport
+
+
+class TestWebSocketTransportService(unittest.IsolatedAsyncioTestCase):
+ def setUp(self):
+ self.transport = WebsocketTransport(host="localhost", port=8765)
+ self.pipeline = Pipeline([])
+ self.sample_frame = TextFrame("Hello there!")
+ self.serialized_sample_frame = self.transport._serializer.serialize(
+ self.sample_frame)
+
+ async def queue_frame(self):
+ await asyncio.sleep(0.1)
+ await self.pipeline.queue_frames([self.sample_frame, EndFrame()])
+
+ async def test_websocket_handler(self):
+ mock_websocket = AsyncMock()
+
+ with patch("websockets.serve", return_value=AsyncMock()) as mock_serve:
+ mock_serve.return_value.__anext__.return_value = (
+ mock_websocket, "/")
+
+ await self.transport._websocket_handler(mock_websocket, "/")
+
+ await asyncio.gather(self.transport.run(self.pipeline), self.queue_frame())
+ self.assertEqual(mock_websocket.send.call_count, 1)
+
+ self.assertEqual(
+ mock_websocket.send.call_args[0][0], self.serialized_sample_frame)
+
+ async def test_on_connection_decorator(self):
+ mock_websocket = AsyncMock()
+
+ connection_handler_called = asyncio.Event()
+
+ @self.transport.on_connection
+ async def connection_handler():
+ connection_handler_called.set()
+
+ with patch("websockets.serve", return_value=AsyncMock()):
+ await self.transport._websocket_handler(mock_websocket, "/")
+
+ self.assertTrue(connection_handler_called.is_set())
+
+ async def test_frame_processor(self):
+ processor = WebSocketFrameProcessor(audio_frame_size=4)
+
+ source_frames = [
+ TTSStartFrame(),
+ AudioFrame(b"1234"),
+ AudioFrame(b"5678"),
+ TTSEndFrame(),
+ TextFrame("hello world")
+ ]
+
+ frames = []
+ for frame in source_frames:
+ async for output_frame in processor.process_frame(frame):
+ frames.append(output_frame)
+
+ self.assertEqual(len(frames), 3)
+ self.assertIsInstance(frames[0], AudioFrame)
+ self.assertEqual(frames[0].data, b"1234")
+ self.assertIsInstance(frames[1], AudioFrame)
+ self.assertEqual(frames[1].data, b"5678")
+ self.assertIsInstance(frames[2], TextFrame)
+ self.assertEqual(frames[2].text, "hello world")
+
+ async def test_serializer_parameter(self):
+ mock_websocket = AsyncMock()
+
+ # Test with ProtobufFrameSerializer (default)
+ with patch("websockets.serve", return_value=AsyncMock()) as mock_serve:
+ mock_serve.return_value.__anext__.return_value = (
+ mock_websocket, "/")
+
+ await self.transport._websocket_handler(mock_websocket, "/")
+
+ await asyncio.gather(self.transport.run(self.pipeline), self.queue_frame())
+ self.assertEqual(mock_websocket.send.call_count, 1)
+ self.assertEqual(
+ mock_websocket.send.call_args[0][0],
+ self.serialized_sample_frame,
+ )
+
+ # Test with a mock serializer
+ mock_serializer = Mock()
+ mock_serializer.serialize.return_value = b"mock_serialized_data"
+ self.transport = WebsocketTransport(
+ host="localhost", port=8765, serializer=mock_serializer
+ )
+ mock_websocket.reset_mock()
+ with patch("websockets.serve", return_value=AsyncMock()) as mock_serve:
+ mock_serve.return_value.__anext__.return_value = (
+ mock_websocket, "/")
+
+ await self.transport._websocket_handler(mock_websocket, "/")
+ await asyncio.gather(self.transport.run(self.pipeline), self.queue_frame())
+ self.assertEqual(mock_websocket.send.call_count, 1)
+ self.assertEqual(
+ mock_websocket.send.call_args[0][0], b"mock_serialized_data")
+ mock_serializer.serialize.assert_called_once_with(
+ TextFrame("Hello there!"))
+
+
+if __name__ == "__main__":
+ unittest.main()