From 35a2b3a040a005fa2b1356c8204cf5f43aff8899 Mon Sep 17 00:00:00 2001 From: Moishe Lettvin Date: Wed, 20 Mar 2024 12:12:26 -0400 Subject: [PATCH] getting started on websocket transport --- .github/workflows/lint.yaml | 18 ++- .../foundational/websocket-server/index.html | 142 ++++++++++++++++++ .../foundational/websocket-server/sample.py | 70 +++++++++ pyproject.toml | 3 +- src/dailyai/pipeline/frames.proto | 25 +++ src/dailyai/pipeline/frames.py | 132 +++++++++++++++- src/dailyai/pipeline/protobufs/frames_pb2.py | 30 ++++ src/dailyai/services/local_stt_service.py | 1 + .../services/websocket_transport_service.py | 87 +++++++++++ src/dailyai/services/whisper_ai_services.py | 1 + 10 files changed, 502 insertions(+), 7 deletions(-) create mode 100644 examples/foundational/websocket-server/index.html create mode 100644 examples/foundational/websocket-server/sample.py create mode 100644 src/dailyai/pipeline/frames.proto create mode 100644 src/dailyai/pipeline/protobufs/frames_pb2.py create mode 100644 src/dailyai/services/websocket_transport_service.py diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index e678623cb..652df99d2 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -22,11 +22,23 @@ jobs: steps: - name: Checkout repo uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Setup virtual environment + run: | + python -m venv .venv + - name: Install basic Python dependencies + run: | + source .venv/bin/activate + python -m pip install --upgrade pip + pip install -r requirements.txt - name: autopep8 id: autopep8 - uses: peter-evans/autopep8@v2 - with: - args: --exit-code -r -d -a -a src/ + run: | + source .venv/bin/activate + autopep8 --exit-code -r -d --exclude "*_pb2.py" -a -a src/ - name: Fail if autopep8 requires changes if: steps.autopep8.outputs.exit-code == 2 run: exit 1 diff --git a/examples/foundational/websocket-server/index.html b/examples/foundational/websocket-server/index.html new file mode 100644 index 000000000..51b409573 --- /dev/null +++ b/examples/foundational/websocket-server/index.html @@ -0,0 +1,142 @@ + + + + + + + + WebSocket Audio Stream + + + +

WebSocket Audio Stream

