diff --git a/examples/foundational/10-wake-word.py b/examples/foundational/10-wake-word.py
index 00c3acdc6..cd82a59e3 100644
--- a/examples/foundational/10-wake-word.py
+++ b/examples/foundational/10-wake-word.py
@@ -172,7 +172,8 @@ async def handle_transcriptions():
isa.run(
tma_out.run(
llm.run(
- tma_in.run(ncf.run(tf.run(transport.get_receive_frames())))
+ tma_in.run(
+ ncf.run(tf.run(transport.get_receive_frames())))
)
)
),
diff --git a/examples/foundational/websocket-server/index.html b/examples/foundational/websocket-server/index.html
index eafb4d39b..77be13518 100644
--- a/examples/foundational/websocket-server/index.html
+++ b/examples/foundational/websocket-server/index.html
@@ -49,10 +49,10 @@
WebSocket Audio Stream
const parsedFrame = frame.decode(new Uint8Array(arrayBuffer));
if (!parsedFrame?.audio) return false;
- const frameCount = parsedFrame.audio.audio.length / 2;
+ const frameCount = parsedFrame.audio.data.length / 2;
const audioOutBuffer = audioContext.createBuffer(1, frameCount, SAMPLE_RATE);
const nowBuffering = audioOutBuffer.getChannelData(0);
- const view = new Int16Array(parsedFrame.audio.audio.buffer);
+ const view = new Int16Array(parsedFrame.audio.data.buffer);
for (let i = 0; i < frameCount; i++) {
const word = view[i];
diff --git a/src/dailyai/pipeline/frames.proto b/src/dailyai/pipeline/frames.proto
index 7ecea6d25..b19fbccbf 100644
--- a/src/dailyai/pipeline/frames.proto
+++ b/src/dailyai/pipeline/frames.proto
@@ -7,12 +7,12 @@ message TextFrame {
}
message AudioFrame {
- bytes audio = 1;
+ bytes data = 1;
}
message TranscriptionFrame {
string text = 1;
- string participant_id = 2;
+ string participantId = 2;
string timestamp = 3;
}
diff --git a/src/dailyai/pipeline/frames.py b/src/dailyai/pipeline/frames.py
index 6f9738cd4..c5fdcbbb4 100644
--- a/src/dailyai/pipeline/frames.py
+++ b/src/dailyai/pipeline/frames.py
@@ -9,21 +9,6 @@ class Frame:
def __str__(self):
return f"{self.__class__.__name__}"
- def to_proto(self):
- """Converts a Frame object to a Frame protobuf. Used to send frames via the
- websocket transport, or other clients that send Frames to their clients.
- If you see this exception, you may need to implement a to_proto method on your
- Frame subclass."""
- raise NotImplementedError
-
- @staticmethod
- def from_proto(proto):
- """Converts a Frame protobuf to a Frame object. Used to receive frames via the
- websocket transport, or other clients that receive Frames from their clients.
- If you see this exception, you may need to implement a from_proto method on your
- Frame subclass, and add your class to FrameFromProto below."""
- raise NotImplementedError
-
class ControlFrame(Frame):
# Control frames should contain no instance data, so
@@ -77,36 +62,12 @@ class LLMResponseEndFrame(ControlFrame):
@dataclass()
class AudioFrame(Frame):
"""A chunk of audio. Will be played by the transport if the transport's mic
- has been enabled.
-
- >>> str(AudioFrame(data=b'1234567890'))
- 'AudioFrame, size: 10 B'
-
- >>> AudioFrame.from_proto(AudioFrame(data=b'1234567890').to_proto())
- AudioFrame(data=b'1234567890')
-
- The to_proto() function will always return a top-level Frame protobuf, so
- it can be sent across the wire with other frames of various types.
-
- >>> type(AudioFrame(data=b'1234567890').to_proto())
-
- """
+ has been enabled."""
data: bytes
def __str__(self):
return f"{self.__class__.__name__}, size: {len(self.data)} B"
- def to_proto(self) -> frame_protos.Frame:
- frame = frame_protos.Frame()
- frame.audio.audio = self.data
- return frame
-
- @staticmethod
- def from_proto(proto: frame_protos.Frame):
- if proto.WhichOneof("frame") != "audio":
- raise ValueError("Proto does not contain an audio frame")
- return AudioFrame(data=proto.audio.audio)
-
@dataclass()
class ImageFrame(Frame):
@@ -133,71 +94,23 @@ def __str__(self):
@dataclass()
class TextFrame(Frame):
"""A chunk of text. Emitted by LLM services, consumed by TTS services, can
- be used to send text through pipelines.
-
- >>> str(TextFrame.from_proto(TextFrame(text='hello world').to_proto()))
- 'TextFrame: "hello world"'
-
- The to_proto() function will always return a top-level Frame protobuf, so
- it can be sent across the wire with other frames of various types.
-
- >>> type(TextFrame(text='hello world').to_proto())
-
- """
+ be used to send text through pipelines."""
text: str
def __str__(self):
return f'{self.__class__.__name__}: "{self.text}"'
- def to_proto(self) -> frame_protos.Frame:
- proto_frame = frame_protos.Frame()
- proto_frame.text.text = self.text
- return proto_frame
-
- @staticmethod
- def from_proto(proto: frame_protos.TextFrame):
- return TextFrame(text=proto.text.text)
-
@dataclass()
class TranscriptionQueueFrame(TextFrame):
"""A text frame with transcription-specific data. Will be placed in the
- transport's receive queue when a participant speaks.
-
- >>> transcription_frame = TranscriptionQueueFrame(text="Hello there!", participantId="123", timestamp="2021-01-01")
- >>> transcription_frame
- TranscriptionQueueFrame(text='Hello there!', participantId='123', timestamp='2021-01-01')
-
- >>> TranscriptionQueueFrame.from_proto(transcription_frame.to_proto())
- TranscriptionQueueFrame(text='Hello there!', participantId='123', timestamp='2021-01-01')
-
- The to_proto() function will always return a top-level Frame protobuf, so
- it can be sent across the wire with other frames of various types.
-
- >>> type(transcription_frame.to_proto())
-
- """
+ transport's receive queue when a participant speaks."""
participantId: str
timestamp: str
def __str__(self):
return f"{self.__class__.__name__}, text: '{self.text}' participantId: {self.participantId}, timestamp: {self.timestamp}"
- def to_proto(self) -> frame_protos.Frame:
- frame = frame_protos.Frame()
- frame.transcription.text = self.text
- frame.transcription.participant_id = self.participantId
- frame.transcription.timestamp = self.timestamp
- return frame
-
- @staticmethod
- def from_proto(proto: frame_protos.Frame):
- return TranscriptionQueueFrame(
- text=proto.transcription.text,
- participantId=proto.transcription.participant_id,
- timestamp=proto.transcription.timestamp
- )
-
class TTSStartFrame(ControlFrame):
"""Used to indicate the beginning of a TTS response. Following AudioFrames
diff --git a/src/dailyai/pipeline/protobufs/frames_pb2.py b/src/dailyai/pipeline/protobufs/frames_pb2.py
index b923f8fe3..ce71723d3 100644
--- a/src/dailyai/pipeline/protobufs/frames_pb2.py
+++ b/src/dailyai/pipeline/protobufs/frames_pb2.py
@@ -12,19 +12,21 @@
_sym_db = _symbol_database.Default()
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x66rames.proto\x12\rdailyai_proto\"\x19\n\tTextFrame\x12\x0c\n\x04text\x18\x01 \x01(\t\"\x1b\n\nAudioFrame\x12\r\n\x05\x61udio\x18\x01 \x01(\x0c\"M\n\x12TranscriptionFrame\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x16\n\x0eparticipant_id\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\t\"\xa2\x01\n\x05\x46rame\x12(\n\x04text\x18\x01 \x01(\x0b\x32\x18.dailyai_proto.TextFrameH\x00\x12*\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x19.dailyai_proto.AudioFrameH\x00\x12:\n\rtranscription\x18\x03 \x01(\x0b\x32!.dailyai_proto.TranscriptionFrameH\x00\x42\x07\n\x05\x66rameb\x06proto3')
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x66rames.proto\x12\rdailyai_proto\"\x19\n\tTextFrame\x12\x0c\n\x04text\x18\x01 \x01(\t\"\x1a\n\nAudioFrame\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"L\n\x12TranscriptionFrame\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x15\n\rparticipantId\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\t\"\xa2\x01\n\x05\x46rame\x12(\n\x04text\x18\x01 \x01(\x0b\x32\x18.dailyai_proto.TextFrameH\x00\x12*\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x19.dailyai_proto.AudioFrameH\x00\x12:\n\rtranscription\x18\x03 \x01(\x0b\x32!.dailyai_proto.TranscriptionFrameH\x00\x42\x07\n\x05\x66rameb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'frames_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
- DESCRIPTOR._options = None
- _globals['_TEXTFRAME']._serialized_start = 31
- _globals['_TEXTFRAME']._serialized_end = 56
- _globals['_AUDIOFRAME']._serialized_start = 58
- _globals['_AUDIOFRAME']._serialized_end = 85
- _globals['_TRANSCRIPTIONFRAME']._serialized_start = 87
- _globals['_TRANSCRIPTIONFRAME']._serialized_end = 164
- _globals['_FRAME']._serialized_start = 167
- _globals['_FRAME']._serialized_end = 329
+ DESCRIPTOR._options = None
+ _globals['_TEXTFRAME']._serialized_start=31
+ _globals['_TEXTFRAME']._serialized_end=56
+ _globals['_AUDIOFRAME']._serialized_start=58
+ _globals['_AUDIOFRAME']._serialized_end=84
+ _globals['_TRANSCRIPTIONFRAME']._serialized_start=86
+ _globals['_TRANSCRIPTIONFRAME']._serialized_end=162
+ _globals['_FRAME']._serialized_start=165
+ _globals['_FRAME']._serialized_end=327
# @@protoc_insertion_point(module_scope)
diff --git a/src/dailyai/serializers/protobuf_serializer.py b/src/dailyai/serializers/protobuf_serializer.py
index 68332b291..8c00c9c94 100644
--- a/src/dailyai/serializers/protobuf_serializer.py
+++ b/src/dailyai/serializers/protobuf_serializer.py
@@ -1,14 +1,34 @@
+import dataclasses
+from typing import Text
from dailyai.pipeline.frames import AudioFrame, Frame, TextFrame, TranscriptionQueueFrame
import dailyai.pipeline.protobufs.frames_pb2 as frame_protos
from dailyai.serializers.abstract_frame_serializer import FrameSerializer
class ProtobufFrameSerializer(FrameSerializer):
+ SERIALIZABLE_TYPES = {
+ TextFrame: "text",
+ AudioFrame: "audio",
+ TranscriptionQueueFrame: "transcription"
+ }
+
+ SERIALIZABLE_FIELDS = {v: k for k, v in SERIALIZABLE_TYPES.items()}
+
def __init__(self):
pass
- def serialize(self, frame: Frame) -> bytes:
- return frame.to_proto().SerializeToString()
+ def serialize(self, frame: TextFrame | AudioFrame | TranscriptionQueueFrame) -> bytes:
+ proto_frame = frame_protos.Frame()
+ if type(frame) not in self.SERIALIZABLE_TYPES:
+ raise ValueError(
+ f"Frame type {type(frame)} is not serializable. You may need to add it to ProtobufFrameSerializer.SERIALIZABLE_FIELDS.")
+
+ proto_optional_name = self.SERIALIZABLE_TYPES[type(frame)]
+ for field in dataclasses.fields(frame):
+ setattr(getattr(proto_frame, proto_optional_name), field.name,
+ getattr(frame, field.name))
+
+ return proto_frame.SerializeToString()
def deserialize(self, data: bytes) -> Frame:
"""Returns a Frame object from a Frame protobuf. Used to convert frames
@@ -30,12 +50,14 @@ def deserialize(self, data: bytes) -> Frame:
"""
proto = frame_protos.Frame.FromString(data)
- if proto.WhichOneof("frame") == "audio":
- return AudioFrame.from_proto(proto)
- elif proto.WhichOneof("frame") == "text":
- return TextFrame.from_proto(proto)
- elif proto.WhichOneof("frame") == "transcription":
- return TranscriptionQueueFrame.from_proto(proto)
- else:
+ which = proto.WhichOneof("frame")
+ if which not in self.SERIALIZABLE_FIELDS:
raise ValueError(
"Proto does not contain a valid frame. You may need to add a new case to ProtobufFrameSerializer.deserialize.")
+
+ class_name = self.SERIALIZABLE_FIELDS[which]
+ args = getattr(proto, which)
+ args_dict = {}
+ for field in proto.DESCRIPTOR.fields_by_name[which].message_type.fields:
+ args_dict[field.name] = getattr(args, field.name)
+ return class_name(**args_dict)
diff --git a/tests/test_protobuf_serializer.py b/tests/test_protobuf_serializer.py
new file mode 100644
index 000000000..9d49fdff5
--- /dev/null
+++ b/tests/test_protobuf_serializer.py
@@ -0,0 +1,30 @@
+import unittest
+
+from dailyai.pipeline.frames import AudioFrame, TextFrame, TranscriptionQueueFrame
+from dailyai.serializers.protobuf_serializer import ProtobufFrameSerializer
+
+
+class TestProtobufFrameSerializer(unittest.IsolatedAsyncioTestCase):
+ def setUp(self):
+ self.serializer = ProtobufFrameSerializer()
+
+ async def test_roundtrip(self):
+ text_frame = TextFrame(text='hello world')
+ frame = self.serializer.deserialize(
+ self.serializer.serialize(text_frame))
+ self.assertEqual(frame, TextFrame(text='hello world'))
+
+ transcription_frame = TranscriptionQueueFrame(
+ text="Hello there!", participantId="123", timestamp="2021-01-01")
+ frame = self.serializer.deserialize(
+ self.serializer.serialize(transcription_frame))
+ self.assertEqual(frame, transcription_frame)
+
+ audio_frame = AudioFrame(data=b'1234567890')
+ frame = self.serializer.deserialize(
+ self.serializer.serialize(audio_frame))
+ self.assertEqual(frame, audio_frame)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_websocket_transport.py b/tests/test_websocket_transport.py
index cbf728338..4b0d3d9b3 100644
--- a/tests/test_websocket_transport.py
+++ b/tests/test_websocket_transport.py
@@ -11,10 +11,13 @@ class TestWebSocketTransportService(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.transport = WebsocketTransport(host="localhost", port=8765)
self.pipeline = Pipeline([])
+ self.sample_frame = TextFrame("Hello there!")
+ self.serialized_sample_frame = self.transport._serializer.serialize(
+ self.sample_frame)
async def queue_frame(self):
await asyncio.sleep(0.1)
- await self.pipeline.queue_frames([TextFrame("Hello there!"), EndFrame()])
+ await self.pipeline.queue_frames([self.sample_frame, EndFrame()])
async def test_websocket_handler(self):
mock_websocket = AsyncMock()
@@ -27,8 +30,9 @@ async def test_websocket_handler(self):
await asyncio.gather(self.transport.run(self.pipeline), self.queue_frame())
self.assertEqual(mock_websocket.send.call_count, 1)
- self.assertEqual(mock_websocket.send.call_args[0][0], TextFrame(
- "Hello there!").to_proto().SerializeToString())
+
+ self.assertEqual(
+ mock_websocket.send.call_args[0][0], self.serialized_sample_frame)
async def test_on_connection_decorator(self):
mock_websocket = AsyncMock()
@@ -82,7 +86,7 @@ async def test_serializer_parameter(self):
self.assertEqual(mock_websocket.send.call_count, 1)
self.assertEqual(
mock_websocket.send.call_args[0][0],
- TextFrame("Hello there!").to_proto().SerializeToString(),
+ self.serialized_sample_frame,
)
# Test with a mock serializer