Skip to content

Commit

Permalink
Serializer class now handles protobuf serialization directly
Browse files Browse the repository at this point in the history
  • Loading branch information
Moishe committed Mar 25, 2024
1 parent 8ab7ff6 commit c6a7fa6
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 118 deletions.
3 changes: 2 additions & 1 deletion examples/foundational/10-wake-word.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ async def handle_transcriptions():
isa.run(
tma_out.run(
llm.run(
tma_in.run(ncf.run(tf.run(transport.get_receive_frames())))
tma_in.run(
ncf.run(tf.run(transport.get_receive_frames())))
)
)
),
Expand Down
4 changes: 2 additions & 2 deletions examples/foundational/websocket-server/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ <h1>WebSocket Audio Stream</h1>
const parsedFrame = frame.decode(new Uint8Array(arrayBuffer));
if (!parsedFrame?.audio) return false;

const frameCount = parsedFrame.audio.audio.length / 2;
const frameCount = parsedFrame.audio.data.length / 2;
const audioOutBuffer = audioContext.createBuffer(1, frameCount, SAMPLE_RATE);
const nowBuffering = audioOutBuffer.getChannelData(0);
const view = new Int16Array(parsedFrame.audio.audio.buffer);
const view = new Int16Array(parsedFrame.audio.data.buffer);

