Skip to content

Commit

Permalink
add audio frame aggregation to websocket transport
Browse files Browse the repository at this point in the history
  • Loading branch information
Moishe committed Mar 24, 2024
1 parent 89a4fc9 commit 0b6f3ab
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 94 deletions.
74 changes: 19 additions & 55 deletions examples/foundational/websocket-server/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,85 +3,49 @@
import logging
import os
import wave

from dailyai.pipeline.frame_processor import FrameProcessor
from dailyai.pipeline.frames import AudioFrame, EndFrame, EndPipeFrame, TextFrame, TranscriptionQueueFrame
from dailyai.pipeline.frames import AudioFrame, EndFrame, EndPipeFrame, TTSEndFrame, TextFrame, TranscriptionQueueFrame
from dailyai.pipeline.pipeline import Pipeline
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
from dailyai.services.websocket_transport_service import WebsocketTransport
from dailyai.services.whisper_ai_services import WhisperSTTService

logging.basicConfig(format=f"%(levelno)s %(asctime)s %(message)s")
logging.basicConfig(format="%(levelno)s %(asctime)s %(message)s")
logger = logging.getLogger("dailyai")
logger.setLevel(logging.DEBUG)


class WhisperTranscriber(FrameProcessor):
async def process_frame(self, frame):
if isinstance(frame, TranscriptionQueueFrame):
print(f"Transcribed: {frame.text}")
else:
yield frame


async def main():
async with aiohttp.ClientSession() as session:
transport = WebsocketTransport(
mic_enabled=True, speaker_enabled=True, duration_minutes=120)

mic_enabled=True,
speaker_enabled=True,
)
tts = ElevenLabsTTSService(
aiohttp_session=session,
api_key=os.getenv("ELEVENLABS_API_KEY"),
voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
)

class AudioWriter(FrameProcessor):
SIZE = 160000

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._buffer = bytes()
self._counter = 0

async def process_frame(self, frame):
if isinstance(frame, AudioFrame):
self._buffer += frame.data
if len(self._buffer) >= AudioWriter.SIZE:
with wave.open(f"output-{self._counter}.wav", "wb") as f:
f.setnchannels(1)
f.setsampwidth(2)
f.setframerate(16000)
f.writeframes(self._buffer)
self._counter += 1
self._buffer = self._buffer[AudioWriter.SIZE:]
yield frame
else:
yield frame

class AudioChunker(FrameProcessor):
def __init__(self):
super().__init__()
self._buffer = bytes()

async def process_frame(self, frame):
if isinstance(frame, AudioFrame):
self._buffer += frame.data
while len(self._buffer) >= 16000:
yield AudioFrame(self._buffer[:16000])
self._buffer = self._buffer[16000:]
else:
yield frame

class WhisperTranscriber(FrameProcessor):
async def process_frame(self, frame):
if isinstance(frame, TranscriptionQueueFrame):
print(f"Transcribed: {frame.text}")
else:
yield frame

# pipeline = Pipeline([WriteToWav(), WhisperSTTService(), tts])
pipeline = Pipeline(
[AudioWriter(), WhisperSTTService(), WhisperTranscriber(), tts, AudioChunker()])
pipeline = Pipeline([
WhisperSTTService(),
WhisperTranscriber(),
tts,
])

@transport.on_connection
async def queue_frame():
await pipeline.queue_frames([TextFrame("Hello there!")])

await asyncio.gather(transport.run(pipeline))
del tts

await transport.run(pipeline)

if __name__ == "__main__":
asyncio.run(main())
42 changes: 21 additions & 21 deletions src/dailyai/pipeline/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,18 @@ def __str__(self):
return f"{self.__class__.__name__}"

def to_proto(self):
"""Converts a Frame object to a Frame protobuf. Used to send frames via the
websocket transport, or other clients that send Frames to their clients.
If you see this exception, you may need to implement a to_proto method on your
Frame subclass."""
raise NotImplementedError

@staticmethod
def from_proto(proto):
"""Converts a Frame protobuf to a Frame object. Used to receive frames via the
websocket transport, or other clients that receive Frames from their clients.
If you see this exception, you may need to implement a from_proto method on your
Frame subclass, and add your class to FrameFromProto below."""
raise NotImplementedError


