From 0b6f3ab044b43b6042121374594806d539320a3d Mon Sep 17 00:00:00 2001 From: Moishe Lettvin Date: Sun, 24 Mar 2024 11:23:13 -0400 Subject: [PATCH] add audio frame aggregation to websocket transport --- .../foundational/websocket-server/sample.py | 74 +++++-------------- src/dailyai/pipeline/frames.py | 42 +++++------ src/dailyai/pipeline/pipeline.py | 9 ++- src/dailyai/services/ai_services.py | 22 +++--- .../services/websocket_transport_service.py | 55 ++++++++++++-- 5 files changed, 108 insertions(+), 94 deletions(-) diff --git a/examples/foundational/websocket-server/sample.py b/examples/foundational/websocket-server/sample.py index 70ad30b1d..24a93182d 100644 --- a/examples/foundational/websocket-server/sample.py +++ b/examples/foundational/websocket-server/sample.py @@ -3,85 +3,49 @@ import logging import os import wave - from dailyai.pipeline.frame_processor import FrameProcessor -from dailyai.pipeline.frames import AudioFrame, EndFrame, EndPipeFrame, TextFrame, TranscriptionQueueFrame +from dailyai.pipeline.frames import AudioFrame, EndFrame, EndPipeFrame, TTSEndFrame, TextFrame, TranscriptionQueueFrame from dailyai.pipeline.pipeline import Pipeline from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService from dailyai.services.websocket_transport_service import WebsocketTransport from dailyai.services.whisper_ai_services import WhisperSTTService -logging.basicConfig(format=f"%(levelno)s %(asctime)s %(message)s") +logging.basicConfig(format="%(levelno)s %(asctime)s %(message)s") logger = logging.getLogger("dailyai") logger.setLevel(logging.DEBUG) +class WhisperTranscriber(FrameProcessor): + async def process_frame(self, frame): + if isinstance(frame, TranscriptionQueueFrame): + print(f"Transcribed: {frame.text}") + else: + yield frame + + async def main(): async with aiohttp.ClientSession() as session: transport = WebsocketTransport( - mic_enabled=True, speaker_enabled=True, duration_minutes=120) - + mic_enabled=True, + speaker_enabled=True, + ) tts = ElevenLabsTTSService( aiohttp_session=session, api_key=os.getenv("ELEVENLABS_API_KEY"), voice_id=os.getenv("ELEVENLABS_VOICE_ID"), ) - class AudioWriter(FrameProcessor): - SIZE = 160000 - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._buffer = bytes() - self._counter = 0 - - async def process_frame(self, frame): - if isinstance(frame, AudioFrame): - self._buffer += frame.data - if len(self._buffer) >= AudioWriter.SIZE: - with wave.open(f"output-{self._counter}.wav", "wb") as f: - f.setnchannels(1) - f.setsampwidth(2) - f.setframerate(16000) - f.writeframes(self._buffer) - self._counter += 1 - self._buffer = self._buffer[AudioWriter.SIZE:] - yield frame - else: - yield frame - - class AudioChunker(FrameProcessor): - def __init__(self): - super().__init__() - self._buffer = bytes() - - async def process_frame(self, frame): - if isinstance(frame, AudioFrame): - self._buffer += frame.data - while len(self._buffer) >= 16000: - yield AudioFrame(self._buffer[:16000]) - self._buffer = self._buffer[16000:] - else: - yield frame - - class WhisperTranscriber(FrameProcessor): - async def process_frame(self, frame): - if isinstance(frame, TranscriptionQueueFrame): - print(f"Transcribed: {frame.text}") - else: - yield frame - - # pipeline = Pipeline([WriteToWav(), WhisperSTTService(), tts]) - pipeline = Pipeline( - [AudioWriter(), WhisperSTTService(), WhisperTranscriber(), tts, AudioChunker()]) + pipeline = Pipeline([ + WhisperSTTService(), + WhisperTranscriber(), + tts, + ]) @transport.on_connection async def queue_frame(): await pipeline.queue_frames([TextFrame("Hello there!")]) - await asyncio.gather(transport.run(pipeline)) - del tts - + await transport.run(pipeline) if __name__ == "__main__": asyncio.run(main()) diff --git a/src/dailyai/pipeline/frames.py b/src/dailyai/pipeline/frames.py index 511025d9b..b2721a05c 100644 --- a/src/dailyai/pipeline/frames.py +++ b/src/dailyai/pipeline/frames.py @@ -10,10 +10,18 @@ def __str__(self): return f"{self.__class__.__name__}" def to_proto(self): + """Converts a Frame object to a Frame protobuf. Used to send frames via the + websocket transport, or other clients that send Frames to their clients. + If you see this exception, you may need to implement a to_proto method on your + Frame subclass.""" raise NotImplementedError @staticmethod def from_proto(proto): + """Converts a Frame protobuf to a Frame object. Used to receive frames via the + websocket transport, or other clients that receive Frames from their clients. + If you see this exception, you may need to implement a from_proto method on your + Frame subclass, and add your class to FrameFromProto below.""" raise NotImplementedError @@ -191,6 +199,19 @@ def from_proto(proto: frame_protos.Frame): ) +class TTSStartFrame(ControlFrame): + """Used to indicate the beginning of a TTS response. Following AudioFrames + are part of the TTS response until an TTEndFrame. These frames can be used + for aggregating audio frames in a transport to optimize the size of frames + sent to the session, without needing to control this in the TTS service.""" + pass + + +class TTSEndFrame(ControlFrame): + """Indicates the end of a TTS response.""" + pass + + @dataclass() class LLMMessagesQueueFrame(Frame): """A frame containing a list of LLM messages. Used to signal that an LLM @@ -287,24 +308,3 @@ def FrameFromProto(proto: frame_protos.Frame) -> Frame: else: raise ValueError( "Proto does not contain a valid frame. You may need to add a new case to FrameFromProto.") - - -if __name__ == "__main__": - audio_frame = AudioFrame(data=b'1234567890') - print(audio_frame) - print(audio_frame.to_proto().SerializeToString()) - print(AudioFrame.from_proto(audio_frame.to_proto())) - - text_frame = TextFrame(text="Hello there!") - print(text_frame) - print(text_frame.to_proto().SerializeToString()) - serialized = text_frame.to_proto().SerializeToString() - print(type(serialized)) - print(frame_protos.Frame.FromString(serialized)) - print(TextFrame.from_proto(text_frame.to_proto())) - - transcripton_frame = TranscriptionQueueFrame( - text="Hello there!", participantId="123", timestamp="2021-01-01") - print(transcripton_frame) - print(transcripton_frame.to_proto().SerializeToString()) - print(TranscriptionQueueFrame.from_proto(transcripton_frame.to_proto())) diff --git a/src/dailyai/pipeline/pipeline.py b/src/dailyai/pipeline/pipeline.py index b055f3d81..fbb4488db 100644 --- a/src/dailyai/pipeline/pipeline.py +++ b/src/dailyai/pipeline/pipeline.py @@ -24,7 +24,7 @@ def __init__( queues. If this pipeline is run by a transport, its sink and source queues will be overridden. """ - self.processors: List[FrameProcessor] = processors + self._processors: List[FrameProcessor] = processors self.source: asyncio.Queue[Frame] = source or asyncio.Queue() self.sink: asyncio.Queue[Frame] = sink or asyncio.Queue() @@ -40,6 +40,9 @@ def set_sink(self, sink: asyncio.Queue[Frame]): has processed a frame, its output will be placed on this queue.""" self.sink = sink + def add_processor(self, processor: FrameProcessor): + self._processors.append(processor) + async def get_next_source_frame(self) -> AsyncGenerator[Frame, None]: """Convenience function to get the next frame from the source queue. This lets us consistently have an AsyncGenerator yield frames, from either the @@ -80,7 +83,7 @@ async def run_pipeline(self): while True: initial_frame = await self.source.get() async for frame in self._run_pipeline_recursively( - initial_frame, self.processors + initial_frame, self._processors ): await self.sink.put(frame) @@ -91,7 +94,7 @@ async def run_pipeline(self): except asyncio.CancelledError: # this means there's been an interruption, do any cleanup necessary # here. - for processor in self.processors: + for processor in self._processors: await processor.interrupted() pass diff --git a/src/dailyai/services/ai_services.py b/src/dailyai/services/ai_services.py index 5db115e9a..b91e2e7e4 100644 --- a/src/dailyai/services/ai_services.py +++ b/src/dailyai/services/ai_services.py @@ -10,6 +10,8 @@ EndPipeFrame, ImageFrame, Frame, + TTSEndFrame, + TTSStartFrame, TextFrame, TranscriptionQueueFrame, ) @@ -47,12 +49,18 @@ async def run_tts(self, text) -> AsyncGenerator[bytes, None]: # yield empty bytes here, so linting can infer what this method does yield bytes() + async def wrap_tts(self, text) -> AsyncGenerator[Frame, None]: + yield TTSStartFrame() + async for audio_chunk in self.run_tts(text): + yield AudioFrame(audio_chunk) + yield TTSEndFrame() + yield TextFrame(text) + async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: if isinstance(frame, EndFrame) or isinstance(frame, EndPipeFrame): if self.current_sentence: - async for audio_chunk in self.run_tts(self.current_sentence): - yield AudioFrame(audio_chunk) - yield TextFrame(self.current_sentence) + async for frame in self.wrap_tts(self.current_sentence): + yield frame if not isinstance(frame, TextFrame): yield frame @@ -68,12 +76,8 @@ async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: self.current_sentence = "" if text: - async for audio_chunk in self.run_tts(text): - yield AudioFrame(audio_chunk) - - # note we pass along the text frame *after* the audio, so the text - # frame is completed after the audio is processed. - yield TextFrame(text) + async for frame in self.wrap_tts(text): + yield frame class ImageGenService(AIService): diff --git a/src/dailyai/services/websocket_transport_service.py b/src/dailyai/services/websocket_transport_service.py index b4273ce0b..131aef796 100644 --- a/src/dailyai/services/websocket_transport_service.py +++ b/src/dailyai/services/websocket_transport_service.py @@ -1,14 +1,48 @@ import asyncio import time -import wave +from typing import AsyncGenerator, List import websockets -from dailyai.pipeline.frames import AudioFrame, EndFrame, FrameFromProto +from dailyai.pipeline.frame_processor import FrameProcessor +from dailyai.pipeline.frames import AudioFrame, EndFrame, Frame, FrameFromProto, TTSEndFrame, TTSStartFrame, TextFrame from dailyai.pipeline.pipeline import Pipeline from dailyai.services.base_transport_service import BaseTransportService import dailyai.pipeline.protobufs.frames_pb2 as frame_protos +class WebSocketFrameProcessor(FrameProcessor): + """This FrameProcessor filters and mutates frames before they're sent over the websocket. + This is necessary to aggregate audio frames into sizes that are cleanly playable by the client """ + + def __init__(self, audio_frame_size=16000, sendable_frames: List[Frame] | None = None): + super().__init__() + self._audio_frame_size = audio_frame_size + self._sendable_frames = sendable_frames or [ + TextFrame, AudioFrame] + self._audio_buffer = bytes() + self._in_tts_audio = False + + async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: + print(f"processing frame {frame}") + if isinstance(frame, TTSStartFrame): + self._in_tts_audio = True + elif isinstance(frame, AudioFrame): + if self._in_tts_audio: + self._audio_buffer += frame.data + if len(self._audio_buffer) >= self._audio_frame_size: + yield AudioFrame( + self._audio_buffer[:self._audio_frame_size]) + self._audio_buffer = self._audio_buffer[self._audio_frame_size:] + elif isinstance(frame, TTSEndFrame): + self._in_tts_audio = False + if self._audio_buffer: + yield AudioFrame( + self._audio_buffer) + self._audio_buffer = bytes() + elif type(frame) in self._sendable_frames: + yield frame + + class WebsocketTransport(BaseTransportService): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -16,6 +50,9 @@ def __init__(self, **kwargs): self._n_channels = kwargs.get("n_channels") or 1 self._port = kwargs.get("port") or 8765 self._host = kwargs.get("host") or "localhost" + self._audio_frame_size = kwargs.get("audio_frame_size") or 16000 + self._sendable_frames = kwargs.get("sendable_frames") or [ + TextFrame, AudioFrame, TTSEndFrame, TTSStartFrame] if self._camera_enabled: raise ValueError( @@ -35,6 +72,10 @@ async def run(self, pipeline: Pipeline, override_pipeline_source_queue=True): if override_pipeline_source_queue: pipeline.set_source(self.receive_queue) + pipeline.add_processor(WebSocketFrameProcessor( + audio_frame_size=self._audio_frame_size, + sendable_frames=self._sendable_frames)) + async def timeout_task(): sleep_time = self._expiration - time.time() await asyncio.sleep(sleep_time) @@ -46,15 +87,17 @@ async def send_task(): if isinstance(frame, EndFrame): self._stop_server_event.set() break - if self._websocket: - await self._websocket.send(frame.to_proto().SerializeToString()) + if self._websocket and frame: + print(f"sending frame {frame}") + proto = frame.to_proto().SerializeToString() + await self._websocket.send(proto) async def start_server() -> None: async with websockets.serve( self._websocket_handler, self._host, self._port) as server: - print("Server started") + self._logger.debug("Websocket server started.") await self._stop_server_event.wait() - print("Server stopped") + self._logger.debug("Websocket server stopped.") await self.receive_queue.put(EndFrame()) await asyncio.gather(start_server(), timeout_task(), send_task(), pipeline.run_pipeline())