diff --git a/CHANGELOG.md b/CHANGELOG.md index 0686d95b9..432eaf9bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added WebsocketServerTransport. This will create a websocket server and will + read messages coming from a client. The messages are serialized/deserialized + with protobufs. See `examples/websocket-server` for a detailed example. + - Added function calling (LLMService.register_function()). This will allow the LLM to call functions you have registered when needed. For example, if you register a function to get the weather in Los Angeles and ask the LLM about @@ -24,6 +28,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed an issue where `camera_out_enabled` would cause the highg CPU usage if no image was provided. +### Performance + +- Removed unnecessary audio input tasks. ## [0.0.24] - 2024-05-29 diff --git a/LICENSE b/LICENSE index b60f5327e..cd6220df2 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ BSD 2-Clause License -Copyright (c) 2024, Kwindla Hultman Kramer +Copyright (c) 2024, Daily Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/dev-requirements.txt b/dev-requirements.txt index 9e0d93cbe..2d7da9ca6 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,5 +1,6 @@ autopep8~=2.1.0 build~=1.2.1 +grpcio-tools~=1.62.2 pip-tools~=7.4.1 pytest~=8.2.0 setuptools~=69.5.1 diff --git a/examples/foundational/02-llm-say-one-thing.py b/examples/foundational/02-llm-say-one-thing.py index 7e12263dc..20756dcb6 100644 --- a/examples/foundational/02-llm-say-one-thing.py +++ b/examples/foundational/02-llm-say-one-thing.py @@ -44,7 +44,7 @@ async def main(room_url): llm = OpenAILLMService( api_key=os.getenv("OPENAI_API_KEY"), - model="gpt-4-turbo-preview") + model="gpt-4o") messages = [ { diff --git a/examples/foundational/05-sync-speech-and-image.py b/examples/foundational/05-sync-speech-and-image.py index 60dd50d07..f057c847c 100644 --- a/examples/foundational/05-sync-speech-and-image.py +++ b/examples/foundational/05-sync-speech-and-image.py @@ -93,7 +93,7 @@ async def main(room_url): llm = OpenAILLMService( api_key=os.getenv("OPENAI_API_KEY"), - model="gpt-4-turbo-preview") + model="gpt-4o") imagegen = FalImageGenService( params=FalImageGenService.InputParams( diff --git a/examples/foundational/05a-local-sync-speech-and-image.py b/examples/foundational/05a-local-sync-speech-and-image.py index bfbd453e2..d476754fb 100644 --- a/examples/foundational/05a-local-sync-speech-and-image.py +++ b/examples/foundational/05a-local-sync-speech-and-image.py @@ -76,7 +76,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): llm = OpenAILLMService( api_key=os.getenv("OPENAI_API_KEY"), - model="gpt-4-turbo-preview") + model="gpt-4o") tts = ElevenLabsTTSService( aiohttp_session=session, diff --git a/examples/foundational/06a-image-sync.py b/examples/foundational/06a-image-sync.py index 3ec2752b4..2f5528ee4 100644 --- a/examples/foundational/06a-image-sync.py +++ b/examples/foundational/06a-image-sync.py @@ -81,7 +81,7 @@ async def main(room_url: str, token): llm = OpenAILLMService( api_key=os.getenv("OPENAI_API_KEY"), - model="gpt-4-turbo-preview") + model="gpt-4o") messages = [ { diff --git a/examples/foundational/07-interruptible.py b/examples/foundational/07-interruptible.py index ce37344f0..9ed146774 100644 --- a/examples/foundational/07-interruptible.py +++ b/examples/foundational/07-interruptible.py @@ -53,7 +53,7 @@ async def main(room_url: str, token): llm = OpenAILLMService( api_key=os.getenv("OPENAI_API_KEY"), - model="gpt-4-turbo-preview") + model="gpt-4o") messages = [ { diff --git a/examples/foundational/07c-interruptible-deepgram.py b/examples/foundational/07c-interruptible-deepgram.py index 27245b02b..818c8bc93 100644 --- a/examples/foundational/07c-interruptible-deepgram.py +++ b/examples/foundational/07c-interruptible-deepgram.py @@ -53,7 +53,7 @@ async def main(room_url: str, token): llm = OpenAILLMService( api_key=os.getenv("OPENAI_API_KEY"), - model="gpt-4-turbo-preview") + model="gpt-4o") messages = [ { diff --git a/examples/foundational/11-sound-effects.py b/examples/foundational/11-sound-effects.py index 1ca568bf0..2a3e8effc 100644 --- a/examples/foundational/11-sound-effects.py +++ b/examples/foundational/11-sound-effects.py @@ -95,7 +95,7 @@ async def main(room_url: str, token): llm = OpenAILLMService( api_key=os.getenv("OPENAI_API_KEY"), - model="gpt-4-turbo-preview") + model="gpt-4o") tts = ElevenLabsTTSService( aiohttp_session=session, diff --git a/examples/foundational/14-function-calling.py b/examples/foundational/14-function-calling.py index 14f834fe1..4a3a8b515 100644 --- a/examples/foundational/14-function-calling.py +++ b/examples/foundational/14-function-calling.py @@ -66,7 +66,7 @@ async def main(room_url: str, token): llm = OpenAILLMService( api_key=os.getenv("OPENAI_API_KEY"), - model="gpt-4-turbo-preview") + model="gpt-4o") llm.register_function( "get_current_weather", fetch_weather_from_api, diff --git a/examples/foundational/websocket-server/frames.proto b/examples/foundational/websocket-server/frames.proto deleted file mode 100644 index 830e3062c..000000000 --- a/examples/foundational/websocket-server/frames.proto +++ /dev/null @@ -1,25 +0,0 @@ -syntax = "proto3"; - -package pipecat_proto; - -message TextFrame { - string text = 1; -} - -message AudioFrame { - bytes audio = 1; -} - -message TranscriptionFrame { - string text = 1; - string participant_id = 2; - string timestamp = 3; -} - -message Frame { - oneof frame { - TextFrame text = 1; - AudioFrame audio = 2; - TranscriptionFrame transcription = 3; - } -} diff --git a/examples/foundational/websocket-server/index.html b/examples/foundational/websocket-server/index.html deleted file mode 100644 index a38e1e78b..000000000 --- a/examples/foundational/websocket-server/index.html +++ /dev/null @@ -1,134 +0,0 @@ - - - - - - - - WebSocket Audio Stream - - - -

WebSocket Audio Stream

