From b117a185e3e47abddc2eba46ef56ce75c368b14c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Tue, 9 Apr 2024 22:14:54 -0700 Subject: [PATCH 1/8] frames: added UserImageRequestFrame --- src/dailyai/pipeline/frames.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/dailyai/pipeline/frames.py b/src/dailyai/pipeline/frames.py index 2505e203f..7117e8f88 100644 --- a/src/dailyai/pipeline/frames.py +++ b/src/dailyai/pipeline/frames.py @@ -105,6 +105,15 @@ def __str__(self): return f"{self.__class__.__name__}, user: {self.user_id}, image size: {self.size[0]}x{self.size[1]}, buffer size: {len(self.image)} B" +@dataclass() +class UserImageRequestFrame(Frame): + """A frame user to request an image from the given user.""" + user_id: str + + def __str__(self): + return f"{self.__class__.__name__}, user: {self.user_id}" + + @dataclass() class SpriteFrame(Frame): """An animated sprite. Will be shown by the transport if the transport's @@ -172,10 +181,10 @@ def __str__(self): @dataclass() class SendAppMessageFrame(Frame): message: Any - participantId: str | None + participant_id: str | None def __str__(self): - return f"SendAppMessageFrame: participantId: {self.participantId}, message: {self.message}" + return f"SendAppMessageFrame: participant: {self.participant_id}, message: {self.message}" class UserStartedSpeakingFrame(Frame): From a5eba0106bea2a6157b0f99f25bab04b5e82e985 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Tue, 9 Apr 2024 22:16:05 -0700 Subject: [PATCH 2/8] transport: allow requesting a user frame --- src/dailyai/transports/daily_transport.py | 36 ++++++++++++++------ src/dailyai/transports/threaded_transport.py | 12 +++++-- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/src/dailyai/transports/daily_transport.py b/src/dailyai/transports/daily_transport.py index 3adbf57a2..9d32d501b 100644 --- a/src/dailyai/transports/daily_transport.py +++ b/src/dailyai/transports/daily_transport.py @@ -169,8 +169,12 @@ def write_frame_to_mic(self, frame: bytes): if self._mic_enabled: self.mic.write_frames(frame) - def send_app_message(self, message: Any, participantId: str | None): - self.client.send_app_message(message, participantId) + def request_participant_image(self, participant_id: str): + if participant_id in self._video_renderers: + self._video_renderers[participant_id]["render_next_frame"] = True + + def send_app_message(self, message: Any, participant_id: str | None): + self.client.send_app_message(message, participant_id) def read_audio_frames(self, desired_frame_count): bytes = b"" @@ -302,6 +306,7 @@ def render_participant_video(self, self._video_renderers[participant_id] = { "framerate": framerate, "timestamp": 0, + "render_next_frame": False, } self.client.set_video_renderer( participant_id, @@ -310,17 +315,28 @@ def render_participant_video(self, color_format=color_format) def on_participant_video_frame(self, participant_id, video_frame): + if not self._loop: + return + + render_frame = False + curr_time = time.time() - prev_time = self._video_renderers[participant_id]["timestamp"] - diff_time = curr_time - prev_time - period = 1 / self._video_renderers[participant_id]["framerate"] - if diff_time > period and self._loop: - self._video_renderers[participant_id]["timestamp"] = curr_time + framerate = self._video_renderers[participant_id]["framerate"] + + if framerate > 0: + prev_time = self._video_renderers[participant_id]["timestamp"] + next_time = prev_time + 1 / framerate + render_frame = curr_time > next_time + elif self._video_renderers[participant_id]["render_next_frame"]: + self._video_renderers[participant_id]["render_next_frame"] = False + render_frame = True + + if render_frame: frame = UserImageFrame(participant_id, video_frame.buffer, (video_frame.width, video_frame.height)) - asyncio.run_coroutine_threadsafe( - self.receive_queue.put(frame), self._loop - ) + asyncio.run_coroutine_threadsafe(self.receive_queue.put(frame), self._loop) + + self._video_renderers[participant_id]["timestamp"] = curr_time def on_error(self, error): self._logger.error(f"on_error: {error}") diff --git a/src/dailyai/transports/threaded_transport.py b/src/dailyai/transports/threaded_transport.py index 334bffce3..52736a008 100644 --- a/src/dailyai/transports/threaded_transport.py +++ b/src/dailyai/transports/threaded_transport.py @@ -20,6 +20,7 @@ SpriteFrame, StartFrame, TextFrame, + UserImageRequestFrame, UserStartedSpeakingFrame, UserStoppedSpeakingFrame, ) @@ -382,7 +383,11 @@ def _set_image(self, image: bytes): def _set_images(self, images: list[bytes], start_frame=0): self._images = itertools.cycle(images) - def send_app_message(self, message: Any, participantId: str | None): + def request_participant_image(self, participant_id: str): + """ Child classes should override this to force an image from a user. """ + pass + + def send_app_message(self, message: Any, participant_id: str | None): """ Child classes should override this to send a custom message to the room. """ pass @@ -458,9 +463,10 @@ def _frame_consumer(self): self._set_image(frame.image) elif isinstance(frame, SpriteFrame): self._set_images(frame.images) + elif isinstance(frame, UserImageRequestFrame): + self.request_participant_image(frame.user_id) elif isinstance(frame, SendAppMessageFrame): - self.send_app_message( - frame.message, frame.participantId) + self.send_app_message(frame.message, frame.participant_id) elif len(b): self.write_frame_to_mic(bytes(b)) b = bytearray() From 84cfa7cc9528d154a4def131c575f9df8871f929 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Tue, 9 Apr 2024 22:16:25 -0700 Subject: [PATCH 3/8] services: added VisionService --- src/dailyai/services/ai_services.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/dailyai/services/ai_services.py b/src/dailyai/services/ai_services.py index d4ecebee4..d620be3d2 100644 --- a/src/dailyai/services/ai_services.py +++ b/src/dailyai/services/ai_services.py @@ -100,6 +100,31 @@ async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: yield URLImageFrame(url, image_data, image_size) +class VisionService(AIService): + """VisionService is a base class for vision services.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._describe_text = None + + @abstractmethod + async def run_vision(self, describe_text: str, frame: ImageFrame) -> str: + pass + + async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: + if isinstance(frame, TextFrame): + self._describe_text = frame.text + elif isinstance(frame, ImageFrame): + if self._describe_text: + description = await self.run_vision(self._describe_text, frame) + self._describe_text = None + yield TextFrame(description) + else: + yield frame + else: + yield frame + + class STTService(AIService): """STTService is a base class for speech-to-text services.""" From 18bf09c704de1ccd249f28d84f24e745a0c2b661 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Tue, 9 Apr 2024 22:18:26 -0700 Subject: [PATCH 4/8] services: added MoondreamService --- README.md | 4 +- pyproject.toml | 1 + src/dailyai/services/moondream_ai_service.py | 52 ++++++++++++++++++++ 3 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 src/dailyai/services/moondream_ai_service.py diff --git a/README.md b/README.md index 1f625ef45..eeeb1d4f1 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,8 @@ Currently implemented services: - Transport - Daily - Local (in progress, intended as a quick start example service) +- Vision + - Moondream If you'd like to [implement a service]((https://github.com/daily-co/daily-ai-sdk/tree/main/src/dailyai/services)), we welcome PRs! Our goal is to support lots of services in all of the above categories, plus new categories (like real-time video) as they emerge. @@ -63,7 +65,7 @@ pip install "dailyai[option,...]" Your project may or may not need these, so they're made available as optional requirements. Here is a list: -- **AI services**: `anthropic`, `azure`, `fal`, `openai`, `playht`, `silero`, `whisper` +- **AI services**: `anthropic`, `azure`, `fal`, `moondream`, `openai`, `playht`, `silero`, `whisper` - **Transports**: `daily`, `local`, `websocket` ## Code examples diff --git a/pyproject.toml b/pyproject.toml index 8baf090ae..c42452d4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ daily = [ "daily-python~=0.7.0" ] examples = [ "python-dotenv~=1.0.0", "flask~=3.0.0", "flask_cors~=4.0.0" ] fal = [ "fal~=0.12.0" ] local = [ "pyaudio~=0.2.0" ] +moondream = [ "einops~=0.7.0", "timm~=0.9.0", "transformers~=4.39.0" ] openai = [ "openai~=1.14.0" ] playht = [ "pyht~=0.0.26" ] silero = [ "torch~=2.2.0", "torchaudio~=2.2.0" ] diff --git a/src/dailyai/services/moondream_ai_service.py b/src/dailyai/services/moondream_ai_service.py new file mode 100644 index 000000000..2784a669e --- /dev/null +++ b/src/dailyai/services/moondream_ai_service.py @@ -0,0 +1,52 @@ +from dailyai.pipeline.frames import ImageFrame +from dailyai.services.ai_services import VisionService + +from PIL import Image + +from transformers import AutoModelForCausalLM, AutoTokenizer + +import torch + + +def detect_device(): + """ + Detects the appropriate device to run on, and return the device and dtype. + """ + if torch.cuda.is_available(): + return torch.device("cuda"), torch.float16 + elif torch.backends.mps.is_available(): + return torch.device("mps"), torch.float16 + else: + return torch.device("cpu"), torch.float32 + + +class MoondreamService(VisionService): + def __init__( + self, + model_id="vikhyatk/moondream2", + revision="2024-04-02", + device=None + ): + super().__init__() + + if not device: + device, dtype = detect_device() + else: + device = torch.device("cpu") + dtype = torch.float32 + + self._tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision) + + self._model = AutoModelForCausalLM.from_pretrained( + model_id, trust_remote_code=True, revision=revision + ).to(device=device, dtype=dtype) + self._model.eval() + + async def run_vision(self, describe_text: str, frame: ImageFrame) -> str: + image = Image.frombytes("RGB", (frame.size[0], frame.size[1]), frame.image) + image_embeds = self._model.encode_image(image) + description = self._model.answer_question( + image_embeds=image_embeds, + question=describe_text, + tokenizer=self._tokenizer) + return description From 34a6c5691b161458e9fcd668ead0234e5a6ceb46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Tue, 9 Apr 2024 22:18:55 -0700 Subject: [PATCH 5/8] examples: added 12-describe-video --- examples/foundational/12-describe-video.py | 82 +++++++++++++++++++ .../14-render-remote-participant.py | 1 - .../14a-local-render-remote-participant.py | 1 - 3 files changed, 82 insertions(+), 2 deletions(-) create mode 100644 examples/foundational/12-describe-video.py diff --git a/examples/foundational/12-describe-video.py b/examples/foundational/12-describe-video.py new file mode 100644 index 000000000..f7cf851b6 --- /dev/null +++ b/examples/foundational/12-describe-video.py @@ -0,0 +1,82 @@ +import asyncio +import aiohttp +import logging +import os + +from typing import AsyncGenerator + +from dailyai.pipeline.aggregators import FrameProcessor, UserResponseAggregator + +from dailyai.pipeline.frames import Frame, TextFrame, UserImageRequestFrame +from dailyai.pipeline.pipeline import Pipeline +from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService +from dailyai.services.moondream_ai_service import MoondreamService +from dailyai.transports.daily_transport import DailyTransport + +from runner import configure + +from dotenv import load_dotenv +load_dotenv(override=True) + +logging.basicConfig(format=f"%(levelno)s %(asctime)s %(message)s") +logger = logging.getLogger("dailyai") +logger.setLevel(logging.DEBUG) + + +class UserImageRequester(FrameProcessor): + participant_id: str + + def set_participant_id(self, participant_id: str): + self.participant_id = participant_id + + async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: + if self.participant_id and isinstance(frame, TextFrame): + yield UserImageRequestFrame(self.participant_id) + yield frame + + +async def main(room_url: str, token): + async with aiohttp.ClientSession() as session: + transport = DailyTransport( + room_url, + token, + "Describe participant video", + duration_minutes=5, + mic_enabled=True, + mic_sample_rate=16000, + vad_enabled=True, + start_transcription=True, + video_rendering_enabled=True + ) + + tts = ElevenLabsTTSService( + aiohttp_session=session, + api_key=os.getenv("ELEVENLABS_API_KEY"), + voice_id=os.getenv("ELEVENLABS_VOICE_ID"), + ) + + user_response = UserResponseAggregator() + + image_requester = UserImageRequester() + + moondream = MoondreamService() + + tts = ElevenLabsTTSService( + aiohttp_session=session, + api_key=os.getenv("ELEVENLABS_API_KEY"), + voice_id=os.getenv("ELEVENLABS_VOICE_ID"), + ) + + @transport.event_handler("on_first_other_participant_joined") + async def on_first_other_participant_joined(transport, participant): + await transport.say("Hi there! Feel free to ask me what I see.", tts) + transport.render_participant_video(participant["id"], framerate=0) + image_requester.set_participant_id(participant["id"]) + + pipeline = Pipeline([user_response, image_requester, moondream, tts]) + + await transport.run(pipeline) + +if __name__ == "__main__": + (url, token) = configure() + asyncio.run(main(url, token)) diff --git a/examples/foundational/14-render-remote-participant.py b/examples/foundational/14-render-remote-participant.py index 88311f502..7c2750754 100644 --- a/examples/foundational/14-render-remote-participant.py +++ b/examples/foundational/14-render-remote-participant.py @@ -1,5 +1,4 @@ import asyncio -import io import logging from typing import AsyncGenerator diff --git a/examples/foundational/14a-local-render-remote-participant.py b/examples/foundational/14a-local-render-remote-participant.py index 979b19f9d..7614ef21d 100644 --- a/examples/foundational/14a-local-render-remote-participant.py +++ b/examples/foundational/14a-local-render-remote-participant.py @@ -1,5 +1,4 @@ import asyncio -import io import logging import tkinter as tk From 5ef5cf30f4bef11afcd23edc860d4a4533aa7ecd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Tue, 9 Apr 2024 22:19:19 -0700 Subject: [PATCH 6/8] update linux-py3.10 requirements --- linux-py3.10-requirements.txt | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/linux-py3.10-requirements.txt b/linux-py3.10-requirements.txt index 928d03ce6..80f010a9c 100644 --- a/linux-py3.10-requirements.txt +++ b/linux-py3.10-requirements.txt @@ -50,7 +50,7 @@ cryptography==42.0.5 # via pyjwt ctranslate2==4.1.0 # via faster-whisper -daily-python==0.7.2 +daily-python==0.7.3 # via dailyai (pyproject.toml) deprecated==1.2.14 # via opentelemetry-api @@ -62,6 +62,8 @@ distro==1.9.0 # via # anthropic # openai +einops==0.7.0 + # via dailyai (pyproject.toml) exceptiongroup==1.2.0 # via anyio fal==0.12.7 @@ -70,11 +72,12 @@ fastapi==0.99.1 # via fal faster-whisper==1.0.1 # via dailyai (pyproject.toml) -filelock==3.13.3 +filelock==3.13.4 # via # huggingface-hub # pyht # torch + # transformers # triton # virtualenv flask==3.0.3 @@ -114,7 +117,9 @@ httpx==0.27.0 huggingface-hub==0.22.2 # via # faster-whisper + # timm # tokenizers + # transformers humanfriendly==10.0 # via coloredlogs idna==3.6 @@ -160,6 +165,8 @@ numpy==1.26.4 # ctranslate2 # dailyai (pyproject.toml) # onnxruntime + # torchvision + # transformers nvidia-cublas-cu12==12.1.3.1 # via # nvidia-cudnn-cu12 @@ -208,12 +215,14 @@ packaging==24.0 # fal # huggingface-hub # onnxruntime + # transformers pathspec==0.11.2 # via fal pillow==10.2.0 # via # dailyai (pyproject.toml) # fal + # torchvision platformdirs==4.2.0 # via # isolate @@ -251,16 +260,25 @@ pyyaml==6.0.1 # ctranslate2 # huggingface-hub # isolate + # timm + # transformers +regex==2023.12.25 + # via transformers requests==2.31.0 # via # huggingface-hub # pyht + # transformers rich==13.7.1 # via # fal # rich-click rich-click==1.7.4 # via fal +safetensors==0.4.2 + # via + # timm + # transformers six==1.16.0 # via python-dateutil sniffio==1.3.1 @@ -279,20 +297,30 @@ sympy==1.12 # torch tblib==3.0.0 # via isolate +timm==0.9.16 + # via dailyai (pyproject.toml) tokenizers==0.15.2 # via # anthropic # faster-whisper + # transformers torch==2.2.2 # via # dailyai (pyproject.toml) + # timm # torchaudio + # torchvision torchaudio==2.2.2 # via dailyai (pyproject.toml) +torchvision==0.17.2 + # via timm tqdm==4.66.2 # via # huggingface-hub # openai + # transformers +transformers==4.39.3 + # via dailyai (pyproject.toml) triton==2.2.0 # via torch types-python-dateutil==2.9.0.20240316 From 2f9899af5ab0e05f9c502a8f295c36714b3bbc94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Tue, 9 Apr 2024 22:39:04 -0700 Subject: [PATCH 7/8] update macos-py3.10 requirements --- macos-py3.10-requirements.txt | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/macos-py3.10-requirements.txt b/macos-py3.10-requirements.txt index f151f5200..ae0a24e0a 100644 --- a/macos-py3.10-requirements.txt +++ b/macos-py3.10-requirements.txt @@ -50,7 +50,7 @@ cryptography==42.0.5 # via pyjwt ctranslate2==4.1.0 # via faster-whisper -daily-python==0.7.2 +daily-python==0.7.3 # via dailyai (pyproject.toml) deprecated==1.2.14 # via opentelemetry-api @@ -62,6 +62,8 @@ distro==1.9.0 # via # anthropic # openai +einops==0.7.0 + # via dailyai (pyproject.toml) exceptiongroup==1.2.0 # via anyio fal==0.12.7 @@ -70,11 +72,12 @@ fastapi==0.99.1 # via fal faster-whisper==1.0.1 # via dailyai (pyproject.toml) -filelock==3.13.3 +filelock==3.13.4 # via # huggingface-hub # pyht # torch + # transformers # virtualenv flask==3.0.3 # via @@ -113,7 +116,9 @@ httpx==0.27.0 huggingface-hub==0.22.2 # via # faster-whisper + # timm # tokenizers + # transformers humanfriendly==10.0 # via coloredlogs idna==3.6 @@ -159,6 +164,8 @@ numpy==1.26.4 # ctranslate2 # dailyai (pyproject.toml) # onnxruntime + # torchvision + # transformers onnxruntime==1.17.1 # via faster-whisper openai==1.14.3 @@ -176,12 +183,14 @@ packaging==24.0 # fal # huggingface-hub # onnxruntime + # transformers pathspec==0.11.2 # via fal pillow==10.2.0 # via # dailyai (pyproject.toml) # fal + # torchvision platformdirs==4.2.0 # via # isolate @@ -219,16 +228,25 @@ pyyaml==6.0.1 # ctranslate2 # huggingface-hub # isolate + # timm + # transformers +regex==2023.12.25 + # via transformers requests==2.31.0 # via # huggingface-hub # pyht + # transformers rich==13.7.1 # via # fal # rich-click rich-click==1.7.4 # via fal +safetensors==0.4.2 + # via + # timm + # transformers six==1.16.0 # via python-dateutil sniffio==1.3.1 @@ -247,20 +265,30 @@ sympy==1.12 # torch tblib==3.0.0 # via isolate +timm==0.9.16 + # via dailyai (pyproject.toml) tokenizers==0.15.2 # via # anthropic # faster-whisper + # transformers torch==2.2.2 # via # dailyai (pyproject.toml) + # timm # torchaudio + # torchvision torchaudio==2.2.2 # via dailyai (pyproject.toml) +torchvision==0.17.2 + # via timm tqdm==4.66.2 # via # huggingface-hub # openai + # transformers +transformers==4.39.3 + # via dailyai (pyproject.toml) types-python-dateutil==2.9.0.20240316 # via fal typing-extensions==4.10.0 From 3c20f9153d7be72749640a7cb7ffe4ed80faec3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Wed, 10 Apr 2024 09:18:54 -0700 Subject: [PATCH 8/8] added VisionImageFrame and VisionImageFrameAggregator --- examples/foundational/12-describe-video.py | 6 ++-- src/dailyai/pipeline/aggregators.py | 36 ++++++++++++++++++++ src/dailyai/pipeline/frames.py | 22 ++++++++++-- src/dailyai/services/ai_services.py | 15 +++----- src/dailyai/services/moondream_ai_service.py | 6 ++-- 5 files changed, 68 insertions(+), 17 deletions(-) diff --git a/examples/foundational/12-describe-video.py b/examples/foundational/12-describe-video.py index f7cf851b6..8e20d6533 100644 --- a/examples/foundational/12-describe-video.py +++ b/examples/foundational/12-describe-video.py @@ -5,7 +5,7 @@ from typing import AsyncGenerator -from dailyai.pipeline.aggregators import FrameProcessor, UserResponseAggregator +from dailyai.pipeline.aggregators import FrameProcessor, UserResponseAggregator, VisionImageFrameAggregator from dailyai.pipeline.frames import Frame, TextFrame, UserImageRequestFrame from dailyai.pipeline.pipeline import Pipeline @@ -59,6 +59,8 @@ async def main(room_url: str, token): image_requester = UserImageRequester() + vision_aggregator = VisionImageFrameAggregator() + moondream = MoondreamService() tts = ElevenLabsTTSService( @@ -73,7 +75,7 @@ async def on_first_other_participant_joined(transport, participant): transport.render_participant_video(participant["id"], framerate=0) image_requester.set_participant_id(participant["id"]) - pipeline = Pipeline([user_response, image_requester, moondream, tts]) + pipeline = Pipeline([user_response, image_requester, vision_aggregator, moondream, tts]) await transport.run(pipeline) diff --git a/src/dailyai/pipeline/aggregators.py b/src/dailyai/pipeline/aggregators.py index 12da8b6dc..9edc87384 100644 --- a/src/dailyai/pipeline/aggregators.py +++ b/src/dailyai/pipeline/aggregators.py @@ -7,6 +7,7 @@ EndFrame, EndPipeFrame, Frame, + ImageFrame, LLMMessagesFrame, LLMResponseEndFrame, LLMResponseStartFrame, @@ -14,6 +15,7 @@ TranscriptionFrame, UserStartedSpeakingFrame, UserStoppedSpeakingFrame, + VisionImageFrame, ) from dailyai.pipeline.pipeline import Pipeline from dailyai.services.ai_services import AIService @@ -463,3 +465,37 @@ async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: self.accumulator = [] else: self.accumulator.append(frame) + + +class VisionImageFrameAggregator(FrameProcessor): + """This aggregator waits for a consecutive TextFrame and an + ImageFrame. After the ImageFrame arrives it will output a VisionImageFrame. + + >>> from dailyai.pipeline.frames import ImageFrame + + >>> async def print_frames(aggregator, frame): + ... async for frame in aggregator.process_frame(frame): + ... print(frame) + + >>> aggregator = VisionImageFrameAggregator() + >>> asyncio.run(print_frames(aggregator, TextFrame("What do you see?"))) + >>> asyncio.run(print_frames(aggregator, ImageFrame(image=bytes([]), size=(0, 0)))) + VisionImageFrame, text: What do you see?, image size: 0x0, buffer size: 0 B + + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._describe_text = None + + async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: + if isinstance(frame, TextFrame): + self._describe_text = frame.text + elif isinstance(frame, ImageFrame): + if self._describe_text: + yield VisionImageFrame(self._describe_text, frame.image, frame.size) + self._describe_text = None + else: + yield frame + else: + yield frame diff --git a/src/dailyai/pipeline/frames.py b/src/dailyai/pipeline/frames.py index 7117e8f88..4322e65c8 100644 --- a/src/dailyai/pipeline/frames.py +++ b/src/dailyai/pipeline/frames.py @@ -79,8 +79,10 @@ def __str__(self): @dataclass() class URLImageFrame(ImageFrame): - """An image. Will be shown by the transport if the transport's camera is - enabled.""" + """An image with an associated URL. Will be shown by the transport if the + transport's camera is enabled. + + """ url: str | None def __init__(self, url, image, size): @@ -91,6 +93,22 @@ def __str__(self): return f"{self.__class__.__name__}, url: {self.url}, image size: {self.size[0]}x{self.size[1]}, buffer size: {len(self.image)} B" +@dataclass() +class VisionImageFrame(ImageFrame): + """An image with an associated text to ask for a description of it. Will be shown by the + transport if the transport's camera is enabled. + + """ + text: str | None + + def __init__(self, text, image, size): + super().__init__(image, size) + self.text = text + + def __str__(self): + return f"{self.__class__.__name__}, text: {self.text}, image size: {self.size[0]}x{self.size[1]}, buffer size: {len(self.image)} B" + + @dataclass() class UserImageFrame(ImageFrame): """An image associated to a user. Will be shown by the transport if the transport's camera is diff --git a/src/dailyai/services/ai_services.py b/src/dailyai/services/ai_services.py index d620be3d2..babf6be75 100644 --- a/src/dailyai/services/ai_services.py +++ b/src/dailyai/services/ai_services.py @@ -15,6 +15,7 @@ TextFrame, TranscriptionFrame, URLImageFrame, + VisionImageFrame, ) from abc import abstractmethod @@ -108,19 +109,13 @@ def __init__(self, **kwargs): self._describe_text = None @abstractmethod - async def run_vision(self, describe_text: str, frame: ImageFrame) -> str: + async def run_vision(self, frame: VisionImageFrame) -> str: pass async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: - if isinstance(frame, TextFrame): - self._describe_text = frame.text - elif isinstance(frame, ImageFrame): - if self._describe_text: - description = await self.run_vision(self._describe_text, frame) - self._describe_text = None - yield TextFrame(description) - else: - yield frame + if isinstance(frame, VisionImageFrame): + description = await self.run_vision(frame) + yield TextFrame(description) else: yield frame diff --git a/src/dailyai/services/moondream_ai_service.py b/src/dailyai/services/moondream_ai_service.py index 2784a669e..07ff9e534 100644 --- a/src/dailyai/services/moondream_ai_service.py +++ b/src/dailyai/services/moondream_ai_service.py @@ -1,4 +1,4 @@ -from dailyai.pipeline.frames import ImageFrame +from dailyai.pipeline.frames import ImageFrame, VisionImageFrame from dailyai.services.ai_services import VisionService from PIL import Image @@ -42,11 +42,11 @@ def __init__( ).to(device=device, dtype=dtype) self._model.eval() - async def run_vision(self, describe_text: str, frame: ImageFrame) -> str: + async def run_vision(self, frame: VisionImageFrame) -> str: image = Image.frombytes("RGB", (frame.size[0], frame.size[1]), frame.image) image_embeds = self._model.encode_image(image) description = self._model.answer_question( image_embeds=image_embeds, - question=describe_text, + question=frame.text, tokenizer=self._tokenizer) return description