Expand Down Expand Up @@ -191,6 +199,19 @@ def from_proto(proto: frame_protos.Frame):
)


class TTSStartFrame(ControlFrame):
"""Used to indicate the beginning of a TTS response. Following AudioFrames
are part of the TTS response until an TTEndFrame. These frames can be used
for aggregating audio frames in a transport to optimize the size of frames
sent to the session, without needing to control this in the TTS service."""
pass


class TTSEndFrame(ControlFrame):
"""Indicates the end of a TTS response."""
pass


@dataclass()
class LLMMessagesQueueFrame(Frame):
"""A frame containing a list of LLM messages. Used to signal that an LLM
Expand Down Expand Up @@ -287,24 +308,3 @@ def FrameFromProto(proto: frame_protos.Frame) -> Frame:
else:
raise ValueError(
"Proto does not contain a valid frame. You may need to add a new case to FrameFromProto.")


if __name__ == "__main__":
audio_frame = AudioFrame(data=b'1234567890')
print(audio_frame)
print(audio_frame.to_proto().SerializeToString())
print(AudioFrame.from_proto(audio_frame.to_proto()))

text_frame = TextFrame(text="Hello there!")
print(text_frame)
print(text_frame.to_proto().SerializeToString())
serialized = text_frame.to_proto().SerializeToString()
print(type(serialized))
print(frame_protos.Frame.FromString(serialized))
print(TextFrame.from_proto(text_frame.to_proto()))

transcripton_frame = TranscriptionQueueFrame(
text="Hello there!", participantId="123", timestamp="2021-01-01")
print(transcripton_frame)
print(transcripton_frame.to_proto().SerializeToString())
print(TranscriptionQueueFrame.from_proto(transcripton_frame.to_proto()))
9 changes: 6 additions & 3 deletions src/dailyai/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
queues. If this pipeline is run by a transport, its sink and source queues
will be overridden.
"""
self.processors: List[FrameProcessor] = processors
self._processors: List[FrameProcessor] = processors

self.source: asyncio.Queue[Frame] = source or asyncio.Queue()
self.sink: asyncio.Queue[Frame] = sink or asyncio.Queue()
Expand All @@ -40,6 +40,9 @@ def set_sink(self, sink: asyncio.Queue[Frame]):
has processed a frame, its output will be placed on this queue."""
self.sink = sink

def add_processor(self, processor: FrameProcessor):
self._processors.append(processor)

async def get_next_source_frame(self) -> AsyncGenerator[Frame, None]:
"""Convenience function to get the next frame from the source queue. This
lets us consistently have an AsyncGenerator yield frames, from either the
Expand Down Expand Up @@ -80,7 +83,7 @@ async def run_pipeline(self):
while True:
initial_frame = await self.source.get()
async for frame in self._run_pipeline_recursively(
initial_frame, self.processors
initial_frame, self._processors
):
await self.sink.put(frame)

Expand All @@ -91,7 +94,7 @@ async def run_pipeline(self):
except asyncio.CancelledError:
# this means there's been an interruption, do any cleanup necessary
# here.
for processor in self.processors:
for processor in self._processors:
await processor.interrupted()
pass

Expand Down
22 changes: 13 additions & 9 deletions src/dailyai/services/ai_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
EndPipeFrame,
ImageFrame,
Frame,
TTSEndFrame,
TTSStartFrame,
TextFrame,
TranscriptionQueueFrame,
)
Expand Down Expand Up @@ -47,12 +49,18 @@ async def run_tts(self, text) -> AsyncGenerator[bytes, None]:
# yield empty bytes here, so linting can infer what this method does
yield bytes()

async def wrap_tts(self, text) -> AsyncGenerator[Frame, None]:
yield TTSStartFrame()
async for audio_chunk in self.run_tts(text):
yield AudioFrame(audio_chunk)
yield TTSEndFrame()
yield TextFrame(text)

async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
if isinstance(frame, EndFrame) or isinstance(frame, EndPipeFrame):
if self.current_sentence:
async for audio_chunk in self.run_tts(self.current_sentence):
yield AudioFrame(audio_chunk)
yield TextFrame(self.current_sentence)
async for frame in self.wrap_tts(self.current_sentence):
yield frame