+ + + + + + diff --git a/examples/foundational/websocket-server/sample.py b/examples/foundational/websocket-server/sample.py new file mode 100644 index 000000000..a8d10d996 --- /dev/null +++ b/examples/foundational/websocket-server/sample.py @@ -0,0 +1,70 @@ +import asyncio +import aiohttp +import logging +import os +import wave + +from dailyai.pipeline.frame_processor import FrameProcessor +from dailyai.pipeline.frames import AudioFrame, EndFrame, EndPipeFrame, TextFrame +from dailyai.pipeline.pipeline import Pipeline +from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService +from dailyai.services.websocket_transport_service import WebsocketTransport +from dailyai.services.whisper_ai_services import WhisperSTTService + +logging.basicConfig(format=f"%(levelno)s %(asctime)s %(message)s") +logger = logging.getLogger("dailyai") +logger.setLevel(logging.DEBUG) + + +async def main(): + async with aiohttp.ClientSession() as session: + transport = WebsocketTransport( + mic_enabled=True, speaker_enabled=True) + + tts = ElevenLabsTTSService( + aiohttp_session=session, + api_key=os.getenv("ELEVENLABS_API_KEY"), + voice_id=os.getenv("ELEVENLABS_VOICE_ID"), + ) + + class WriteToWav(FrameProcessor): + async def process_frame(self, frame): + if isinstance(frame, AudioFrame): + with wave.open("output.wav", "wb") as f: + f.setnchannels(1) + f.setsampwidth(2) + f.setframerate(16000) + f.writeframes(frame.data) + yield frame + + class AudioChunker(FrameProcessor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._buffer = bytes() + + async def process_frame(self, frame): + if isinstance(frame, AudioFrame): + self._buffer += frame.data + if len(self._buffer) >= 1600: + yield AudioFrame(self._buffer[:1600]) + self._buffer = self._buffer[1600:] + elif isinstance(frame, EndFrame) or isinstance(frame, EndPipeFrame): + if self._buffer: + yield AudioFrame(self._buffer) + self._buffer = bytes() + else: + yield frame + + # pipeline = Pipeline([WriteToWav(), WhisperSTTService(), tts]) + pipeline = Pipeline([tts, AudioChunker()]) + + @transport.on_connection + async def queue_frame(): + await pipeline.queue_frames([TextFrame("Hello there!")]) + + await asyncio.gather(transport.run(pipeline)) + del tts + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index e1eac802b..a559076f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,8 @@ dependencies = [ "torch", "torchaudio", "pyaudio", - "typing-extensions" + "typing-extensions", + "websockets" ] [project.urls] diff --git a/src/dailyai/pipeline/frames.proto b/src/dailyai/pipeline/frames.proto new file mode 100644 index 000000000..7ecea6d25 --- /dev/null +++ b/src/dailyai/pipeline/frames.proto @@ -0,0 +1,25 @@ +syntax = "proto3"; + +package dailyai_proto; + +message TextFrame { + string text = 1; +} + +message AudioFrame { + bytes audio = 1; +} + +message TranscriptionFrame { + string text = 1; + string participant_id = 2; + string timestamp = 3; +} + +message Frame { + oneof frame { + TextFrame text = 1; + AudioFrame audio = 2; + TranscriptionFrame transcription = 3; + } +} diff --git a/src/dailyai/pipeline/frames.py b/src/dailyai/pipeline/frames.py index 6c39fb2c7..893a50aea 100644 --- a/src/dailyai/pipeline/frames.py +++ b/src/dailyai/pipeline/frames.py @@ -2,12 +2,20 @@ from typing import Any, List from dailyai.services.openai_llm_context import OpenAILLMContext +import dailyai.pipeline.protobufs.frames_pb2 as frame_protos class Frame: def __str__(self): return f"{self.__class__.__name__}" + def to_proto(self): + raise NotImplementedError + + @staticmethod + def from_proto(proto): + raise NotImplementedError + class ControlFrame(Frame): # Control frames should contain no instance data, so @@ -61,12 +69,36 @@ class LLMResponseEndFrame(ControlFrame): @dataclass() class AudioFrame(Frame): """A chunk of audio. Will be played by the transport if the transport's mic - has been enabled.""" + has been enabled. + + >>> str(AudioFrame(data=b'1234567890')) + 'AudioFrame, size: 10 B' + + >>> AudioFrame.from_proto(AudioFrame(data=b'1234567890').to_proto()) + AudioFrame(data=b'1234567890') + + The to_proto() function will always return a top-level Frame protobuf, so + it can be sent across the wire with other frames of various types. + + >>> type(AudioFrame(data=b'1234567890').to_proto()) + + """ data: bytes def __str__(self): return f"{self.__class__.__name__}, size: {len(self.data)} B" + def to_proto(self) -> frame_protos.Frame: + frame = frame_protos.Frame() + frame.audio.audio = self.data + return frame + + @staticmethod + def from_proto(proto: frame_protos.Frame): + if proto.WhichOneof("frame") != "audio": + raise ValueError("Proto does not contain an audio frame") + return AudioFrame(data=proto.audio.audio) + @dataclass() class ImageFrame(Frame): @@ -93,20 +125,71 @@ def __str__(self): @dataclass() class TextFrame(Frame): """A chunk of text. Emitted by LLM services, consumed by TTS services, can - be used to send text through pipelines.""" + be used to send text through pipelines. + + >>> str(TextFrame.from_proto(TextFrame(text='hello world').to_proto())) + 'TextFrame: "hello world"' + + The to_proto() function will always return a top-level Frame protobuf, so + it can be sent across the wire with other frames of various types. + + >>> type(TextFrame(text='hello world').to_proto()) + + """ text: str def __str__(self): return f'{self.__class__.__name__}: "{self.text}"' + def to_proto(self) -> frame_protos.Frame: + proto_frame = frame_protos.Frame() + proto_frame.text.text = self.text + return proto_frame + + @staticmethod + def from_proto(proto: frame_protos.TextFrame): + return TextFrame(text=proto.text.text) + @dataclass() class TranscriptionQueueFrame(TextFrame): """A text frame with transcription-specific data. Will be placed in the - transport's receive queue when a participant speaks.""" + transport's receive queue when a participant speaks. + + >>> transcription_frame = TranscriptionQueueFrame(text="Hello there!", participantId="123", timestamp="2021-01-01") + >>> transcription_frame + TranscriptionQueueFrame(text='Hello there!', participantId='123', timestamp='2021-01-01') + + >>> TranscriptionQueueFrame.from_proto(transcription_frame.to_proto()) + TranscriptionQueueFrame(text='Hello there!', participantId='123', timestamp='2021-01-01') + + The to_proto() function will always return a top-level Frame protobuf, so + it can be sent across the wire with other frames of various types. + + >>> type(transcription_frame.to_proto()) + + """ participantId: str timestamp: str + def __str__(self): + return f"{self.__class__.__name__}, text: '{self.text}' participantId: {self.participantId}, timestamp: {self.timestamp}" + + def to_proto(self) -> frame_protos.Frame: + frame = frame_protos.Frame() + frame.transcription.text = self.text + frame.transcription.participant_id = self.participantId + frame.transcription.timestamp = self.timestamp + return frame + + @staticmethod + def from_proto(proto: frame_protos.Frame): + return TranscriptionQueueFrame( + text=proto.transcription.text, + participantId=proto.transcription.participant_id, + timestamp=proto.transcription.timestamp + ) + @dataclass() class LLMMessagesQueueFrame(Frame): @@ -179,3 +262,46 @@ class LLMFunctionCallFrame(Frame): """Emitted when the LLM has received an entire function call completion.""" function_name: str arguments: str + + +def FrameFromProto(proto: frame_protos.Frame) -> Frame: + """Returns a Frame object from a Frame protobuf. Used to convert frames + passed over the wire as protobufs to Frame objects used in pipelines + and frame processors. + + >>> FrameFromProto(AudioFrame(data=b'1234567890').to_proto()) + AudioFrame(data=b'1234567890') + + >>> FrameFromProto(TextFrame(text='hello world').to_proto()) + TextFrame(text='hello world') + + >>> FrameFromProto(TranscriptionQueueFrame(text="Hello there!", participantId="123", timestamp="2021-01-01").to_proto()) + TranscriptionQueueFrame(text='Hello there!', participantId='123', timestamp='2021-01-01') + """ + if proto.WhichOneof("frame") == "audio": + return AudioFrame.from_proto(proto) + elif proto.WhichOneof("frame") == "text": + return TextFrame.from_proto(proto) + elif proto.WhichOneof("frame") == "transcription": + return TranscriptionQueueFrame.from_proto(proto) + else: + raise ValueError( + "Proto does not contain a valid frame. You may need to add a new case to FrameFromProto.") + + +if __name__ == "__main__": + audio_frame = AudioFrame(data=b'1234567890') + print(audio_frame) + print(audio_frame.to_proto().SerializeToString()) + print(AudioFrame.from_proto(audio_frame.to_proto())) + + text_frame = TextFrame(text="Hello there!") + print(text_frame) + print(text_frame.to_proto().SerializeToString()) + print(TextFrame.from_proto(text_frame.to_proto())) + + transcripton_frame = TranscriptionQueueFrame( + text="Hello there!", participantId="123", timestamp="2021-01-01") + print(transcripton_frame) + print(transcripton_frame.to_proto().SerializeToString()) + print(TranscriptionQueueFrame.from_proto(transcripton_frame.to_proto())) diff --git a/src/dailyai/pipeline/protobufs/frames_pb2.py b/src/dailyai/pipeline/protobufs/frames_pb2.py new file mode 100644 index 000000000..b923f8fe3 --- /dev/null +++ b/src/dailyai/pipeline/protobufs/frames_pb2.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: frames.proto +# Protobuf Python Version: 4.25.3 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x66rames.proto\x12\rdailyai_proto\"\x19\n\tTextFrame\x12\x0c\n\x04text\x18\x01 \x01(\t\"\x1b\n\nAudioFrame\x12\r\n\x05\x61udio\x18\x01 \x01(\x0c\"M\n\x12TranscriptionFrame\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x16\n\x0eparticipant_id\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\t\"\xa2\x01\n\x05\x46rame\x12(\n\x04text\x18\x01 \x01(\x0b\x32\x18.dailyai_proto.TextFrameH\x00\x12*\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x19.dailyai_proto.AudioFrameH\x00\x12:\n\rtranscription\x18\x03 \x01(\x0b\x32!.dailyai_proto.TranscriptionFrameH\x00\x42\x07\n\x05\x66rameb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'frames_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_TEXTFRAME']._serialized_start = 31 + _globals['_TEXTFRAME']._serialized_end = 56 + _globals['_AUDIOFRAME']._serialized_start = 58 + _globals['_AUDIOFRAME']._serialized_end = 85 + _globals['_TRANSCRIPTIONFRAME']._serialized_start = 87 + _globals['_TRANSCRIPTIONFRAME']._serialized_end = 164 + _globals['_FRAME']._serialized_start = 167 + _globals['_FRAME']._serialized_end = 329 +# @@protoc_insertion_point(module_scope) diff --git a/src/dailyai/services/local_stt_service.py b/src/dailyai/services/local_stt_service.py index 727bafbc6..98d37dc10 100644 --- a/src/dailyai/services/local_stt_service.py +++ b/src/dailyai/services/local_stt_service.py @@ -42,6 +42,7 @@ def _new_wave(self): async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: """Processes a frame of audio data, either buffering or transcribing it.""" if not isinstance(frame, AudioFrame): + yield frame return data = frame.data diff --git a/src/dailyai/services/websocket_transport_service.py b/src/dailyai/services/websocket_transport_service.py new file mode 100644 index 000000000..53bd40958 --- /dev/null +++ b/src/dailyai/services/websocket_transport_service.py @@ -0,0 +1,87 @@ +import asyncio +import time +import wave +import websockets + +from dailyai.pipeline.frames import AudioFrame, EndFrame, FrameFromProto +from dailyai.pipeline.pipeline import Pipeline +from dailyai.services.base_transport_service import BaseTransportService +import dailyai.pipeline.protobufs.frames_pb2 as frame_protos + + +class WebsocketTransport(BaseTransportService): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._sample_width = kwargs.get("sample_width") or 2 + self._n_channels = kwargs.get("n_channels") or 1 + self._port = kwargs.get("port") or 8765 + self._host = kwargs.get("host") or "localhost" + + if self._camera_enabled: + raise ValueError( + "Camera is not supported in WebsocketTransportService") + + if self._speaker_enabled: + self._speaker_buffer_pending = bytearray() + + self._server: websockets.WebSocketServer | None = None + self._websocket: websockets.WebSocketServerProtocol | None = None + + self._connection_handlers = [] + + async def run(self, pipeline: Pipeline, override_pipeline_source_queue=True): + self._stop_server_event = asyncio.Event() + pipeline.set_sink(self.send_queue) + if override_pipeline_source_queue: + pipeline.set_source(self.receive_queue) + + async def timeout_task(): + sleep_time = self._expiration - time.time() + await asyncio.sleep(sleep_time) + self._stop_server_event.set() + + async def send_task(): + while not self._stop_server_event.is_set(): + frame = await self.send_queue.get() + if isinstance(frame, EndFrame): + self._stop_server_event.set() + break + if self._websocket: + await self._websocket.send(frame.to_proto().SerializeToString()) + + async def start_server() -> None: + async with websockets.serve( + self._websocket_handler, self._host, self._port) as server: + print("Server started") + await self._stop_server_event.wait() + print("Server stopped") + await self.receive_queue.put(EndFrame) + + await asyncio.gather(start_server(), timeout_task(), send_task(), pipeline.run_pipeline()) + + def on_connection(self, handler): + self._connection_handlers.append(handler) + + async def _websocket_handler(self, websocket: websockets.WebSocketServerProtocol, path): + if self._websocket: + await self._websocket.close() + self._logger.warning( + "Got another websocket connection; closing first.") + + for handler in self._connection_handlers: + await handler() + + self._websocket = websocket + async for message in websocket: + """ + generic_frame = frame_protos.Frame.ParseFromString(message) + frame = FrameFromProto(generic_frame) + if isinstance(frame, AudioFrame): + with wave.open("output.wav", "wb") as f: + f.setnchannels(1) + f.setsampwidth(2) + f.setframerate(16000) + f.writeframes(frame.data) + await self.receive_queue.put(frame) + """ + pass diff --git a/src/dailyai/services/whisper_ai_services.py b/src/dailyai/services/whisper_ai_services.py index cc657e6cf..2ed1c3563 100644 --- a/src/dailyai/services/whisper_ai_services.py +++ b/src/dailyai/services/whisper_ai_services.py @@ -52,4 +52,5 @@ async def run_stt(self, audio: BinaryIO) -> str: res: str = "" for segment in segments: res += f"{segment.text} " + print("Transcription: ", segment.text) return res