- - - - - - diff --git a/examples/foundational/websocket-server/sample.py b/examples/foundational/websocket-server/sample.py deleted file mode 100644 index b3a4a731d..000000000 --- a/examples/foundational/websocket-server/sample.py +++ /dev/null @@ -1,50 +0,0 @@ -import asyncio -import aiohttp -import logging -import os -from pipecat.pipeline.frame_processor import FrameProcessor -from pipecat.pipeline.frames import TextFrame, TranscriptionFrame -from pipecat.pipeline.pipeline import Pipeline -from pipecat.services.elevenlabs_ai_services import ElevenLabsTTSService -from pipecat.transports.websocket_transport import WebsocketTransport -from pipecat.services.whisper_ai_services import WhisperSTTService - -logging.basicConfig(format="%(levelno)s %(asctime)s %(message)s") -logger = logging.getLogger("pipecat") -logger.setLevel(logging.DEBUG) - - -class WhisperTranscriber(FrameProcessor): - async def process_frame(self, frame): - if isinstance(frame, TranscriptionFrame): - print(f"Transcribed: {frame.text}") - else: - yield frame - - -async def main(): - async with aiohttp.ClientSession() as session: - transport = WebsocketTransport( - mic_enabled=True, - speaker_enabled=True, - ) - tts = ElevenLabsTTSService( - aiohttp_session=session, - api_key=os.getenv("ELEVENLABS_API_KEY"), - voice_id=os.getenv("ELEVENLABS_VOICE_ID"), - ) - - pipeline = Pipeline([ - WhisperSTTService(), - WhisperTranscriber(), - tts, - ]) - - @transport.on_connection - async def queue_frame(): - await pipeline.queue_frames([TextFrame("Hello there!")]) - - await transport.run(pipeline) - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/moondream-chatbot/bot.py b/examples/moondream-chatbot/bot.py index 7830cf46a..3e9ced259 100644 --- a/examples/moondream-chatbot/bot.py +++ b/examples/moondream-chatbot/bot.py @@ -145,7 +145,7 @@ async def main(room_url: str, token): llm = OpenAILLMService( api_key=os.getenv("OPENAI_API_KEY"), - model="gpt-4-turbo-preview") + model="gpt-4o") ta = TalkingAnimation() diff --git a/examples/simple-chatbot/bot.py b/examples/simple-chatbot/bot.py index a63b215ab..80b60833f 100644 --- a/examples/simple-chatbot/bot.py +++ b/examples/simple-chatbot/bot.py @@ -117,7 +117,7 @@ async def main(room_url: str, token): llm = OpenAILLMService( api_key=os.getenv("OPENAI_API_KEY"), - model="gpt-4-turbo-preview") + model="gpt-4o") messages = [ { diff --git a/examples/storytelling-chatbot/src/bot.py b/examples/storytelling-chatbot/src/bot.py index c5a75e949..96bb7626a 100644 --- a/examples/storytelling-chatbot/src/bot.py +++ b/examples/storytelling-chatbot/src/bot.py @@ -56,7 +56,7 @@ async def main(room_url, token=None): llm_service = OpenAILLMService( api_key=os.getenv("OPENAI_API_KEY"), - model="gpt-4-turbo" + model="gpt-4o" ) tts_service = ElevenLabsTTSService( diff --git a/examples/translation-chatbot/bot.py b/examples/translation-chatbot/bot.py index 89ca461b1..38667d897 100644 --- a/examples/translation-chatbot/bot.py +++ b/examples/translation-chatbot/bot.py @@ -97,7 +97,8 @@ async def main(room_url: str, token): ) llm = OpenAILLMService( - api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4-turbo-preview" + api_key=os.getenv("OPENAI_API_KEY"), + model="gpt-4o" ) sa = SentenceAggregator() diff --git a/examples/websocket-server/README.md b/examples/websocket-server/README.md new file mode 100644 index 000000000..8417f5a9d --- /dev/null +++ b/examples/websocket-server/README.md @@ -0,0 +1,27 @@ +# Websocket Server + +This is an example that shows how to use `WebsocketServerTransport` to communicate with a web client. + +## Get started + +```python +python3 -m venv venv +source venv/bin/activate +pip install -r requirements.txt +``` + +## Run the bot + +```bash +python bot.py +``` + +## Run the HTTP server + +This will host the static web client: + +```bash +python -m http.server +``` + +Then, visit `http://localhost:8000` in your browser to start a session. diff --git a/examples/websocket-server/bot.py b/examples/websocket-server/bot.py new file mode 100644 index 000000000..ba61084de --- /dev/null +++ b/examples/websocket-server/bot.py @@ -0,0 +1,94 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import aiohttp +import asyncio +import os +import sys + +from pipecat.frames.frames import LLMMessagesFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantResponseAggregator, + LLMUserResponseAggregator +) +from pipecat.services.elevenlabs import ElevenLabsTTSService +from pipecat.services.openai import OpenAILLMService +from pipecat.services.whisper import WhisperSTTService +from pipecat.transports.network.websocket_server import WebsocketServerParams, WebsocketServerTransport +from pipecat.vad.silero import SileroVADAnalyzer + +from loguru import logger + +from dotenv import load_dotenv +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +async def main(): + async with aiohttp.ClientSession() as session: + transport = WebsocketServerTransport( + params=WebsocketServerParams( + audio_in_enabled=True, + audio_out_enabled=True, + add_wav_header=True, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + vad_audio_passthrough=True + ) + ) + + llm = OpenAILLMService( + api_key=os.getenv("OPENAI_API_KEY"), + model="gpt-4o") + + stt = WhisperSTTService() + + tts = ElevenLabsTTSService( + aiohttp_session=session, + api_key=os.getenv("ELEVENLABS_API_KEY"), + voice_id=os.getenv("ELEVENLABS_VOICE_ID"), + ) + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.", + }, + ] + + tma_in = LLMUserResponseAggregator(messages) + tma_out = LLMAssistantResponseAggregator(messages) + + pipeline = Pipeline([ + transport.input(), # Websocket input from client + stt, # Speech-To-Text + tma_in, # User responses + llm, # LLM + tts, # Text-To-Speech + transport.output(), # Websocket output to client + tma_out # LLM responses + ]) + + task = PipelineTask(pipeline) + + @transport.event_handler("on_client_connected") + async def on_client_connected(transport, client): + # Kick off the conversation. + messages.append( + {"role": "system", "content": "Please introduce yourself to the user."}) + await task.queue_frames([LLMMessagesFrame(messages)]) + + runner = PipelineRunner() + + await runner.run(task) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/websocket-server/frames.proto b/examples/websocket-server/frames.proto new file mode 100644 index 000000000..5c5d81d4d --- /dev/null +++ b/examples/websocket-server/frames.proto @@ -0,0 +1,43 @@ +// +// Copyright (c) 2024, Daily +// +// SPDX-License-Identifier: BSD 2-Clause License +// + +// Generate frames_pb2.py with: +// +// python -m grpc_tools.protoc --proto_path=./ --python_out=./protobufs frames.proto + +syntax = "proto3"; + +package pipecat; + +message TextFrame { + uint64 id = 1; + string name = 2; + string text = 3; +} + +message AudioRawFrame { + uint64 id = 1; + string name = 2; + bytes audio = 3; + uint32 sample_rate = 4; + uint32 num_channels = 5; +} + +message TranscriptionFrame { + uint64 id = 1; + string name = 2; + string text = 3; + string user_id = 4; + string timestamp = 5; +} + +message Frame { + oneof frame { + TextFrame text = 1; + AudioRawFrame audio = 2; + TranscriptionFrame transcription = 3; + } +} diff --git a/examples/websocket-server/index.html b/examples/websocket-server/index.html new file mode 100644 index 000000000..514a4a821 --- /dev/null +++ b/examples/websocket-server/index.html @@ -0,0 +1,205 @@ + + + + + + + + Pipecat WebSocket Client Example + + + +

Pipecat WebSocket Client Example

+

Loading, wait...

+ + + + + + diff --git a/examples/websocket-server/requirements.txt b/examples/websocket-server/requirements.txt new file mode 100644 index 000000000..77e5b9e91 --- /dev/null +++ b/examples/websocket-server/requirements.txt @@ -0,0 +1,2 @@ +python-dotenv +pipecat-ai[openai,silero,websocket,whisper] diff --git a/linux-py3.10-requirements.txt b/linux-py3.10-requirements.txt index 204a0e932..d3687c17e 100644 --- a/linux-py3.10-requirements.txt +++ b/linux-py3.10-requirements.txt @@ -42,7 +42,7 @@ coloredlogs==15.0.1 # via onnxruntime ctranslate2==4.2.1 # via faster-whisper -daily-python==0.9.0 +daily-python==0.9.1 # via pipecat-ai (pyproject.toml) distro==1.9.0 # via @@ -226,6 +226,7 @@ protobuf==4.25.3 # googleapis-common-protos # grpcio-status # onnxruntime + # pipecat-ai (pyproject.toml) # proto-plus # pyht pyasn1==0.6.0 @@ -259,7 +260,7 @@ pyyaml==6.0.1 # transformers regex==2024.5.15 # via transformers -requests==2.32.2 +requests==2.32.3 # via # google-api-core # huggingface-hub diff --git a/macos-py3.10-requirements.txt b/macos-py3.10-requirements.txt index 35ddcd8b6..e6f91d0f9 100644 --- a/macos-py3.10-requirements.txt +++ b/macos-py3.10-requirements.txt @@ -208,6 +208,7 @@ protobuf==4.25.3 # googleapis-common-protos # grpcio-status # onnxruntime + # pipecat-ai (pyproject.toml) # proto-plus # pyht pyasn1==0.6.0 diff --git a/pyproject.toml b/pyproject.toml index 90363e34b..aa3558f87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "numpy~=1.26.4", "loguru~=0.7.0", "Pillow~=10.3.0", + "protobuf~=4.25.3", "pyloudnorm~=0.1.1", "typing-extensions~=4.11.0", ] diff --git a/src/pipecat/frames/frames.proto b/src/pipecat/frames/frames.proto index 18e59e492..5c5d81d4d 100644 --- a/src/pipecat/frames/frames.proto +++ b/src/pipecat/frames/frames.proto @@ -4,28 +4,40 @@ // SPDX-License-Identifier: BSD 2-Clause License // +// Generate frames_pb2.py with: +// +// python -m grpc_tools.protoc --proto_path=./ --python_out=./protobufs frames.proto + syntax = "proto3"; -package pipecat_proto; +package pipecat; message TextFrame { - string text = 1; + uint64 id = 1; + string name = 2; + string text = 3; } -message AudioFrame { - bytes data = 1; +message AudioRawFrame { + uint64 id = 1; + string name = 2; + bytes audio = 3; + uint32 sample_rate = 4; + uint32 num_channels = 5; } message TranscriptionFrame { - string text = 1; - string participantId = 2; - string timestamp = 3; + uint64 id = 1; + string name = 2; + string text = 3; + string user_id = 4; + string timestamp = 5; } message Frame { - oneof frame { - TextFrame text = 1; - AudioFrame audio = 2; - TranscriptionFrame transcription = 3; - } + oneof frame { + TextFrame text = 1; + AudioRawFrame audio = 2; + TranscriptionFrame transcription = 3; + } } diff --git a/src/pipecat/frames/protobufs/frames_pb2.py b/src/pipecat/frames/protobufs/frames_pb2.py index bdc34d385..5040efc97 100644 --- a/src/pipecat/frames/protobufs/frames_pb2.py +++ b/src/pipecat/frames/protobufs/frames_pb2.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: frames.proto -# Protobuf Python Version: 4.25.3 +# Protobuf Python Version: 4.25.1 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -14,19 +14,19 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x66rames.proto\x12\rpipecat_proto\"\x19\n\tTextFrame\x12\x0c\n\x04text\x18\x01 \x01(\t\"\x1a\n\nAudioFrame\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"L\n\x12TranscriptionFrame\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x15\n\rparticipantId\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\t\"\xa2\x01\n\x05\x46rame\x12(\n\x04text\x18\x01 \x01(\x0b\x32\x18.pipecat_proto.TextFrameH\x00\x12*\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x19.pipecat_proto.AudioFrameH\x00\x12:\n\rtranscription\x18\x03 \x01(\x0b\x32!.pipecat_proto.TranscriptionFrameH\x00\x42\x07\n\x05\x66rameb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x66rames.proto\x12\x07pipecat\"3\n\tTextFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\"c\n\rAudioRawFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05\x61udio\x18\x03 \x01(\x0c\x12\x13\n\x0bsample_rate\x18\x04 \x01(\r\x12\x14\n\x0cnum_channels\x18\x05 \x01(\r\"`\n\x12TranscriptionFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\x12\x0f\n\x07user_id\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\t\"\x93\x01\n\x05\x46rame\x12\"\n\x04text\x18\x01 \x01(\x0b\x32\x12.pipecat.TextFrameH\x00\x12\'\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x16.pipecat.AudioRawFrameH\x00\x12\x34\n\rtranscription\x18\x03 \x01(\x0b\x32\x1b.pipecat.TranscriptionFrameH\x00\x42\x07\n\x05\x66rameb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'frames_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_TEXTFRAME']._serialized_start=31 - _globals['_TEXTFRAME']._serialized_end=56 - _globals['_AUDIOFRAME']._serialized_start=58 - _globals['_AUDIOFRAME']._serialized_end=84 - _globals['_TRANSCRIPTIONFRAME']._serialized_start=86 - _globals['_TRANSCRIPTIONFRAME']._serialized_end=162 - _globals['_FRAME']._serialized_start=165 - _globals['_FRAME']._serialized_end=327 + _globals['_TEXTFRAME']._serialized_start=25 + _globals['_TEXTFRAME']._serialized_end=76 + _globals['_AUDIORAWFRAME']._serialized_start=78 + _globals['_AUDIORAWFRAME']._serialized_end=177 + _globals['_TRANSCRIPTIONFRAME']._serialized_start=179 + _globals['_TRANSCRIPTIONFRAME']._serialized_end=275 + _globals['_FRAME']._serialized_start=278 + _globals['_FRAME']._serialized_end=425 # @@protoc_insertion_point(module_scope) diff --git a/src/pipecat/pipeline/pipeline.py b/src/pipecat/pipeline/pipeline.py index 5d1f9d3cf..2ba061a37 100644 --- a/src/pipecat/pipeline/pipeline.py +++ b/src/pipecat/pipeline/pipeline.py @@ -67,7 +67,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): await self._sink.process_frame(frame, FrameDirection.UPSTREAM) async def _cleanup_processors(self): - await asyncio.gather(*[p.cleanup() for p in self._processors]) + for p in self._processors: + await p.cleanup() def _link_processors(self): prev = self._processors[0] diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index 3bb750218..793aa8eb1 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -5,7 +5,7 @@ # import asyncio -from asyncio import AbstractEventLoop + from enum import Enum from pipecat.frames.frames import ErrorFrame, Frame @@ -21,12 +21,12 @@ class FrameDirection(Enum): class FrameProcessor: - def __init__(self): + def __init__(self, loop: asyncio.AbstractEventLoop | None = None): self.id: int = obj_id() self.name = f"{self.__class__.__name__}#{obj_count(self)}" self._prev: "FrameProcessor" | None = None self._next: "FrameProcessor" | None = None - self._loop: AbstractEventLoop = asyncio.get_running_loop() + self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_running_loop() async def cleanup(self): pass @@ -36,7 +36,7 @@ def link(self, processor: 'FrameProcessor'): processor._prev = self logger.debug(f"Linking {self} -> {self._next}") - def get_event_loop(self) -> AbstractEventLoop: + def get_event_loop(self) -> asyncio.AbstractEventLoop: return self._loop async def process_frame(self, frame: Frame, direction: FrameDirection): diff --git a/src/pipecat/serializers/abstract_frame_serializer.py b/src/pipecat/serializers/abstract_frame_serializer.py deleted file mode 100644 index 8da0bd11d..000000000 --- a/src/pipecat/serializers/abstract_frame_serializer.py +++ /dev/null @@ -1,16 +0,0 @@ -from abc import abstractmethod - -from pipecat.pipeline.frames import Frame - - -class FrameSerializer: - def __init__(self): - pass - - @abstractmethod - def serialize(self, frame: Frame) -> bytes: - raise NotImplementedError - - @abstractmethod - def deserialize(self, data: bytes) -> Frame: - raise NotImplementedError diff --git a/src/pipecat/serializers/base_serializer.py b/src/pipecat/serializers/base_serializer.py new file mode 100644 index 000000000..c137f873d --- /dev/null +++ b/src/pipecat/serializers/base_serializer.py @@ -0,0 +1,20 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from abc import ABC, abstractmethod + +from pipecat.frames.frames import Frame + + +class FrameSerializer(ABC): + + @abstractmethod + def serialize(self, frame: Frame) -> bytes: + pass + + @abstractmethod + def deserialize(self, data: bytes) -> Frame: + pass diff --git a/src/pipecat/serializers/protobuf_serializer.py b/src/pipecat/serializers/protobuf.py similarity index 72% rename from src/pipecat/serializers/protobuf_serializer.py rename to src/pipecat/serializers/protobuf.py index 04b348b86..50692a51b 100644 --- a/src/pipecat/serializers/protobuf_serializer.py +++ b/src/pipecat/serializers/protobuf.py @@ -1,14 +1,21 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + import dataclasses -from typing import Text -from pipecat.pipeline.frames import AudioFrame, Frame, TextFrame, TranscriptionFrame -import pipecat.pipeline.protobufs.frames_pb2 as frame_protos -from pipecat.serializers.abstract_frame_serializer import FrameSerializer + +import pipecat.frames.protobufs.frames_pb2 as frame_protos + +from pipecat.frames.frames import AudioRawFrame, Frame, TextFrame, TranscriptionFrame +from pipecat.serializers.base_serializer import FrameSerializer class ProtobufFrameSerializer(FrameSerializer): SERIALIZABLE_TYPES = { TextFrame: "text", - AudioFrame: "audio", + AudioRawFrame: "audio", TranscriptionFrame: "transcription" } @@ -29,7 +36,8 @@ def serialize(self, frame: Frame) -> bytes: setattr(getattr(proto_frame, proto_optional_name), field.name, getattr(frame, field.name)) - return proto_frame.SerializeToString() + result = proto_frame.SerializeToString() + return result def deserialize(self, data: bytes) -> Frame: """Returns a Frame object from a Frame protobuf. Used to convert frames @@ -61,4 +69,22 @@ def deserialize(self, data: bytes) -> Frame: args_dict = {} for field in proto.DESCRIPTOR.fields_by_name[which].message_type.fields: args_dict[field.name] = getattr(args, field.name) - return class_name(**args_dict) + + # Remove special fields if needed + id = getattr(args, "id") + name = getattr(args, "name") + if not id: + del args_dict["id"] + if not name: + del args_dict["name"] + + # Create the instance + instance = class_name(**args_dict) + + # Set special fields + if id: + setattr(instance, "id", getattr(args, "id")) + if name: + setattr(instance, "name", getattr(args, "name")) + + return instance diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index a7f74ccc3..fc879d887 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -196,7 +196,7 @@ def __init__(self): super().__init__() # Renders the image. Returns an Image object. - @ abstractmethod + @abstractmethod async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: pass @@ -215,7 +215,7 @@ def __init__(self): super().__init__() self._describe_text = None - @ abstractmethod + @abstractmethod async def run_vision(self, frame: VisionImageRawFrame) -> AsyncGenerator[Frame, None]: pass diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index 96c855fa4..ac5fa1d99 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -229,7 +229,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): class OpenAILLMService(BaseOpenAILLMService): - def __init__(self, model="gpt-4", **kwargs): + def __init__(self, model="gpt-4o", **kwargs): super().__init__(model, **kwargs) diff --git a/src/pipecat/storage/__init__.py b/src/pipecat/storage/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/pipecat/storage/search.py b/src/pipecat/storage/search.py deleted file mode 100644 index 5740d1cae..000000000 --- a/src/pipecat/storage/search.py +++ /dev/null @@ -1,9 +0,0 @@ -class SearchIndexer(): - def __init__(self, story_id): - pass - - def index_text(self, text): - pass - - def index_image(self, text): - pass diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index 71733c59c..ba4f8ae4e 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -21,7 +21,7 @@ UserStartedSpeakingFrame, UserStoppedSpeakingFrame) from pipecat.transports.base_transport import TransportParams -from pipecat.vad.vad_analyzer import VADState +from pipecat.vad.vad_analyzer import VADAnalyzer, VADState from loguru import logger @@ -59,10 +59,7 @@ async def start(self, frame: StartFrame): if self._params.audio_in_enabled or self._params.vad_enabled: loop = self.get_event_loop() - self._audio_in_thread = loop.run_in_executor( - self._in_executor, self._audio_in_thread_handler) - self._audio_out_thread = loop.run_in_executor( - self._in_executor, self._audio_out_thread_handler) + self._audio_thread = loop.run_in_executor(self._in_executor, self._audio_thread_handler) async def stop(self): if not self._running: @@ -73,15 +70,14 @@ async def stop(self): # Wait for the threads to finish. if self._params.audio_in_enabled or self._params.vad_enabled: - await self._audio_in_thread - await self._audio_out_thread + await self._audio_thread self._push_frame_task.cancel() - def vad_analyze(self, audio_frames: bytes) -> VADState: - pass + def vad_analyzer(self) -> VADAnalyzer | None: + return self._params.vad_analyzer - def read_raw_audio_frames(self, frame_count: int) -> bytes: + def read_next_audio_frame(self) -> AudioRawFrame | None: pass # @@ -150,8 +146,15 @@ async def _handle_interruptions(self, frame: Frame): # Audio input # + def _vad_analyze(self, audio_frames: bytes) -> VADState: + state = VADState.QUIET + vad_analyzer = self.vad_analyzer() + if vad_analyzer: + state = vad_analyzer.analyze_audio(audio_frames) + return state + def _handle_vad(self, audio_frames: bytes, vad_state: VADState): - new_vad_state = self.vad_analyze(audio_frames) + new_vad_state = self._vad_analyze(audio_frames) if new_vad_state != vad_state and new_vad_state != VADState.STARTING and new_vad_state != VADState.STOPPING: frame = None if new_vad_state == VADState.SPEAKING: @@ -167,44 +170,25 @@ def _handle_vad(self, audio_frames: bytes, vad_state: VADState): vad_state = new_vad_state return vad_state - def _audio_in_thread_handler(self): - sample_rate = self._params.audio_in_sample_rate - num_channels = self._params.audio_in_channels - num_frames = int(sample_rate / 100) # 10ms of audio - while self._running: - try: - audio_frames = self.read_raw_audio_frames(num_frames) - if len(audio_frames) > 0: - frame = AudioRawFrame( - audio=audio_frames, - sample_rate=sample_rate, - num_channels=num_channels) - self._audio_in_queue.put(frame) - except BaseException as e: - logger.error(f"Error reading audio frames: {e}") - - def _audio_out_thread_handler(self): + def _audio_thread_handler(self): vad_state: VADState = VADState.QUIET while self._running: try: - frame = self._audio_in_queue.get(timeout=1) - - audio_passthrough = True - - # Check VAD and push event if necessary. We just care about changes - # from QUIET to SPEAKING and vice versa. - if self._params.vad_enabled: - vad_state = self._handle_vad(frame.audio, vad_state) - audio_passthrough = self._params.vad_audio_passthrough - - # Push audio downstream if passthrough. - if audio_passthrough: - future = asyncio.run_coroutine_threadsafe( - self._internal_push_frame(frame), self.get_event_loop()) - future.result() - - self._audio_in_queue.task_done() - except queue.Empty: - pass + frame = self.read_next_audio_frame() + + if frame: + audio_passthrough = True + + # Check VAD and push event if necessary. We just care about + # changes from QUIET to SPEAKING and vice versa. + if self._params.vad_enabled: + vad_state = self._handle_vad(frame.audio, vad_state) + audio_passthrough = self._params.vad_audio_passthrough + + # Push audio downstream if passthrough. + if audio_passthrough: + future = asyncio.run_coroutine_threadsafe( + self._internal_push_frame(frame), self.get_event_loop()) + future.result() except BaseException as e: - logger.error(f"Error pushing audio frames: {e}") + logger.error(f"Error reading audio frames: {e}") diff --git a/src/pipecat/transports/base_transport.py b/src/pipecat/transports/base_transport.py index 7f22d2c2c..7034b81eb 100644 --- a/src/pipecat/transports/base_transport.py +++ b/src/pipecat/transports/base_transport.py @@ -4,6 +4,9 @@ # SPDX-License-Identifier: BSD 2-Clause License # +import asyncio +import inspect + from abc import ABC, abstractmethod from pydantic import ConfigDict @@ -12,6 +15,8 @@ from pipecat.processors.frame_processor import FrameProcessor from pipecat.vad.vad_analyzer import VADAnalyzer +from loguru import logger + class TransportParams(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) @@ -36,6 +41,10 @@ class TransportParams(BaseModel): class BaseTransport(ABC): + def __init__(self, loop: asyncio.AbstractEventLoop | None): + self._loop = loop or asyncio.get_running_loop() + self._event_handlers: dict = {} + @abstractmethod def input(self) -> FrameProcessor: raise NotImplementedError @@ -43,3 +52,30 @@ def input(self) -> FrameProcessor: @abstractmethod def output(self) -> FrameProcessor: raise NotImplementedError + + def event_handler(self, event_name: str): + def decorator(handler): + self._add_event_handler(event_name, handler) + return handler + return decorator + + def _register_event_handler(self, event_name: str): + if event_name in self._event_handlers: + raise Exception(f"Event handler {event_name} already registered") + self._event_handlers[event_name] = [] + + def _add_event_handler(self, event_name: str, handler): + if event_name not in self._event_handlers: + raise Exception(f"Event handler {event_name} not registered") + self._event_handlers[event_name].append(handler) + + async def _call_event_handler(self, event_name: str, *args, **kwargs): + try: + for handler in self._event_handlers[event_name]: + if inspect.iscoroutinefunction(handler): + await handler(self, *args, **kwargs) + else: + handler(self, *args, **kwargs) + except Exception as e: + logger.error(f"Exception in event handler {event_name}: {e}") + raise e diff --git a/src/pipecat/transports/local/audio.py b/src/pipecat/transports/local/audio.py index 771715111..14b8bd5d3 100644 --- a/src/pipecat/transports/local/audio.py +++ b/src/pipecat/transports/local/audio.py @@ -6,7 +6,7 @@ import asyncio -from pipecat.frames.frames import StartFrame +from pipecat.frames.frames import AudioRawFrame, StartFrame from pipecat.processors.frame_processor import FrameProcessor from pipecat.transports.base_input import BaseInputTransport from pipecat.transports.base_output import BaseOutputTransport @@ -35,8 +35,14 @@ def __init__(self, py_audio: pyaudio.PyAudio, params: TransportParams): frames_per_buffer=params.audio_in_sample_rate, input=True) - def read_raw_audio_frames(self, frame_count: int) -> bytes: - return self._in_stream.read(frame_count, exception_on_overflow=False) + def read_next_audio_frame(self) -> AudioRawFrame | None: + sample_rate = self._params.audio_in_sample_rate + num_channels = self._params.audio_in_channels + num_frames = int(sample_rate / 100) # 10ms of audio + + audio = self._in_stream.read(num_frames, exception_on_overflow=False) + + return AudioRawFrame(audio=audio, sample_rate=sample_rate, num_channels=num_channels) async def start(self, frame: StartFrame): await super().start(frame) diff --git a/src/pipecat/transports/local/tk.py b/src/pipecat/transports/local/tk.py index 782c01dae..808837998 100644 --- a/src/pipecat/transports/local/tk.py +++ b/src/pipecat/transports/local/tk.py @@ -9,7 +9,7 @@ import numpy as np import tkinter as tk -from pipecat.frames.frames import ImageRawFrame, StartFrame +from pipecat.frames.frames import AudioRawFrame, ImageRawFrame, StartFrame from pipecat.processors.frame_processor import FrameProcessor from pipecat.transports.base_input import BaseInputTransport from pipecat.transports.base_output import BaseOutputTransport @@ -45,8 +45,14 @@ def __init__(self, py_audio: pyaudio.PyAudio, params: TransportParams): frames_per_buffer=params.audio_in_sample_rate, input=True) - def read_raw_audio_frames(self, frame_count: int) -> bytes: - return self._in_stream.read(frame_count, exception_on_overflow=False) + def read_next_audio_frame(self) -> AudioRawFrame | None: + sample_rate = self._params.audio_in_sample_rate + num_channels = self._params.audio_in_channels + num_frames = int(sample_rate / 100) # 10ms of audio + + audio = self._in_stream.read(num_frames, exception_on_overflow=False) + + return AudioRawFrame(audio=audio, sample_rate=sample_rate, num_channels=num_channels) async def start(self, frame: StartFrame): await super().start(frame) diff --git a/src/pipecat/transports/network/websocket_server.py b/src/pipecat/transports/network/websocket_server.py new file mode 100644 index 000000000..8456c4536 --- /dev/null +++ b/src/pipecat/transports/network/websocket_server.py @@ -0,0 +1,206 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + + +import asyncio +import io +import queue +import wave +import websockets + +from typing import Awaitable, Callable +from pydantic.main import BaseModel + +from pipecat.frames.frames import AudioRawFrame, StartFrame +from pipecat.processors.frame_processor import FrameProcessor +from pipecat.serializers.base_serializer import FrameSerializer +from pipecat.serializers.protobuf import ProtobufFrameSerializer +from pipecat.transports.base_input import BaseInputTransport +from pipecat.transports.base_output import BaseOutputTransport +from pipecat.transports.base_transport import BaseTransport, TransportParams + +from loguru import logger + + +class WebsocketServerParams(TransportParams): + add_wav_header: bool = False + audio_frame_size: int = 6400 # 200ms + serializer: FrameSerializer = ProtobufFrameSerializer() + + +class WebsocketServerCallbacks(BaseModel): + on_client_connected: Callable[[websockets.WebSocketServerProtocol], Awaitable[None]] + on_client_disconnected: Callable[[websockets.WebSocketServerProtocol], Awaitable[None]] + + +class WebsocketServerInputTransport(BaseInputTransport): + + def __init__( + self, + host: str, + port: int, + params: WebsocketServerParams, + callbacks: WebsocketServerCallbacks): + super().__init__(params) + + self._host = host + self._port = port + self._params = params + self._callbacks = callbacks + + self._websocket: websockets.WebSocketServerProtocol | None = None + + self._client_audio_queue = queue.Queue() + self._stop_server_event = asyncio.Event() + + async def start(self, frame: StartFrame): + self._server_task = self.get_event_loop().create_task(self._server_task_handler()) + await super().start(frame) + + async def stop(self): + self._stop_server_event.set() + await self._server_task + await super().stop() + + def read_next_audio_frame(self) -> AudioRawFrame | None: + try: + return self._client_audio_queue.get(timeout=1) + except queue.Empty: + return None + + async def _server_task_handler(self): + logger.info(f"Starting websocket server on {self._host}:{self._port}") + async with websockets.serve(self._client_handler, self._host, self._port) as server: + await self._stop_server_event.wait() + + async def _client_handler(self, websocket: websockets.WebSocketServerProtocol, path): + logger.info(f"New client connection from {websocket.remote_address}") + if self._websocket: + await self._websocket.close() + logger.warning("Only one client connected, using new connection") + + self._websocket = websocket + + # Notify + await self._callbacks.on_client_connected(websocket) + + # Handle incoming messages + async for message in websocket: + frame = self._params.serializer.deserialize(message) + if isinstance(frame, AudioRawFrame) and self._params.audio_in_enabled: + self._client_audio_queue.put_nowait(frame) + else: + await self._internal_push_frame(frame) + + # Notify disconnection + await self._callbacks.on_client_disconnected(websocket) + + await self._websocket.close() + self._websocket = None + + logger.info(f"Client {websocket.remote_address} disconnected") + + +class WebsocketServerOutputTransport(BaseOutputTransport): + + def __init__(self, params: WebsocketServerParams): + super().__init__(params) + + self._params = params + + self._websocket: websockets.WebSocketServerProtocol | None = None + + self._audio_buffer = bytes() + + async def set_client_connection(self, websocket: websockets.WebSocketServerProtocol | None): + if self._websocket: + await self._websocket.close() + logger.warning("Only one client allowed, using new connection") + self._websocket = websocket + + def write_raw_audio_frames(self, frames: bytes): + self._audio_buffer += frames + while len(self._audio_buffer) >= self._params.audio_frame_size: + frame = AudioRawFrame( + audio=self._audio_buffer[:self._params.audio_frame_size], + sample_rate=self._params.audio_out_sample_rate, + num_channels=self._params.audio_out_channels + ) + + if self._params.add_wav_header: + content = io.BytesIO() + ww = wave.open(content, "wb") + ww.setsampwidth(2) + ww.setnchannels(frame.num_channels) + ww.setframerate(frame.sample_rate) + ww.writeframes(frame.audio) + ww.close() + content.seek(0) + wav_frame = AudioRawFrame( + content.read(), + sample_rate=frame.sample_rate, + num_channels=frame.num_channels) + frame = wav_frame + + proto = self._params.serializer.serialize(frame) + + future = asyncio.run_coroutine_threadsafe( + self._websocket.send(proto), self.get_event_loop()) + future.result() + + self._audio_buffer = self._audio_buffer[self._params.audio_frame_size:] + + +class WebsocketServerTransport(BaseTransport): + + def __init__( + self, + host: str = "localhost", + port: int = 8765, + params: WebsocketServerParams = WebsocketServerParams(), + loop: asyncio.AbstractEventLoop | None = None): + super().__init__(loop) + self._host = host + self._port = port + self._params = params + + self._callbacks = WebsocketServerCallbacks( + on_client_connected=self._on_client_connected, + on_client_disconnected=self._on_client_disconnected + ) + self._input: WebsocketServerInputTransport | None = None + self._output: WebsocketServerOutputTransport | None = None + self._websocket: websockets.WebSocketServerProtocol | None = None + + # Register supported handlers. The user will only be able to register + # these handlers. + self._register_event_handler("on_client_connected") + self._register_event_handler("on_client_disconnected") + + def input(self) -> FrameProcessor: + if not self._input: + self._input = WebsocketServerInputTransport( + self._host, self._port, self._params, self._callbacks) + return self._input + + def output(self) -> FrameProcessor: + if not self._output: + self._output = WebsocketServerOutputTransport(self._params) + return self._output + + async def _on_client_connected(self, websocket): + if self._output: + await self._output.set_client_connection(websocket) + await self._call_event_handler("on_client_connected", websocket) + else: + logger.error("A WebsocketServerTransport output is missing in the pipeline") + + async def _on_client_disconnected(self, websocket): + if self._output: + await self._output.set_client_connection(None) + await self._call_event_handler("on_client_disconnected", websocket) + else: + logger.error("A WebsocketServerTransport output is missing in the pipeline") diff --git a/src/pipecat/transports/services/daily.py b/src/pipecat/transports/services/daily.py index e4c0a9762..f20eaa4dc 100644 --- a/src/pipecat/transports/services/daily.py +++ b/src/pipecat/transports/services/daily.py @@ -6,15 +6,12 @@ import aiohttp import asyncio -from concurrent.futures import ThreadPoolExecutor -import inspect import queue import time -import types from dataclasses import dataclass -from functools import partial from typing import Any, Callable, Mapping +from concurrent.futures import ThreadPoolExecutor from daily import ( CallClient, @@ -139,7 +136,8 @@ def __init__( token: str | None, bot_name: str, params: DailyParams, - callbacks: DailyCallbacks): + callbacks: DailyCallbacks, + loop: asyncio.AbstractEventLoop): super().__init__() if not self._daily_initialized: @@ -151,6 +149,7 @@ def __init__( self._bot_name: str = bot_name self._params: DailyParams = params self._callbacks = callbacks + self._loop = loop self._participant_id: str = "" self._video_renderers = {} @@ -189,15 +188,22 @@ def set_callbacks(self, callbacks: DailyCallbacks): def send_message(self, frame: DailyTransportMessageFrame): self._client.send_app_message(frame.message, frame.participant_id) - def read_raw_audio_frames(self, frame_count: int) -> bytes: + def read_next_audio_frame(self) -> AudioRawFrame | None: + sample_rate = self._params.audio_in_sample_rate + num_channels = self._params.audio_in_channels + if self._other_participant_has_joined: - return self._speaker.read_frames(frame_count) + num_frames = int(sample_rate / 100) # 10ms of audio + + audio = self._speaker.read_frames(num_frames) + + return AudioRawFrame(audio=audio, sample_rate=sample_rate, num_channels=num_channels) else: # If no one has ever joined the meeting `read_frames()` would block, # instead we just wait a bit. daily-python should probably return # silence instead. time.sleep(0.01) - return b'' + return None def write_raw_audio_frames(self, frames: bytes): self._mic.write_frames(frames) @@ -212,8 +218,7 @@ async def join(self): self._joining = True - loop = asyncio.get_running_loop() - await loop.run_in_executor(self._executor, self._join) + await self._loop.run_in_executor(self._executor, self._join) def _join(self): logger.info(f"Joining {self._room_url}") @@ -304,8 +309,7 @@ async def leave(self): self._joined = False self._leaving = True - loop = asyncio.get_running_loop() - await loop.run_in_executor(self._executor, self._leave) + await self._loop.run_in_executor(self._executor, self._leave) def _leave(self): logger.info(f"Leaving {self._room_url}") @@ -335,8 +339,7 @@ def _handle_leave_response(self): self._callbacks.on_error(error_msg) async def cleanup(self): - loop = asyncio.get_running_loop() - await loop.run_in_executor(self._executor, self._cleanup) + await self._loop.run_in_executor(self._executor, self._cleanup) def _cleanup(self): if self._client: @@ -471,7 +474,7 @@ def __init__(self, client: DailyTransportClient, params: DailyParams): self._video_renderers = {} self._camera_in_queue = queue.Queue() - self._vad_analyzer = params.vad_analyzer + self._vad_analyzer: VADAnalyzer | None = params.vad_analyzer if params.vad_enabled and not params.vad_analyzer: self._vad_analyzer = WebRTCVADAnalyzer( sample_rate=self._params.audio_in_sample_rate, @@ -485,8 +488,7 @@ async def start(self, frame: StartFrame): # This will set _running=True await super().start(frame) # Create camera in thread (runs if _running is true). - loop = asyncio.get_running_loop() - self._camera_in_thread = loop.run_in_executor( + self._camera_in_thread = self._loop.run_in_executor( self._in_executor, self._camera_in_thread_handler) async def stop(self): @@ -503,14 +505,11 @@ async def cleanup(self): await super().cleanup() await self._client.cleanup() - def vad_analyze(self, audio_frames: bytes) -> VADState: - state = VADState.QUIET - if self._vad_analyzer: - state = self._vad_analyzer.analyze_audio(audio_frames) - return state + def vad_analyzer(self) -> VADAnalyzer | None: + return self._vad_analyzer - def read_raw_audio_frames(self, frame_count: int) -> bytes: - return self._client.read_raw_audio_frames(frame_count) + def read_next_audio_frame(self) -> AudioRawFrame | None: + return self._client.read_next_audio_frame() # # FrameProcessor @@ -642,7 +641,15 @@ def write_frame_to_camera(self, frame: ImageRawFrame): class DailyTransport(BaseTransport): - def __init__(self, room_url: str, token: str | None, bot_name: str, params: DailyParams): + def __init__( + self, + room_url: str, + token: str | None, + bot_name: str, + params: DailyParams, + loop: asyncio.AbstractEventLoop | None = None): + super().__init__(loop) + callbacks = DailyCallbacks( on_joined=self._on_joined, on_left=self._on_left, @@ -660,12 +667,10 @@ def __init__(self, room_url: str, token: str | None, bot_name: str, params: Dail ) self._params = params - self._client = DailyTransportClient(room_url, token, bot_name, params, callbacks) + self._client = DailyTransportClient( + room_url, token, bot_name, params, callbacks, self._loop) self._input: DailyInputTransport | None = None self._output: DailyOutputTransport | None = None - self._loop = asyncio.get_running_loop() - - self._event_handlers: dict = {} # Register supported handlers. The user will only be able to register # these handlers. @@ -741,10 +746,10 @@ def capture_participant_video( participant_id, framerate, video_source, color_format) def _on_joined(self, participant): - self.on_joined(participant) + self._call_async_event_handler("on_joined", participant) def _on_left(self): - self.on_left() + self._call_async_event_handler("on_left") def _on_error(self, error): # TODO(aleix): Report error to input/output transports. The one managing @@ -754,10 +759,10 @@ def _on_error(self, error): def _on_app_message(self, message: Any, sender: str): if self._input: self._input.push_app_message(message, sender) - self.on_app_message(message, sender) + self._call_async_event_handler("on_app_message", message, sender) def _on_call_state_updated(self, state: str): - self.on_call_state_updated(state) + self._call_async_event_handler("on_call_state_updated", state) async def _handle_dialin_ready(self, sip_endpoint: str): if not self._params.dialin_settings: @@ -793,28 +798,28 @@ async def _handle_dialin_ready(self, sip_endpoint: str): def _on_dialin_ready(self, sip_endpoint): if self._params.dialin_settings: asyncio.run_coroutine_threadsafe(self._handle_dialin_ready(sip_endpoint), self._loop) - self.on_dialin_ready(sip_endpoint) + self._call_async_event_handler("on_dialin_ready", sip_endpoint) def _on_dialout_connected(self, data): - self.on_dialout_connected(data) + self._call_async_event_handler("on_dialout_connected", data) def _on_dialout_stopped(self, data): - self.on_dialout_stopped(data) + self._call_async_event_handler("on_dialout_stopped", data) def _on_dialout_error(self, data): - self.on_dialout_error(data) + self._call_async_event_handler("on_dialout_error", data) def _on_dialout_warning(self, data): - self.on_dialout_warning(data) + self._call_async_event_handler("on_dialout_warning", data) def _on_participant_joined(self, participant): - self.on_participant_joined(participant) + self._call_async_event_handler("on_participant_joined", participant) def _on_participant_left(self, participant, reason): - self.on_participant_left(participant, reason) + self._call_async_event_handler("on_participant_left", participant, reason) def _on_first_participant_joined(self, participant): - self.on_first_participant_joined(participant) + self._call_async_event_handler("on_first_participant_joined", participant) def _on_transcription_message(self, participant_id, message): text = message["text"] @@ -829,84 +834,7 @@ def _on_transcription_message(self, participant_id, message): if self._input: self._input.push_transcription_frame(frame) - # - # Decorators (event handlers) - # - - def on_joined(self, participant): - pass - - def on_left(self): - pass - - def on_app_message(self, message, sender): - pass - - def on_call_state_updated(self, state): - pass - - def on_dialin_ready(self, sip_endpoint): - pass - - def on_dialout_connected(self, data): - pass - - def on_dialout_stopped(self, data): - pass - - def on_dialout_error(self, data): - pass - - def on_dialout_warning(self, data): - pass - - def on_first_participant_joined(self, participant): - pass - - def on_participant_joined(self, participant): - pass - - def on_participant_left(self, participant, reason): - pass - - def event_handler(self, event_name: str): - def decorator(handler): - self._add_event_handler(event_name, handler) - return handler - return decorator - - def _register_event_handler(self, event_name: str): - methods = inspect.getmembers(self, predicate=inspect.ismethod) - if event_name not in [method[0] for method in methods]: - raise Exception(f"Event handler {event_name} not found") - - self._event_handlers[event_name] = [getattr(self, event_name)] - - patch_method = types.MethodType(partial(self._patch_method, event_name), self) - setattr(self, event_name, patch_method) - - def _add_event_handler(self, event_name: str, handler): - if event_name not in self._event_handlers: - raise Exception(f"Event handler {event_name} not registered") - self._event_handlers[event_name].append(types.MethodType(handler, self)) - - def _patch_method(self, event_name, *args, **kwargs): - try: - for handler in self._event_handlers[event_name]: - if inspect.iscoroutinefunction(handler): - # Beware, if handler() calls another event handler it - # will deadlock. You shouldn't do that anyways. - future = asyncio.run_coroutine_threadsafe( - handler(*args[1:], **kwargs), self._loop) - - # wait for the coroutine to finish. This will also - # raise any exceptions raised by the coroutine. - future.result() - else: - handler(*args[1:], **kwargs) - except Exception as e: - logger.error(f"Exception in event handler {event_name}: {e}") - raise e - - # def start_recording(self): - # self.client.start_recording() + def _call_async_event_handler(self, event_name: str, *args, **kwargs): + future = asyncio.run_coroutine_threadsafe( + self._call_event_handler(event_name, *args, **kwargs), self._loop) + future.result()