From 319b8e781667e766ec54e59c490d1cd647c67554 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Mon, 8 Apr 2024 23:10:10 -0700 Subject: [PATCH] updated ImageFrame and added URLImageFrame and UserImageFrame --- .../05a-local-sync-speech-and-text.py | 5 ++- examples/foundational/06a-image-sync.py | 4 +- examples/foundational/08-bots-arguing.py | 4 +- examples/foundational/10-wake-word.py | 2 +- examples/starter-apps/chatbot.py | 2 +- examples/starter-apps/storybot.py | 10 ++--- src/dailyai/pipeline/aggregators.py | 2 +- src/dailyai/pipeline/frames.py | 32 ++++++++++++++- src/dailyai/services/ai_services.py | 7 ++-- src/dailyai/services/azure_ai_services.py | 4 +- src/dailyai/services/fal_ai_services.py | 4 +- src/dailyai/services/open_ai_services.py | 4 +- .../services/to_be_updated/mock_ai_service.py | 2 +- src/dailyai/transports/daily_transport.py | 39 ++++++++++++++++++- tests/test_aggregators.py | 4 +- tests/test_daily_transport_service.py | 2 +- 16 files changed, 97 insertions(+), 30 deletions(-) diff --git a/examples/foundational/05a-local-sync-speech-and-text.py b/examples/foundational/05a-local-sync-speech-and-text.py index 7083c367f..5e2116496 100644 --- a/examples/foundational/05a-local-sync-speech-and-text.py +++ b/examples/foundational/05a-local-sync-speech-and-text.py @@ -5,7 +5,7 @@ import os from dailyai.pipeline.aggregators import LLMFullResponseAggregator -from dailyai.pipeline.frames import AudioFrame, ImageFrame, LLMMessagesFrame, TextFrame +from dailyai.pipeline.frames import AudioFrame, URLImageFrame, LLMMessagesFrame, TextFrame from dailyai.services.open_ai_services import OpenAILLMService from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService from dailyai.services.fal_ai_services import FalImageGenService @@ -94,6 +94,7 @@ async def get_month_data(month): "text": image_description, "image_url": image_data[0], "image": image_data[1], + "image_size": image_data[2], "audio": audio, } @@ -117,7 +118,7 @@ async def show_images(): if data: await transport.send_queue.put( [ - ImageFrame(data["image_url"], data["image"]), + URLImageFrame(data["image_url"], data["image"], data["image_size"]), AudioFrame(data["audio"]), ] ) diff --git a/examples/foundational/06a-image-sync.py b/examples/foundational/06a-image-sync.py index ad0b67d4f..fbc2aaaa7 100644 --- a/examples/foundational/06a-image-sync.py +++ b/examples/foundational/06a-image-sync.py @@ -35,9 +35,9 @@ def __init__(self, speaking_path: str, waiting_path: str): self._waiting_image_bytes = self._waiting_image.tobytes() async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: - yield ImageFrame(None, self._speaking_image_bytes) + yield ImageFrame(self._speaking_image_bytes, (1024, 1024)) yield frame - yield ImageFrame(None, self._waiting_image_bytes) + yield ImageFrame(self._waiting_image_bytes, (1024, 1024)) async def main(room_url: str, token): diff --git a/examples/foundational/08-bots-arguing.py b/examples/foundational/08-bots-arguing.py index 942965ac5..dabd25117 100644 --- a/examples/foundational/08-bots-arguing.py +++ b/examples/foundational/08-bots-arguing.py @@ -122,7 +122,7 @@ async def argue(): ) await transport.send_queue.put( [ - ImageFrame(None, image_data1[1]), + ImageFrame(image_data1[1], image_data1[2]), AudioFrame(audio1), ] ) @@ -134,7 +134,7 @@ async def argue(): ) await transport.send_queue.put( [ - ImageFrame(None, image_data2[1]), + ImageFrame(image_data2[1], image_data2[2]), AudioFrame(audio2), ] ) diff --git a/examples/foundational/10-wake-word.py b/examples/foundational/10-wake-word.py index 590b09a56..0546b5305 100644 --- a/examples/foundational/10-wake-word.py +++ b/examples/foundational/10-wake-word.py @@ -55,7 +55,7 @@ sprites[file] = img.tobytes() # When the bot isn't talking, show a static image of the cat listening -quiet_frame = ImageFrame("", sprites["sc-listen-1.png"]) +quiet_frame = ImageFrame(sprites["sc-listen-1.png"], (720, 1280)) # When the bot is talking, build an animation from two sprites talking_list = [sprites["sc-default.png"], sprites["sc-talk.png"]] talking = [random.choice(talking_list) for x in range(30)] diff --git a/examples/starter-apps/chatbot.py b/examples/starter-apps/chatbot.py index 79a28d88e..166dab6fb 100644 --- a/examples/starter-apps/chatbot.py +++ b/examples/starter-apps/chatbot.py @@ -48,7 +48,7 @@ flipped = sprites[::-1] sprites.extend(flipped) # When the bot isn't talking, show a static image of the cat listening -quiet_frame = ImageFrame("", sprites[0]) +quiet_frame = ImageFrame(sprites[0], (1024, 576)) talking_frame = SpriteFrame(images=sprites) diff --git a/examples/starter-apps/storybot.py b/examples/starter-apps/storybot.py index 73102657a..8e40e296b 100644 --- a/examples/starter-apps/storybot.py +++ b/examples/starter-apps/storybot.py @@ -99,7 +99,7 @@ async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: 1. Catch the frames that are generated by the LLM service """ if isinstance(frame, UserStoppedSpeakingFrame): - yield ImageFrame(None, images["grandma-writing.png"]) + yield ImageFrame(images["grandma-writing.png"], (1024, 1024)) yield AudioFrame(sounds["talking.wav"]) elif isinstance(frame, TextFrame): @@ -112,7 +112,7 @@ async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: self._text = self._text.replace("\n", " ") if len(self._text) > 2: - yield ImageFrame(None, images["grandma-writing.png"]) + yield ImageFrame(images["grandma-writing.png"], (1024, 1024)) yield StoryStartFrame(self._text) yield AudioFrame(sounds["ding3.wav"]) self._text = "" @@ -146,11 +146,11 @@ async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: # last bit pass elif isinstance(frame, LLMResponseEndFrame): - yield ImageFrame(None, images["grandma-writing.png"]) + yield ImageFrame(images["grandma-writing.png"], (1024, 1024)) yield StoryPromptFrame(self._text) self._text = "" yield frame - yield ImageFrame(None, images["grandma-listening.png"]) + yield ImageFrame(images["grandma-listening.png"], (1024, 1024)) yield AudioFrame(sounds["listening.wav"]) else: @@ -252,7 +252,7 @@ async def storytime(): [llm, lca, tts], sink=transport.send_queue) await local_pipeline.queue_frames( [ - ImageFrame(None, images["grandma-listening.png"]), + ImageFrame(images["grandma-listening.png"], (1024, 1024)), LLMMessagesFrame(intro_messages), AudioFrame(sounds["listening.wav"]), EndPipeFrame(), diff --git a/src/dailyai/pipeline/aggregators.py b/src/dailyai/pipeline/aggregators.py index bbed2fbc3..887f00a97 100644 --- a/src/dailyai/pipeline/aggregators.py +++ b/src/dailyai/pipeline/aggregators.py @@ -360,7 +360,7 @@ class GatedAggregator(FrameProcessor): ... start_open=False) >>> asyncio.run(print_frames(aggregator, TextFrame("Hello"))) >>> asyncio.run(print_frames(aggregator, TextFrame("Hello again."))) - >>> asyncio.run(print_frames(aggregator, ImageFrame(url='', image=bytes([])))) + >>> asyncio.run(print_frames(aggregator, ImageFrame(image=bytes([]), size=(0, 0)))) ImageFrame Hello Hello again. diff --git a/src/dailyai/pipeline/frames.py b/src/dailyai/pipeline/frames.py index 4942747b8..2505e203f 100644 --- a/src/dailyai/pipeline/frames.py +++ b/src/dailyai/pipeline/frames.py @@ -70,11 +70,39 @@ def __str__(self): class ImageFrame(Frame): """An image. Will be shown by the transport if the transport's camera is enabled.""" - url: str | None image: bytes + size: tuple[int, int] + + def __str__(self): + return f"{self.__class__.__name__}, image size: {self.size[0]}x{self.size[1]} buffer size: {len(self.image)} B" + + +@dataclass() +class URLImageFrame(ImageFrame): + """An image. Will be shown by the transport if the transport's camera is + enabled.""" + url: str | None + + def __init__(self, url, image, size): + super().__init__(image, size) + self.url = url + + def __str__(self): + return f"{self.__class__.__name__}, url: {self.url}, image size: {self.size[0]}x{self.size[1]}, buffer size: {len(self.image)} B" + + +@dataclass() +class UserImageFrame(ImageFrame): + """An image associated to a user. Will be shown by the transport if the transport's camera is + enabled.""" + user_id: str + + def __init__(self, user_id, image, size): + super().__init__(image, size) + self.user_id = user_id def __str__(self): - return f"{self.__class__.__name__}, url: {self.url}, image size: {len(self.image)} B" + return f"{self.__class__.__name__}, user: {self.user_id}, image size: {self.size[0]}x{self.size[1]}, buffer size: {len(self.image)} B" @dataclass() diff --git a/src/dailyai/services/ai_services.py b/src/dailyai/services/ai_services.py index 815cf9390..d4ecebee4 100644 --- a/src/dailyai/services/ai_services.py +++ b/src/dailyai/services/ai_services.py @@ -14,6 +14,7 @@ TTSStartFrame, TextFrame, TranscriptionFrame, + URLImageFrame, ) from abc import abstractmethod @@ -87,7 +88,7 @@ def __init__(self, image_size, **kwargs): # Renders the image. Returns an Image object. @abstractmethod - async def run_image_gen(self, sentence: str) -> tuple[str, bytes]: + async def run_image_gen(self, sentence: str) -> tuple[str, bytes, tuple[int, int]]: pass async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: @@ -95,8 +96,8 @@ async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: yield frame return - (url, image_data) = await self.run_image_gen(frame.text) - yield ImageFrame(url, image_data) + (url, image_data, image_size) = await self.run_image_gen(frame.text) + yield URLImageFrame(url, image_data, image_size) class STTService(AIService): diff --git a/src/dailyai/services/azure_ai_services.py b/src/dailyai/services/azure_ai_services.py index e9b15ec30..de068fbb4 100644 --- a/src/dailyai/services/azure_ai_services.py +++ b/src/dailyai/services/azure_ai_services.py @@ -105,7 +105,7 @@ def __init__( self._model = model self._aiohttp_session = aiohttp_session - async def run_image_gen(self, sentence) -> tuple[str, bytes]: + async def run_image_gen(self, sentence) -> tuple[str, bytes, tuple[int, int]]: url = f"{self._azure_endpoint}openai/images/generations:submit?api-version={self._api_version}" headers = { "api-key": self._api_key, @@ -146,4 +146,4 @@ async def run_image_gen(self, sentence) -> tuple[str, bytes]: async with self._aiohttp_session.get(image_url) as response: image_stream = io.BytesIO(await response.content.read()) image = Image.open(image_stream) - return (image_url, image.tobytes()) + return (image_url, image.tobytes(), image.size) diff --git a/src/dailyai/services/fal_ai_services.py b/src/dailyai/services/fal_ai_services.py index 1f97db598..4a7016011 100644 --- a/src/dailyai/services/fal_ai_services.py +++ b/src/dailyai/services/fal_ai_services.py @@ -31,7 +31,7 @@ def __init__( if key_secret: os.environ["FAL_KEY_SECRET"] = key_secret - async def run_image_gen(self, sentence) -> tuple[str, bytes]: + async def run_image_gen(self, sentence) -> tuple[str, bytes, tuple[int, int]]: def get_image_url(sentence, size): handler = fal.apps.submit( "110602490-fast-sdxl", @@ -55,4 +55,4 @@ def get_image_url(sentence, size): async with self._aiohttp_session.get(image_url) as response: image_stream = io.BytesIO(await response.content.read()) image = Image.open(image_stream) - return (image_url, image.tobytes()) + return (image_url, image.tobytes(), image.size) diff --git a/src/dailyai/services/open_ai_services.py b/src/dailyai/services/open_ai_services.py index 9d177f8b7..95045f6dc 100644 --- a/src/dailyai/services/open_ai_services.py +++ b/src/dailyai/services/open_ai_services.py @@ -36,7 +36,7 @@ def __init__( self._client = AsyncOpenAI(api_key=api_key) self._aiohttp_session = aiohttp_session - async def run_image_gen(self, sentence) -> tuple[str, bytes]: + async def run_image_gen(self, sentence) -> tuple[str, bytes, tuple[int, int]]: self.logger.info("Generating OpenAI image", sentence) image = await self._client.images.generate( @@ -53,4 +53,4 @@ async def run_image_gen(self, sentence) -> tuple[str, bytes]: async with self._aiohttp_session.get(image_url) as response: image_stream = io.BytesIO(await response.content.read()) image = Image.open(image_stream) - return (image_url, image.tobytes()) + return (image_url, image.tobytes(), image.size) diff --git a/src/dailyai/services/to_be_updated/mock_ai_service.py b/src/dailyai/services/to_be_updated/mock_ai_service.py index be608c9f5..dc200f622 100644 --- a/src/dailyai/services/to_be_updated/mock_ai_service.py +++ b/src/dailyai/services/to_be_updated/mock_ai_service.py @@ -19,7 +19,7 @@ def run_image_gen(self, sentence): image_stream = io.BytesIO(response.content) image = Image.open(image_stream) time.sleep(1) - return (image_url, image) + return (image_url, image.tobytes(), image.size) def run_llm(self, messages, latest_user_message=None, stream=True): for i in range(5): diff --git a/src/dailyai/transports/daily_transport.py b/src/dailyai/transports/daily_transport.py index 3a77dc534..d48f9d8ed 100644 --- a/src/dailyai/transports/daily_transport.py +++ b/src/dailyai/transports/daily_transport.py @@ -2,6 +2,7 @@ import inspect import logging import signal +import time import threading import types @@ -11,6 +12,7 @@ from dailyai.pipeline.frames import ( ReceivedAppMessageFrame, TranscriptionFrame, + UserImageFrame, ) from threading import Event @@ -58,6 +60,7 @@ def __init__( bot_name: str, min_others_count: int = 1, start_transcription: bool = False, + video_rendering_enabled: bool = False, **kwargs, ): kwargs['has_webrtc_vad'] = True @@ -69,6 +72,7 @@ def __init__( self._token: str | None = token self._min_others_count = min_others_count self._start_transcription = start_transcription + self._video_rendering_enabled = video_rendering_enabled self._is_interrupted = Event() self._stop_threads = Event() @@ -76,6 +80,8 @@ def __init__( self._other_participant_has_joined = False self._my_participant_id = None + self._video_renderers = {} + self.transcription_settings = { "language": "en", "tier": "nova", @@ -236,7 +242,7 @@ def _prerun(self): self.client.update_subscription_profiles({ "base": { - "camera": "unsubscribed", + "camera": "subscribed" if self._video_rendering_enabled else "unsubscribed", } }) @@ -268,6 +274,37 @@ def dialout(self, number): def start_recording(self): self.client.start_recording() + def render_participant_video(self, + participant_id, + framerate=10, + video_source="camera", + color_format="RGB") -> None: + if not self._video_rendering_enabled: + self._logger.warn("Video rendering is not enabled") + + self._video_renderers[participant_id] = { + "framerate": framerate, + "timestamp": 0, + } + self.client.set_video_renderer( + participant_id, + self.on_participant_video_frame, + video_source=video_source, + color_format=color_format) + + def on_participant_video_frame(self, participant_id, video_frame): + curr_time = time.time() + prev_time = self._video_renderers[participant_id]["timestamp"] + diff_time = curr_time - prev_time + period = 1 / self._video_renderers[participant_id]["framerate"] + if diff_time > period and self._loop: + self._video_renderers[participant_id]["timestamp"] = curr_time + frame = UserImageFrame(participant_id, video_frame.buffer, + (video_frame.width, video_frame.height)) + asyncio.run_coroutine_threadsafe( + self.receive_queue.put(frame), self._loop + ) + def on_error(self, error): self._logger.error(f"on_error: {error}") diff --git a/tests/test_aggregators.py b/tests/test_aggregators.py index d232349e3..5c522f787 100644 --- a/tests/test_aggregators.py +++ b/tests/test_aggregators.py @@ -54,13 +54,13 @@ async def test_gated_accumulator(self): TextFrame("Hello, "), TextFrame("world."), AudioFrame(b"hello"), - ImageFrame("image", b"image"), + ImageFrame(b"image", (0, 0)), AudioFrame(b"world"), LLMResponseEndFrame(), ] expected_output_frames = [ - ImageFrame("image", b"image"), + ImageFrame(b"image", (0, 0)), LLMResponseStartFrame(), TextFrame("Hello, "), TextFrame("world."), diff --git a/tests/test_daily_transport_service.py b/tests/test_daily_transport_service.py index b620acf75..9d02cd14b 100644 --- a/tests/test_daily_transport_service.py +++ b/tests/test_daily_transport_service.py @@ -68,7 +68,7 @@ async def send_audio_frame(): await transport.send_queue.put(AudioFrame(bytes([0] * 3300))) async def send_video_frame(): - await transport.send_queue.put(ImageFrame(None, b"test")) + await transport.send_queue.put(ImageFrame(b"test", (0, 0))) await asyncio.gather(transport.run(), send_audio_frame(), send_video_frame())