for (let i = 0; i < frameCount; i++) {
const word = view[i];
Expand Down
4 changes: 2 additions & 2 deletions src/dailyai/pipeline/frames.proto
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ message TextFrame {
}

message AudioFrame {
bytes audio = 1;
bytes data = 1;
}

message TranscriptionFrame {
string text = 1;
string participant_id = 2;
string participantId = 2;
string timestamp = 3;
}

Expand Down
93 changes: 3 additions & 90 deletions src/dailyai/pipeline/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,6 @@ class Frame:
def __str__(self):
return f"{self.__class__.__name__}"

def to_proto(self):
"""Converts a Frame object to a Frame protobuf. Used to send frames via the
websocket transport, or other clients that send Frames to their clients.
If you see this exception, you may need to implement a to_proto method on your
Frame subclass."""
raise NotImplementedError

@staticmethod
def from_proto(proto):
"""Converts a Frame protobuf to a Frame object. Used to receive frames via the
websocket transport, or other clients that receive Frames from their clients.
If you see this exception, you may need to implement a from_proto method on your
Frame subclass, and add your class to FrameFromProto below."""
raise NotImplementedError


class ControlFrame(Frame):
# Control frames should contain no instance data, so
Expand Down Expand Up @@ -77,36 +62,12 @@ class LLMResponseEndFrame(ControlFrame):
@dataclass()
class AudioFrame(Frame):
"""A chunk of audio. Will be played by the transport if the transport's mic
has been enabled.
>>> str(AudioFrame(data=b'1234567890'))
'AudioFrame, size: 10 B'
>>> AudioFrame.from_proto(AudioFrame(data=b'1234567890').to_proto())
AudioFrame(data=b'1234567890')
The to_proto() function will always return a top-level Frame protobuf, so
it can be sent across the wire with other frames of various types.
>>> type(AudioFrame(data=b'1234567890').to_proto())
<class 'frames_pb2.Frame'>
"""
has been enabled."""
data: bytes

def __str__(self):
return f"{self.__class__.__name__}, size: {len(self.data)} B"

def to_proto(self) -> frame_protos.Frame:
frame = frame_protos.Frame()
frame.audio.audio = self.data
return frame

@staticmethod
def from_proto(proto: frame_protos.Frame):
if proto.WhichOneof("frame") != "audio":
raise ValueError("Proto does not contain an audio frame")
return AudioFrame(data=proto.audio.audio)


@dataclass()
class ImageFrame(Frame):
Expand All @@ -133,71 +94,23 @@ def __str__(self):
@dataclass()
class TextFrame(Frame):
"""A chunk of text. Emitted by LLM services, consumed by TTS services, can
be used to send text through pipelines.
>>> str(TextFrame.from_proto(TextFrame(text='hello world').to_proto()))
'TextFrame: "hello world"'
The to_proto() function will always return a top-level Frame protobuf, so
it can be sent across the wire with other frames of various types.
>>> type(TextFrame(text='hello world').to_proto())
<class 'frames_pb2.Frame'>
"""
be used to send text through pipelines."""
text: str

def __str__(self):
return f'{self.__class__.__name__}: "{self.text}"'

def to_proto(self) -> frame_protos.Frame:
proto_frame = frame_protos.Frame()
proto_frame.text.text = self.text
return proto_frame

@staticmethod
def from_proto(proto: frame_protos.TextFrame):
return TextFrame(text=proto.text.text)


@dataclass()
class TranscriptionQueueFrame(TextFrame):
"""A text frame with transcription-specific data. Will be placed in the
transport's receive queue when a participant speaks.
>>> transcription_frame = TranscriptionQueueFrame(text="Hello there!", participantId="123", timestamp="2021-01-01")
>>> transcription_frame
TranscriptionQueueFrame(text='Hello there!', participantId='123', timestamp='2021-01-01')
>>> TranscriptionQueueFrame.from_proto(transcription_frame.to_proto())
TranscriptionQueueFrame(text='Hello there!', participantId='123', timestamp='2021-01-01')
The to_proto() function will always return a top-level Frame protobuf, so
it can be sent across the wire with other frames of various types.
>>> type(transcription_frame.to_proto())
<class 'frames_pb2.Frame'>
"""
transport's receive queue when a participant speaks."""
participantId: str
timestamp: str

def __str__(self):
return f"{self.__class__.__name__}, text: '{self.text}' participantId: {self.participantId}, timestamp: {self.timestamp}"

def to_proto(self) -> frame_protos.Frame:
frame = frame_protos.Frame()
frame.transcription.text = self.text
frame.transcription.participant_id = self.participantId
frame.transcription.timestamp = self.timestamp
return frame

@staticmethod
def from_proto(proto: frame_protos.Frame):
return TranscriptionQueueFrame(
text=proto.transcription.text,
participantId=proto.transcription.participant_id,
timestamp=proto.transcription.timestamp
)


class TTSStartFrame(ControlFrame):
"""Used to indicate the beginning of a TTS response. Following AudioFrames
Expand Down
22 changes: 12 additions & 10 deletions src/dailyai/pipeline/protobufs/frames_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 31 additions & 9 deletions src/dailyai/serializers/protobuf_serializer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,34 @@
import dataclasses
from typing import Text
from dailyai.pipeline.frames import AudioFrame, Frame, TextFrame, TranscriptionQueueFrame
import dailyai.pipeline.protobufs.frames_pb2 as frame_protos
from dailyai.serializers.abstract_frame_serializer import FrameSerializer


class ProtobufFrameSerializer(FrameSerializer):
SERIALIZABLE_TYPES = {
TextFrame: "text",
AudioFrame: "audio",
TranscriptionQueueFrame: "transcription"
}

SERIALIZABLE_FIELDS = {v: k for k, v in SERIALIZABLE_TYPES.items()}

def __init__(self):
pass

def serialize(self, frame: Frame) -> bytes:
return frame.to_proto().SerializeToString()
def serialize(self, frame: TextFrame | AudioFrame | TranscriptionQueueFrame) -> bytes:
proto_frame = frame_protos.Frame()
if type(frame) not in self.SERIALIZABLE_TYPES:
raise ValueError(
f"Frame type {type(frame)} is not serializable. You may need to add it to ProtobufFrameSerializer.SERIALIZABLE_FIELDS.")

proto_optional_name = self.SERIALIZABLE_TYPES[type(frame)]
for field in dataclasses.fields(frame):
setattr(getattr(proto_frame, proto_optional_name), field.name,
getattr(frame, field.name))

return proto_frame.SerializeToString()

def deserialize(self, data: bytes) -> Frame:
"""Returns a Frame object from a Frame protobuf. Used to convert frames
Expand All @@ -30,12 +50,14 @@ def deserialize(self, data: bytes) -> Frame:
"""

proto = frame_protos.Frame.FromString(data)
if proto.WhichOneof("frame") == "audio":
return AudioFrame.from_proto(proto)
elif proto.WhichOneof("frame") == "text":
return TextFrame.from_proto(proto)
elif proto.WhichOneof("frame") == "transcription":
return TranscriptionQueueFrame.from_proto(proto)
else:
which = proto.WhichOneof("frame")
if which not in self.SERIALIZABLE_FIELDS:
raise ValueError(
"Proto does not contain a valid frame. You may need to add a new case to ProtobufFrameSerializer.deserialize.")

class_name = self.SERIALIZABLE_FIELDS[which]
args = getattr(proto, which)
args_dict = {}
for field in proto.DESCRIPTOR.fields_by_name[which].message_type.fields:
args_dict[field.name] = getattr(args, field.name)
return class_name(**args_dict)
30 changes: 30 additions & 0 deletions tests/test_protobuf_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import unittest

from dailyai.pipeline.frames import AudioFrame, TextFrame, TranscriptionQueueFrame
from dailyai.serializers.protobuf_serializer import ProtobufFrameSerializer


class TestProtobufFrameSerializer(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.serializer = ProtobufFrameSerializer()

async def test_roundtrip(self):
text_frame = TextFrame(text='hello world')
frame = self.serializer.deserialize(
self.serializer.serialize(text_frame))
self.assertEqual(frame, TextFrame(text='hello world'))

transcription_frame = TranscriptionQueueFrame(
text="Hello there!", participantId="123", timestamp="2021-01-01")
frame = self.serializer.deserialize(
self.serializer.serialize(transcription_frame))
self.assertEqual(frame, transcription_frame)

audio_frame = AudioFrame(data=b'1234567890')
frame = self.serializer.deserialize(
self.serializer.serialize(audio_frame))
self.assertEqual(frame, audio_frame)


if __name__ == "__main__":
unittest.main()
12 changes: 8 additions & 4 deletions tests/test_websocket_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ class TestWebSocketTransportService(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.transport = WebsocketTransport(host="localhost", port=8765)
self.pipeline = Pipeline([])
self.sample_frame = TextFrame("Hello there!")
self.serialized_sample_frame = self.transport._serializer.serialize(
self.sample_frame)

async def queue_frame(self):
await asyncio.sleep(0.1)
await self.pipeline.queue_frames([TextFrame("Hello there!"), EndFrame()])
await self.pipeline.queue_frames([self.sample_frame, EndFrame()])

async def test_websocket_handler(self):
mock_websocket = AsyncMock()
Expand All @@ -27,8 +30,9 @@ async def test_websocket_handler(self):

await asyncio.gather(self.transport.run(self.pipeline), self.queue_frame())
self.assertEqual(mock_websocket.send.call_count, 1)
self.assertEqual(mock_websocket.send.call_args[0][0], TextFrame(
"Hello there!").to_proto().SerializeToString())

self.assertEqual(
mock_websocket.send.call_args[0][0], self.serialized_sample_frame)

async def test_on_connection_decorator(self):
mock_websocket = AsyncMock()
Expand Down Expand Up @@ -82,7 +86,7 @@ async def test_serializer_parameter(self):
self.assertEqual(mock_websocket.send.call_count, 1)
self.assertEqual(
mock_websocket.send.call_args[0][0],
TextFrame("Hello there!").to_proto().SerializeToString(),
self.serialized_sample_frame,
)

# Test with a mock serializer
Expand Down

0 comments on commit c6a7fa6

Please sign in to comment.