diff --git a/examples/foundational/10-wake-word.py b/examples/foundational/10-wake-word.py index 00c3acdc6..cd82a59e3 100644 --- a/examples/foundational/10-wake-word.py +++ b/examples/foundational/10-wake-word.py @@ -172,7 +172,8 @@ async def handle_transcriptions(): isa.run( tma_out.run( llm.run( - tma_in.run(ncf.run(tf.run(transport.get_receive_frames()))) + tma_in.run( + ncf.run(tf.run(transport.get_receive_frames()))) ) ) ), diff --git a/examples/foundational/websocket-server/index.html b/examples/foundational/websocket-server/index.html index eafb4d39b..77be13518 100644 --- a/examples/foundational/websocket-server/index.html +++ b/examples/foundational/websocket-server/index.html @@ -49,10 +49,10 @@

WebSocket Audio Stream

const parsedFrame = frame.decode(new Uint8Array(arrayBuffer)); if (!parsedFrame?.audio) return false; - const frameCount = parsedFrame.audio.audio.length / 2; + const frameCount = parsedFrame.audio.data.length / 2; const audioOutBuffer = audioContext.createBuffer(1, frameCount, SAMPLE_RATE); const nowBuffering = audioOutBuffer.getChannelData(0); - const view = new Int16Array(parsedFrame.audio.audio.buffer); + const view = new Int16Array(parsedFrame.audio.data.buffer); for (let i = 0; i < frameCount; i++) { const word = view[i]; diff --git a/src/dailyai/pipeline/frames.proto b/src/dailyai/pipeline/frames.proto index 7ecea6d25..b19fbccbf 100644 --- a/src/dailyai/pipeline/frames.proto +++ b/src/dailyai/pipeline/frames.proto @@ -7,12 +7,12 @@ message TextFrame { } message AudioFrame { - bytes audio = 1; + bytes data = 1; } message TranscriptionFrame { string text = 1; - string participant_id = 2; + string participantId = 2; string timestamp = 3; } diff --git a/src/dailyai/pipeline/frames.py b/src/dailyai/pipeline/frames.py index 6f9738cd4..c5fdcbbb4 100644 --- a/src/dailyai/pipeline/frames.py +++ b/src/dailyai/pipeline/frames.py @@ -9,21 +9,6 @@ class Frame: def __str__(self): return f"{self.__class__.__name__}" - def to_proto(self): - """Converts a Frame object to a Frame protobuf. Used to send frames via the - websocket transport, or other clients that send Frames to their clients. - If you see this exception, you may need to implement a to_proto method on your - Frame subclass.""" - raise NotImplementedError - - @staticmethod - def from_proto(proto): - """Converts a Frame protobuf to a Frame object. Used to receive frames via the - websocket transport, or other clients that receive Frames from their clients. - If you see this exception, you may need to implement a from_proto method on your - Frame subclass, and add your class to FrameFromProto below.""" - raise NotImplementedError - class ControlFrame(Frame): # Control frames should contain no instance data, so @@ -77,36 +62,12 @@ class LLMResponseEndFrame(ControlFrame): @dataclass() class AudioFrame(Frame): """A chunk of audio. Will be played by the transport if the transport's mic - has been enabled. - - >>> str(AudioFrame(data=b'1234567890')) - 'AudioFrame, size: 10 B' - - >>> AudioFrame.from_proto(AudioFrame(data=b'1234567890').to_proto()) - AudioFrame(data=b'1234567890') - - The to_proto() function will always return a top-level Frame protobuf, so - it can be sent across the wire with other frames of various types. - - >>> type(AudioFrame(data=b'1234567890').to_proto()) - - """ + has been enabled.""" data: bytes def __str__(self): return f"{self.__class__.__name__}, size: {len(self.data)} B" - def to_proto(self) -> frame_protos.Frame: - frame = frame_protos.Frame() - frame.audio.audio = self.data - return frame - - @staticmethod - def from_proto(proto: frame_protos.Frame): - if proto.WhichOneof("frame") != "audio": - raise ValueError("Proto does not contain an audio frame") - return AudioFrame(data=proto.audio.audio) - @dataclass() class ImageFrame(Frame): @@ -133,71 +94,23 @@ def __str__(self): @dataclass() class TextFrame(Frame): """A chunk of text. Emitted by LLM services, consumed by TTS services, can - be used to send text through pipelines. - - >>> str(TextFrame.from_proto(TextFrame(text='hello world').to_proto())) - 'TextFrame: "hello world"' - - The to_proto() function will always return a top-level Frame protobuf, so - it can be sent across the wire with other frames of various types. - - >>> type(TextFrame(text='hello world').to_proto()) - - """ + be used to send text through pipelines.""" text: str def __str__(self): return f'{self.__class__.__name__}: "{self.text}"' - def to_proto(self) -> frame_protos.Frame: - proto_frame = frame_protos.Frame() - proto_frame.text.text = self.text - return proto_frame - - @staticmethod - def from_proto(proto: frame_protos.TextFrame): - return TextFrame(text=proto.text.text) - @dataclass() class TranscriptionQueueFrame(TextFrame): """A text frame with transcription-specific data. Will be placed in the - transport's receive queue when a participant speaks. - - >>> transcription_frame = TranscriptionQueueFrame(text="Hello there!", participantId="123", timestamp="2021-01-01") - >>> transcription_frame - TranscriptionQueueFrame(text='Hello there!', participantId='123', timestamp='2021-01-01') - - >>> TranscriptionQueueFrame.from_proto(transcription_frame.to_proto()) - TranscriptionQueueFrame(text='Hello there!', participantId='123', timestamp='2021-01-01') - - The to_proto() function will always return a top-level Frame protobuf, so - it can be sent across the wire with other frames of various types. - - >>> type(transcription_frame.to_proto()) - - """ + transport's receive queue when a participant speaks.""" participantId: str timestamp: str def __str__(self): return f"{self.__class__.__name__}, text: '{self.text}' participantId: {self.participantId}, timestamp: {self.timestamp}" - def to_proto(self) -> frame_protos.Frame: - frame = frame_protos.Frame() - frame.transcription.text = self.text - frame.transcription.participant_id = self.participantId - frame.transcription.timestamp = self.timestamp - return frame - - @staticmethod - def from_proto(proto: frame_protos.Frame): - return TranscriptionQueueFrame( - text=proto.transcription.text, - participantId=proto.transcription.participant_id, - timestamp=proto.transcription.timestamp - ) - class TTSStartFrame(ControlFrame): """Used to indicate the beginning of a TTS response. Following AudioFrames diff --git a/src/dailyai/pipeline/protobufs/frames_pb2.py b/src/dailyai/pipeline/protobufs/frames_pb2.py index b923f8fe3..ce71723d3 100644 --- a/src/dailyai/pipeline/protobufs/frames_pb2.py +++ b/src/dailyai/pipeline/protobufs/frames_pb2.py @@ -12,19 +12,21 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x66rames.proto\x12\rdailyai_proto\"\x19\n\tTextFrame\x12\x0c\n\x04text\x18\x01 \x01(\t\"\x1b\n\nAudioFrame\x12\r\n\x05\x61udio\x18\x01 \x01(\x0c\"M\n\x12TranscriptionFrame\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x16\n\x0eparticipant_id\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\t\"\xa2\x01\n\x05\x46rame\x12(\n\x04text\x18\x01 \x01(\x0b\x32\x18.dailyai_proto.TextFrameH\x00\x12*\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x19.dailyai_proto.AudioFrameH\x00\x12:\n\rtranscription\x18\x03 \x01(\x0b\x32!.dailyai_proto.TranscriptionFrameH\x00\x42\x07\n\x05\x66rameb\x06proto3') + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x66rames.proto\x12\rdailyai_proto\"\x19\n\tTextFrame\x12\x0c\n\x04text\x18\x01 \x01(\t\"\x1a\n\nAudioFrame\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"L\n\x12TranscriptionFrame\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x15\n\rparticipantId\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\t\"\xa2\x01\n\x05\x46rame\x12(\n\x04text\x18\x01 \x01(\x0b\x32\x18.dailyai_proto.TextFrameH\x00\x12*\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x19.dailyai_proto.AudioFrameH\x00\x12:\n\rtranscription\x18\x03 \x01(\x0b\x32!.dailyai_proto.TranscriptionFrameH\x00\x42\x07\n\x05\x66rameb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'frames_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _globals['_TEXTFRAME']._serialized_start = 31 - _globals['_TEXTFRAME']._serialized_end = 56 - _globals['_AUDIOFRAME']._serialized_start = 58 - _globals['_AUDIOFRAME']._serialized_end = 85 - _globals['_TRANSCRIPTIONFRAME']._serialized_start = 87 - _globals['_TRANSCRIPTIONFRAME']._serialized_end = 164 - _globals['_FRAME']._serialized_start = 167 - _globals['_FRAME']._serialized_end = 329 + DESCRIPTOR._options = None + _globals['_TEXTFRAME']._serialized_start=31 + _globals['_TEXTFRAME']._serialized_end=56 + _globals['_AUDIOFRAME']._serialized_start=58 + _globals['_AUDIOFRAME']._serialized_end=84 + _globals['_TRANSCRIPTIONFRAME']._serialized_start=86 + _globals['_TRANSCRIPTIONFRAME']._serialized_end=162 + _globals['_FRAME']._serialized_start=165 + _globals['_FRAME']._serialized_end=327 # @@protoc_insertion_point(module_scope) diff --git a/src/dailyai/serializers/protobuf_serializer.py b/src/dailyai/serializers/protobuf_serializer.py index 68332b291..8c00c9c94 100644 --- a/src/dailyai/serializers/protobuf_serializer.py +++ b/src/dailyai/serializers/protobuf_serializer.py @@ -1,14 +1,34 @@ +import dataclasses +from typing import Text from dailyai.pipeline.frames import AudioFrame, Frame, TextFrame, TranscriptionQueueFrame import dailyai.pipeline.protobufs.frames_pb2 as frame_protos from dailyai.serializers.abstract_frame_serializer import FrameSerializer class ProtobufFrameSerializer(FrameSerializer): + SERIALIZABLE_TYPES = { + TextFrame: "text", + AudioFrame: "audio", + TranscriptionQueueFrame: "transcription" + } + + SERIALIZABLE_FIELDS = {v: k for k, v in SERIALIZABLE_TYPES.items()} + def __init__(self): pass - def serialize(self, frame: Frame) -> bytes: - return frame.to_proto().SerializeToString() + def serialize(self, frame: TextFrame | AudioFrame | TranscriptionQueueFrame) -> bytes: + proto_frame = frame_protos.Frame() + if type(frame) not in self.SERIALIZABLE_TYPES: + raise ValueError( + f"Frame type {type(frame)} is not serializable. You may need to add it to ProtobufFrameSerializer.SERIALIZABLE_FIELDS.") + + proto_optional_name = self.SERIALIZABLE_TYPES[type(frame)] + for field in dataclasses.fields(frame): + setattr(getattr(proto_frame, proto_optional_name), field.name, + getattr(frame, field.name)) + + return proto_frame.SerializeToString() def deserialize(self, data: bytes) -> Frame: """Returns a Frame object from a Frame protobuf. Used to convert frames @@ -30,12 +50,14 @@ def deserialize(self, data: bytes) -> Frame: """ proto = frame_protos.Frame.FromString(data) - if proto.WhichOneof("frame") == "audio": - return AudioFrame.from_proto(proto) - elif proto.WhichOneof("frame") == "text": - return TextFrame.from_proto(proto) - elif proto.WhichOneof("frame") == "transcription": - return TranscriptionQueueFrame.from_proto(proto) - else: + which = proto.WhichOneof("frame") + if which not in self.SERIALIZABLE_FIELDS: raise ValueError( "Proto does not contain a valid frame. You may need to add a new case to ProtobufFrameSerializer.deserialize.") + + class_name = self.SERIALIZABLE_FIELDS[which] + args = getattr(proto, which) + args_dict = {} + for field in proto.DESCRIPTOR.fields_by_name[which].message_type.fields: + args_dict[field.name] = getattr(args, field.name) + return class_name(**args_dict) diff --git a/tests/test_protobuf_serializer.py b/tests/test_protobuf_serializer.py new file mode 100644 index 000000000..9d49fdff5 --- /dev/null +++ b/tests/test_protobuf_serializer.py @@ -0,0 +1,30 @@ +import unittest + +from dailyai.pipeline.frames import AudioFrame, TextFrame, TranscriptionQueueFrame +from dailyai.serializers.protobuf_serializer import ProtobufFrameSerializer + + +class TestProtobufFrameSerializer(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.serializer = ProtobufFrameSerializer() + + async def test_roundtrip(self): + text_frame = TextFrame(text='hello world') + frame = self.serializer.deserialize( + self.serializer.serialize(text_frame)) + self.assertEqual(frame, TextFrame(text='hello world')) + + transcription_frame = TranscriptionQueueFrame( + text="Hello there!", participantId="123", timestamp="2021-01-01") + frame = self.serializer.deserialize( + self.serializer.serialize(transcription_frame)) + self.assertEqual(frame, transcription_frame) + + audio_frame = AudioFrame(data=b'1234567890') + frame = self.serializer.deserialize( + self.serializer.serialize(audio_frame)) + self.assertEqual(frame, audio_frame) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_websocket_transport.py b/tests/test_websocket_transport.py index cbf728338..4b0d3d9b3 100644 --- a/tests/test_websocket_transport.py +++ b/tests/test_websocket_transport.py @@ -11,10 +11,13 @@ class TestWebSocketTransportService(unittest.IsolatedAsyncioTestCase): def setUp(self): self.transport = WebsocketTransport(host="localhost", port=8765) self.pipeline = Pipeline([]) + self.sample_frame = TextFrame("Hello there!") + self.serialized_sample_frame = self.transport._serializer.serialize( + self.sample_frame) async def queue_frame(self): await asyncio.sleep(0.1) - await self.pipeline.queue_frames([TextFrame("Hello there!"), EndFrame()]) + await self.pipeline.queue_frames([self.sample_frame, EndFrame()]) async def test_websocket_handler(self): mock_websocket = AsyncMock() @@ -27,8 +30,9 @@ async def test_websocket_handler(self): await asyncio.gather(self.transport.run(self.pipeline), self.queue_frame()) self.assertEqual(mock_websocket.send.call_count, 1) - self.assertEqual(mock_websocket.send.call_args[0][0], TextFrame( - "Hello there!").to_proto().SerializeToString()) + + self.assertEqual( + mock_websocket.send.call_args[0][0], self.serialized_sample_frame) async def test_on_connection_decorator(self): mock_websocket = AsyncMock() @@ -82,7 +86,7 @@ async def test_serializer_parameter(self): self.assertEqual(mock_websocket.send.call_count, 1) self.assertEqual( mock_websocket.send.call_args[0][0], - TextFrame("Hello there!").to_proto().SerializeToString(), + self.serialized_sample_frame, ) # Test with a mock serializer