diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index e678623cb..da097d237 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -22,11 +22,23 @@ jobs: steps: - name: Checkout repo uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Setup virtual environment + run: | + python -m venv .venv + - name: Install basic Python dependencies + run: | + source .venv/bin/activate + python -m pip install --upgrade pip + pip install -r requirements.txt - name: autopep8 id: autopep8 - uses: peter-evans/autopep8@v2 - with: - args: --exit-code -r -d -a -a src/ + run: | + source .venv/bin/activate + autopep8 --max-line-length 100 --exit-code -r -d --exclude "*_pb2.py" -a -a src/ - name: Fail if autopep8 requires changes if: steps.autopep8.outputs.exit-code == 2 run: exit 1 diff --git a/README.md b/README.md index feed35dd6..21e38a31d 100644 --- a/README.md +++ b/README.md @@ -127,7 +127,7 @@ You can use [use-package](https://github.com/jwiegley/use-package) to install [p :defer t :hook ((python-mode . py-autopep8-mode)) :config - (setq py-autopep8-options '("-a" "-a"))) + (setq py-autopep8-options '("-a" "-a", "--max-line-length=100"))) ``` `autopep8` was installed in the `venv` environment described before, so you should be able to use [pyvenv-auto](https://github.com/ryotaro612/pyvenv-auto) to automatically load that environment inside Emacs. @@ -152,6 +152,7 @@ Install the }, "autopep8.args": [ "-a", - "-a" + "-a", + "--max-line-length=100" ], ``` diff --git a/examples/foundational/10-wake-word.py b/examples/foundational/10-wake-word.py index 00c3acdc6..cd82a59e3 100644 --- a/examples/foundational/10-wake-word.py +++ b/examples/foundational/10-wake-word.py @@ -172,7 +172,8 @@ async def handle_transcriptions(): isa.run( tma_out.run( llm.run( - tma_in.run(ncf.run(tf.run(transport.get_receive_frames()))) + tma_in.run( + ncf.run(tf.run(transport.get_receive_frames()))) ) ) ), diff --git a/examples/foundational/websocket-server/frames.proto b/examples/foundational/websocket-server/frames.proto new file mode 100644 index 000000000..7ecea6d25 --- /dev/null +++ b/examples/foundational/websocket-server/frames.proto @@ -0,0 +1,25 @@ +syntax = "proto3"; + +package dailyai_proto; + +message TextFrame { + string text = 1; +} + +message AudioFrame { + bytes audio = 1; +} + +message TranscriptionFrame { + string text = 1; + string participant_id = 2; + string timestamp = 3; +} + +message Frame { + oneof frame { + TextFrame text = 1; + AudioFrame audio = 2; + TranscriptionFrame transcription = 3; + } +} diff --git a/examples/foundational/websocket-server/index.html b/examples/foundational/websocket-server/index.html new file mode 100644 index 000000000..77be13518 --- /dev/null +++ b/examples/foundational/websocket-server/index.html @@ -0,0 +1,134 @@ + + + + + + + + WebSocket Audio Stream + + + +

WebSocket Audio Stream

+ + + + + + diff --git a/examples/foundational/websocket-server/sample.py b/examples/foundational/websocket-server/sample.py new file mode 100644 index 000000000..72390831e --- /dev/null +++ b/examples/foundational/websocket-server/sample.py @@ -0,0 +1,50 @@ +import asyncio +import aiohttp +import logging +import os +from dailyai.pipeline.frame_processor import FrameProcessor +from dailyai.pipeline.frames import TextFrame, TranscriptionQueueFrame +from dailyai.pipeline.pipeline import Pipeline +from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService +from dailyai.services.websocket_transport_service import WebsocketTransport +from dailyai.services.whisper_ai_services import WhisperSTTService + +logging.basicConfig(format="%(levelno)s %(asctime)s %(message)s") +logger = logging.getLogger("dailyai") +logger.setLevel(logging.DEBUG) + + +class WhisperTranscriber(FrameProcessor): + async def process_frame(self, frame): + if isinstance(frame, TranscriptionQueueFrame): + print(f"Transcribed: {frame.text}") + else: + yield frame + + +async def main(): + async with aiohttp.ClientSession() as session: + transport = WebsocketTransport( + mic_enabled=True, + speaker_enabled=True, + ) + tts = ElevenLabsTTSService( + aiohttp_session=session, + api_key=os.getenv("ELEVENLABS_API_KEY"), + voice_id=os.getenv("ELEVENLABS_VOICE_ID"), + ) + + pipeline = Pipeline([ + WhisperSTTService(), + WhisperTranscriber(), + tts, + ]) + + @transport.on_connection + async def queue_frame(): + await pipeline.queue_frames([TextFrame("Hello there!")]) + + await transport.run(pipeline) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index e1eac802b..a559076f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,8 @@ dependencies = [ "torch", "torchaudio", "pyaudio", - "typing-extensions" + "typing-extensions", + "websockets" ] [project.urls] diff --git a/src/dailyai/pipeline/frame_processor.py b/src/dailyai/pipeline/frame_processor.py index 3d361b4d9..3f42b6987 100644 --- a/src/dailyai/pipeline/frame_processor.py +++ b/src/dailyai/pipeline/frame_processor.py @@ -23,8 +23,6 @@ async def process_frame( self, frame: Frame ) -> AsyncGenerator[Frame, None]: """Process a single frame and yield 0 or more frames.""" - if isinstance(frame, ControlFrame): - yield frame yield frame @abstractmethod diff --git a/src/dailyai/pipeline/frames.proto b/src/dailyai/pipeline/frames.proto new file mode 100644 index 000000000..b19fbccbf --- /dev/null +++ b/src/dailyai/pipeline/frames.proto @@ -0,0 +1,25 @@ +syntax = "proto3"; + +package dailyai_proto; + +message TextFrame { + string text = 1; +} + +message AudioFrame { + bytes data = 1; +} + +message TranscriptionFrame { + string text = 1; + string participantId = 2; + string timestamp = 3; +} + +message Frame { + oneof frame { + TextFrame text = 1; + AudioFrame audio = 2; + TranscriptionFrame transcription = 3; + } +} diff --git a/src/dailyai/pipeline/frames.py b/src/dailyai/pipeline/frames.py index 6c39fb2c7..c5fdcbbb4 100644 --- a/src/dailyai/pipeline/frames.py +++ b/src/dailyai/pipeline/frames.py @@ -2,6 +2,7 @@ from typing import Any, List from dailyai.services.openai_llm_context import OpenAILLMContext +import dailyai.pipeline.protobufs.frames_pb2 as frame_protos class Frame: @@ -107,6 +108,22 @@ class TranscriptionQueueFrame(TextFrame): participantId: str timestamp: str + def __str__(self): + return f"{self.__class__.__name__}, text: '{self.text}' participantId: {self.participantId}, timestamp: {self.timestamp}" + + +class TTSStartFrame(ControlFrame): + """Used to indicate the beginning of a TTS response. Following AudioFrames + are part of the TTS response until an TTEndFrame. These frames can be used + for aggregating audio frames in a transport to optimize the size of frames + sent to the session, without needing to control this in the TTS service.""" + pass + + +class TTSEndFrame(ControlFrame): + """Indicates the end of a TTS response.""" + pass + @dataclass() class LLMMessagesQueueFrame(Frame): diff --git a/src/dailyai/pipeline/pipeline.py b/src/dailyai/pipeline/pipeline.py index b055f3d81..fbb4488db 100644 --- a/src/dailyai/pipeline/pipeline.py +++ b/src/dailyai/pipeline/pipeline.py @@ -24,7 +24,7 @@ def __init__( queues. If this pipeline is run by a transport, its sink and source queues will be overridden. """ - self.processors: List[FrameProcessor] = processors + self._processors: List[FrameProcessor] = processors self.source: asyncio.Queue[Frame] = source or asyncio.Queue() self.sink: asyncio.Queue[Frame] = sink or asyncio.Queue() @@ -40,6 +40,9 @@ def set_sink(self, sink: asyncio.Queue[Frame]): has processed a frame, its output will be placed on this queue.""" self.sink = sink + def add_processor(self, processor: FrameProcessor): + self._processors.append(processor) + async def get_next_source_frame(self) -> AsyncGenerator[Frame, None]: """Convenience function to get the next frame from the source queue. This lets us consistently have an AsyncGenerator yield frames, from either the @@ -80,7 +83,7 @@ async def run_pipeline(self): while True: initial_frame = await self.source.get() async for frame in self._run_pipeline_recursively( - initial_frame, self.processors + initial_frame, self._processors ): await self.sink.put(frame) @@ -91,7 +94,7 @@ async def run_pipeline(self): except asyncio.CancelledError: # this means there's been an interruption, do any cleanup necessary # here. - for processor in self.processors: + for processor in self._processors: await processor.interrupted() pass diff --git a/src/dailyai/pipeline/protobufs/frames_pb2.py b/src/dailyai/pipeline/protobufs/frames_pb2.py new file mode 100644 index 000000000..ce71723d3 --- /dev/null +++ b/src/dailyai/pipeline/protobufs/frames_pb2.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: frames.proto +# Protobuf Python Version: 4.25.3 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x66rames.proto\x12\rdailyai_proto\"\x19\n\tTextFrame\x12\x0c\n\x04text\x18\x01 \x01(\t\"\x1a\n\nAudioFrame\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"L\n\x12TranscriptionFrame\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x15\n\rparticipantId\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\t\"\xa2\x01\n\x05\x46rame\x12(\n\x04text\x18\x01 \x01(\x0b\x32\x18.dailyai_proto.TextFrameH\x00\x12*\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x19.dailyai_proto.AudioFrameH\x00\x12:\n\rtranscription\x18\x03 \x01(\x0b\x32!.dailyai_proto.TranscriptionFrameH\x00\x42\x07\n\x05\x66rameb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'frames_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_TEXTFRAME']._serialized_start=31 + _globals['_TEXTFRAME']._serialized_end=56 + _globals['_AUDIOFRAME']._serialized_start=58 + _globals['_AUDIOFRAME']._serialized_end=84 + _globals['_TRANSCRIPTIONFRAME']._serialized_start=86 + _globals['_TRANSCRIPTIONFRAME']._serialized_end=162 + _globals['_FRAME']._serialized_start=165 + _globals['_FRAME']._serialized_end=327 +# @@protoc_insertion_point(module_scope) diff --git a/src/dailyai/serializers/abstract_frame_serializer.py b/src/dailyai/serializers/abstract_frame_serializer.py new file mode 100644 index 000000000..cf0831bcb --- /dev/null +++ b/src/dailyai/serializers/abstract_frame_serializer.py @@ -0,0 +1,16 @@ +from abc import abstractmethod + +from dailyai.pipeline.frames import Frame + + +class FrameSerializer: + def __init__(self): + pass + + @abstractmethod + def serialize(self, frame: Frame) -> bytes: + raise NotImplementedError() + + @abstractmethod + def deserialize(self, data: bytes) -> Frame: + raise NotImplementedError diff --git a/src/dailyai/serializers/protobuf_serializer.py b/src/dailyai/serializers/protobuf_serializer.py new file mode 100644 index 000000000..1b7e3ded0 --- /dev/null +++ b/src/dailyai/serializers/protobuf_serializer.py @@ -0,0 +1,64 @@ +import dataclasses +from typing import Text +from dailyai.pipeline.frames import AudioFrame, Frame, TextFrame, TranscriptionQueueFrame +import dailyai.pipeline.protobufs.frames_pb2 as frame_protos +from dailyai.serializers.abstract_frame_serializer import FrameSerializer + + +class ProtobufFrameSerializer(FrameSerializer): + SERIALIZABLE_TYPES = { + TextFrame: "text", + AudioFrame: "audio", + TranscriptionQueueFrame: "transcription" + } + + SERIALIZABLE_FIELDS = {v: k for k, v in SERIALIZABLE_TYPES.items()} + + def __init__(self): + pass + + def serialize(self, frame: Frame) -> bytes: + proto_frame = frame_protos.Frame() + if type(frame) not in self.SERIALIZABLE_TYPES: + raise ValueError( + f"Frame type {type(frame)} is not serializable. You may need to add it to ProtobufFrameSerializer.SERIALIZABLE_FIELDS.") + + # ignoring linter errors; we check that type(frame) is in this dict above + proto_optional_name = self.SERIALIZABLE_TYPES[type(frame)] # type: ignore + for field in dataclasses.fields(frame): # type: ignore + setattr(getattr(proto_frame, proto_optional_name), field.name, + getattr(frame, field.name)) + + return proto_frame.SerializeToString() + + def deserialize(self, data: bytes) -> Frame: + """Returns a Frame object from a Frame protobuf. Used to convert frames + passed over the wire as protobufs to Frame objects used in pipelines + and frame processors. + + >>> serializer = ProtobufFrameSerializer() + >>> serializer.deserialize( + ... serializer.serialize(AudioFrame(data=b'1234567890'))) + AudioFrame(data=b'1234567890') + + >>> serializer.deserialize( + ... serializer.serialize(TextFrame(text='hello world'))) + TextFrame(text='hello world') + + >>> serializer.deserialize(serializer.serialize(TranscriptionQueueFrame( + ... text="Hello there!", participantId="123", timestamp="2021-01-01"))) + TranscriptionQueueFrame(text='Hello there!', participantId='123', timestamp='2021-01-01') + """ + + proto = frame_protos.Frame.FromString(data) + which = proto.WhichOneof("frame") + if which not in self.SERIALIZABLE_FIELDS: + raise ValueError( + "Proto does not contain a valid frame. You may need to add a new case to ProtobufFrameSerializer.deserialize.") + + class_name = self.SERIALIZABLE_FIELDS[which] + args = getattr(proto, which) + args_dict = {} + for field in proto.DESCRIPTOR.fields_by_name[which].message_type.fields: + args_dict[field.name] = getattr(args, field.name) + return class_name(**args_dict) diff --git a/src/dailyai/services/ai_services.py b/src/dailyai/services/ai_services.py index 5db115e9a..b9f77a685 100644 --- a/src/dailyai/services/ai_services.py +++ b/src/dailyai/services/ai_services.py @@ -10,6 +10,8 @@ EndPipeFrame, ImageFrame, Frame, + TTSEndFrame, + TTSStartFrame, TextFrame, TranscriptionQueueFrame, ) @@ -47,12 +49,18 @@ async def run_tts(self, text) -> AsyncGenerator[bytes, None]: # yield empty bytes here, so linting can infer what this method does yield bytes() + async def wrap_tts(self, text) -> AsyncGenerator[Frame, None]: + yield TTSStartFrame() + async for audio_chunk in self.run_tts(text): + yield AudioFrame(audio_chunk) + yield TTSEndFrame() + yield TextFrame(text) + async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: if isinstance(frame, EndFrame) or isinstance(frame, EndPipeFrame): if self.current_sentence: - async for audio_chunk in self.run_tts(self.current_sentence): - yield AudioFrame(audio_chunk) - yield TextFrame(self.current_sentence) + async for cleanup_frame in self.wrap_tts(self.current_sentence): + yield cleanup_frame if not isinstance(frame, TextFrame): yield frame @@ -68,12 +76,8 @@ async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: self.current_sentence = "" if text: - async for audio_chunk in self.run_tts(text): - yield AudioFrame(audio_chunk) - - # note we pass along the text frame *after* the audio, so the text - # frame is completed after the audio is processed. - yield TextFrame(text) + async for frame in self.wrap_tts(text): + yield frame class ImageGenService(AIService): diff --git a/src/dailyai/services/local_stt_service.py b/src/dailyai/services/local_stt_service.py index 727bafbc6..98d37dc10 100644 --- a/src/dailyai/services/local_stt_service.py +++ b/src/dailyai/services/local_stt_service.py @@ -42,6 +42,7 @@ def _new_wave(self): async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: """Processes a frame of audio data, either buffering or transcribing it.""" if not isinstance(frame, AudioFrame): + yield frame return data = frame.data diff --git a/src/dailyai/services/websocket_transport_service.py b/src/dailyai/services/websocket_transport_service.py new file mode 100644 index 000000000..a9df7c07d --- /dev/null +++ b/src/dailyai/services/websocket_transport_service.py @@ -0,0 +1,117 @@ +import asyncio +import time +from typing import AsyncGenerator, List +import websockets + +from dailyai.pipeline.frame_processor import FrameProcessor +from dailyai.pipeline.frames import AudioFrame, ControlFrame, EndFrame, Frame, TTSEndFrame, TTSStartFrame, TextFrame +from dailyai.pipeline.pipeline import Pipeline +from dailyai.serializers.protobuf_serializer import ProtobufFrameSerializer +from dailyai.services.base_transport_service import BaseTransportService + + +class WebSocketFrameProcessor(FrameProcessor): + """This FrameProcessor filters and mutates frames before they're sent over the websocket. + This is necessary to aggregate audio frames into sizes that are cleanly playable by the client""" + + def __init__( + self, + audio_frame_size: int | None = None, + sendable_frames: List[Frame] | None = None): + super().__init__() + if not audio_frame_size: + raise ValueError("audio_frame_size must be provided") + + self._audio_frame_size = audio_frame_size + self._sendable_frames = sendable_frames or [TextFrame, AudioFrame] + self._audio_buffer = bytes() + self._in_tts_audio = False + + async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: + if isinstance(frame, TTSStartFrame): + self._in_tts_audio = True + elif isinstance(frame, AudioFrame): + if self._in_tts_audio: + self._audio_buffer += frame.data + while len(self._audio_buffer) >= self._audio_frame_size: + yield AudioFrame(self._audio_buffer[:self._audio_frame_size]) + self._audio_buffer = self._audio_buffer[self._audio_frame_size:] + elif isinstance(frame, TTSEndFrame): + self._in_tts_audio = False + if self._audio_buffer: + yield AudioFrame(self._audio_buffer) + self._audio_buffer = bytes() + elif type(frame) in self._sendable_frames or isinstance(frame, ControlFrame): + yield frame + + +class WebsocketTransport(BaseTransportService): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._sample_width = kwargs.get("sample_width", 2) + self._n_channels = kwargs.get("n_channels", 1) + self._port = kwargs.get("port", 8765) + self._host = kwargs.get("host", "localhost") + self._audio_frame_size = kwargs.get("audio_frame_size", 16000) + self._sendable_frames = kwargs.get( + "sendable_frames", [ + TextFrame, AudioFrame, TTSEndFrame, TTSStartFrame]) + self._serializer = kwargs.get("serializer", ProtobufFrameSerializer()) + + self._server: websockets.WebSocketServer | None = None + self._websocket: websockets.WebSocketServerProtocol | None = None + + self._connection_handlers = [] + + async def run(self, pipeline: Pipeline, override_pipeline_source_queue=True): + self._stop_server_event = asyncio.Event() + pipeline.set_sink(self.send_queue) + if override_pipeline_source_queue: + pipeline.set_source(self.receive_queue) + + pipeline.add_processor(WebSocketFrameProcessor( + audio_frame_size=self._audio_frame_size, + sendable_frames=self._sendable_frames)) + + async def timeout(): + sleep_time = self._expiration - time.time() + await asyncio.sleep(sleep_time) + self._stop_server_event.set() + + async def send_task(): + while not self._stop_server_event.is_set(): + frame = await self.send_queue.get() + if isinstance(frame, EndFrame): + self._stop_server_event.set() + break + if self._websocket and frame: + proto = self._serializer.serialize(frame) + await self._websocket.send(proto) + + async def start_server(): + async with websockets.serve(self._websocket_handler, self._host, self._port) as server: + self._logger.debug("Websocket server started.") + await self._stop_server_event.wait() + self._logger.debug("Websocket server stopped.") + await self.receive_queue.put(EndFrame()) + + timeout_task = asyncio.create_task(timeout()) + await asyncio.gather(start_server(), send_task(), pipeline.run_pipeline()) + timeout_task.cancel() + + def on_connection(self, handler): + self._connection_handlers.append(handler) + + async def _websocket_handler(self, websocket: websockets.WebSocketServerProtocol, path): + if self._websocket: + await self._websocket.close() + self._logger.warning( + "Got another websocket connection; closing first.") + + for handler in self._connection_handlers: + await handler() + + self._websocket = websocket + async for message in websocket: + frame = self._serializer.deserialize(message) + await self.receive_queue.put(frame) diff --git a/tests/test_protobuf_serializer.py b/tests/test_protobuf_serializer.py new file mode 100644 index 000000000..9d49fdff5 --- /dev/null +++ b/tests/test_protobuf_serializer.py @@ -0,0 +1,30 @@ +import unittest + +from dailyai.pipeline.frames import AudioFrame, TextFrame, TranscriptionQueueFrame +from dailyai.serializers.protobuf_serializer import ProtobufFrameSerializer + + +class TestProtobufFrameSerializer(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.serializer = ProtobufFrameSerializer() + + async def test_roundtrip(self): + text_frame = TextFrame(text='hello world') + frame = self.serializer.deserialize( + self.serializer.serialize(text_frame)) + self.assertEqual(frame, TextFrame(text='hello world')) + + transcription_frame = TranscriptionQueueFrame( + text="Hello there!", participantId="123", timestamp="2021-01-01") + frame = self.serializer.deserialize( + self.serializer.serialize(transcription_frame)) + self.assertEqual(frame, transcription_frame) + + audio_frame = AudioFrame(data=b'1234567890') + frame = self.serializer.deserialize( + self.serializer.serialize(audio_frame)) + self.assertEqual(frame, audio_frame) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_websocket_transport.py b/tests/test_websocket_transport.py new file mode 100644 index 000000000..4b0d3d9b3 --- /dev/null +++ b/tests/test_websocket_transport.py @@ -0,0 +1,113 @@ +import asyncio +import unittest +from unittest.mock import AsyncMock, patch, Mock + +from dailyai.pipeline.frames import AudioFrame, EndFrame, TextFrame, TTSEndFrame, TTSStartFrame +from dailyai.pipeline.pipeline import Pipeline +from dailyai.services.websocket_transport_service import WebSocketFrameProcessor, WebsocketTransport + + +class TestWebSocketTransportService(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.transport = WebsocketTransport(host="localhost", port=8765) + self.pipeline = Pipeline([]) + self.sample_frame = TextFrame("Hello there!") + self.serialized_sample_frame = self.transport._serializer.serialize( + self.sample_frame) + + async def queue_frame(self): + await asyncio.sleep(0.1) + await self.pipeline.queue_frames([self.sample_frame, EndFrame()]) + + async def test_websocket_handler(self): + mock_websocket = AsyncMock() + + with patch("websockets.serve", return_value=AsyncMock()) as mock_serve: + mock_serve.return_value.__anext__.return_value = ( + mock_websocket, "/") + + await self.transport._websocket_handler(mock_websocket, "/") + + await asyncio.gather(self.transport.run(self.pipeline), self.queue_frame()) + self.assertEqual(mock_websocket.send.call_count, 1) + + self.assertEqual( + mock_websocket.send.call_args[0][0], self.serialized_sample_frame) + + async def test_on_connection_decorator(self): + mock_websocket = AsyncMock() + + connection_handler_called = asyncio.Event() + + @self.transport.on_connection + async def connection_handler(): + connection_handler_called.set() + + with patch("websockets.serve", return_value=AsyncMock()): + await self.transport._websocket_handler(mock_websocket, "/") + + self.assertTrue(connection_handler_called.is_set()) + + async def test_frame_processor(self): + processor = WebSocketFrameProcessor(audio_frame_size=4) + + source_frames = [ + TTSStartFrame(), + AudioFrame(b"1234"), + AudioFrame(b"5678"), + TTSEndFrame(), + TextFrame("hello world") + ] + + frames = [] + for frame in source_frames: + async for output_frame in processor.process_frame(frame): + frames.append(output_frame) + + self.assertEqual(len(frames), 3) + self.assertIsInstance(frames[0], AudioFrame) + self.assertEqual(frames[0].data, b"1234") + self.assertIsInstance(frames[1], AudioFrame) + self.assertEqual(frames[1].data, b"5678") + self.assertIsInstance(frames[2], TextFrame) + self.assertEqual(frames[2].text, "hello world") + + async def test_serializer_parameter(self): + mock_websocket = AsyncMock() + + # Test with ProtobufFrameSerializer (default) + with patch("websockets.serve", return_value=AsyncMock()) as mock_serve: + mock_serve.return_value.__anext__.return_value = ( + mock_websocket, "/") + + await self.transport._websocket_handler(mock_websocket, "/") + + await asyncio.gather(self.transport.run(self.pipeline), self.queue_frame()) + self.assertEqual(mock_websocket.send.call_count, 1) + self.assertEqual( + mock_websocket.send.call_args[0][0], + self.serialized_sample_frame, + ) + + # Test with a mock serializer + mock_serializer = Mock() + mock_serializer.serialize.return_value = b"mock_serialized_data" + self.transport = WebsocketTransport( + host="localhost", port=8765, serializer=mock_serializer + ) + mock_websocket.reset_mock() + with patch("websockets.serve", return_value=AsyncMock()) as mock_serve: + mock_serve.return_value.__anext__.return_value = ( + mock_websocket, "/") + + await self.transport._websocket_handler(mock_websocket, "/") + await asyncio.gather(self.transport.run(self.pipeline), self.queue_frame()) + self.assertEqual(mock_websocket.send.call_count, 1) + self.assertEqual( + mock_websocket.send.call_args[0][0], b"mock_serialized_data") + mock_serializer.serialize.assert_called_once_with( + TextFrame("Hello there!")) + + +if __name__ == "__main__": + unittest.main()