if not isinstance(frame, TextFrame):
yield frame
Expand All @@ -68,12 +76,8 @@ async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
self.current_sentence = ""

if text:
async for audio_chunk in self.run_tts(text):
yield AudioFrame(audio_chunk)

# note we pass along the text frame *after* the audio, so the text
# frame is completed after the audio is processed.
yield TextFrame(text)
async for frame in self.wrap_tts(text):
yield frame


class ImageGenService(AIService):
Expand Down
55 changes: 49 additions & 6 deletions src/dailyai/services/websocket_transport_service.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,58 @@
import asyncio
import time
import wave
from typing import AsyncGenerator, List
import websockets

from dailyai.pipeline.frames import AudioFrame, EndFrame, FrameFromProto
from dailyai.pipeline.frame_processor import FrameProcessor
from dailyai.pipeline.frames import AudioFrame, EndFrame, Frame, FrameFromProto, TTSEndFrame, TTSStartFrame, TextFrame
from dailyai.pipeline.pipeline import Pipeline
from dailyai.services.base_transport_service import BaseTransportService
import dailyai.pipeline.protobufs.frames_pb2 as frame_protos


class WebSocketFrameProcessor(FrameProcessor):
"""This FrameProcessor filters and mutates frames before they're sent over the websocket.
This is necessary to aggregate audio frames into sizes that are cleanly playable by the client """

def __init__(self, audio_frame_size=16000, sendable_frames: List[Frame] | None = None):
super().__init__()
self._audio_frame_size = audio_frame_size
self._sendable_frames = sendable_frames or [
TextFrame, AudioFrame]
self._audio_buffer = bytes()
self._in_tts_audio = False

async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
print(f"processing frame {frame}")
if isinstance(frame, TTSStartFrame):
self._in_tts_audio = True
elif isinstance(frame, AudioFrame):
if self._in_tts_audio:
self._audio_buffer += frame.data
if len(self._audio_buffer) >= self._audio_frame_size:
yield AudioFrame(
self._audio_buffer[:self._audio_frame_size])
self._audio_buffer = self._audio_buffer[self._audio_frame_size:]
elif isinstance(frame, TTSEndFrame):
self._in_tts_audio = False
if self._audio_buffer:
yield AudioFrame(
self._audio_buffer)
self._audio_buffer = bytes()
elif type(frame) in self._sendable_frames:
yield frame


class WebsocketTransport(BaseTransportService):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._sample_width = kwargs.get("sample_width") or 2
self._n_channels = kwargs.get("n_channels") or 1
self._port = kwargs.get("port") or 8765
self._host = kwargs.get("host") or "localhost"
self._audio_frame_size = kwargs.get("audio_frame_size") or 16000
self._sendable_frames = kwargs.get("sendable_frames") or [
TextFrame, AudioFrame, TTSEndFrame, TTSStartFrame]

if self._camera_enabled:
raise ValueError(
Expand All @@ -35,6 +72,10 @@ async def run(self, pipeline: Pipeline, override_pipeline_source_queue=True):
if override_pipeline_source_queue:
pipeline.set_source(self.receive_queue)

pipeline.add_processor(WebSocketFrameProcessor(
audio_frame_size=self._audio_frame_size,
sendable_frames=self._sendable_frames))

async def timeout_task():
sleep_time = self._expiration - time.time()
await asyncio.sleep(sleep_time)
Expand All @@ -46,15 +87,17 @@ async def send_task():
if isinstance(frame, EndFrame):
self._stop_server_event.set()
break
if self._websocket:
await self._websocket.send(frame.to_proto().SerializeToString())
if self._websocket and frame:
print(f"sending frame {frame}")
proto = frame.to_proto().SerializeToString()
await self._websocket.send(proto)

async def start_server() -> None:
async with websockets.serve(
self._websocket_handler, self._host, self._port) as server:
print("Server started")
self._logger.debug("Websocket server started.")
await self._stop_server_event.wait()
print("Server stopped")
self._logger.debug("Websocket server stopped.")
await self.receive_queue.put(EndFrame())

await asyncio.gather(start_server(), timeout_task(), send_task(), pipeline.run_pipeline())
Expand Down

0 comments on commit 0b6f3ab

Please sign in to comment.