From 4589033c7913c7f33273c6fc1184d4babea4c737 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Thu, 12 Dec 2024 21:23:04 +0530 Subject: [PATCH] allow web widget to interact with sources while the lipsync is being rendered --- daras_ai_v2/bots.py | 9 +++++++-- daras_ai_v2/slack_bot.py | 7 ++++--- routers/bots_api.py | 18 ++++++++++++++++-- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index b7222cb44..eb53d9796 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -24,6 +24,7 @@ from daras_ai_v2.asr import run_google_translate, should_translate_lang from daras_ai_v2.base import BasePage, RecipeRunState, StateKeys from daras_ai_v2.language_model import CHATML_ROLE_USER, CHATML_ROLE_ASSISTANT +from daras_ai_v2.search_ref import SearchReference from daras_ai_v2.vector_search import doc_url_to_file_metadata from gooeysite.bg_db_conn import db_middleware from recipes.VideoBots import VideoBotsPage, ReplyButton @@ -205,7 +206,9 @@ def get_interactive_msg_info(self) -> ButtonPressed: def on_run_created(self, sr: "SavedRun"): pass - def send_run_status(self, update_msg_id: str | None) -> str | None: + def send_run_status( + self, update_msg_id: str | None, references: list[SearchReference] | None = None + ) -> str | None: pass def nice_filename(self, mime_type: str) -> str: @@ -411,7 +414,9 @@ def _process_and_send_msg( text = state.get("output_text") and state.get("output_text")[0] if not text: # if no text, send the run status as text - update_msg_id = bot.send_run_status(update_msg_id=update_msg_id) + update_msg_id = bot.send_run_status( + update_msg_id=update_msg_id, references=state.get("references") + ) continue # no text, wait for the next update streaming_done = state.get("finish_reason") # send the response to the user diff --git a/daras_ai_v2/slack_bot.py b/daras_ai_v2/slack_bot.py index 388b1af7d..927b55b85 100644 --- a/daras_ai_v2/slack_bot.py +++ b/daras_ai_v2/slack_bot.py @@ -11,13 +11,12 @@ from bots.models import BotIntegration, Platform, Conversation from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2.asr import ( - run_google_translate, audio_bytes_to_wav, - should_translate_lang, ) from daras_ai_v2.bots import BotInterface, SLACK_MAX_SIZE, ButtonPressed from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.functional import fetch_parallel +from daras_ai_v2.search_ref import SearchReference from daras_ai_v2.text_splitter import text_splitter from recipes.VideoBots import ReplyButton @@ -133,7 +132,9 @@ def get_interactive_msg_info(self) -> ButtonPressed: button_id=self._actions[0]["value"], context_msg_id=self._msg_ts ) - def send_run_status(self, update_msg_id: str | None) -> str | None: + def send_run_status( + self, update_msg_id: str | None, references: list[SearchReference] | None = None + ) -> str | None: if not self.run_status: return update_msg_id return self.send_msg(text=self.run_status, update_msg_id=update_msg_id) diff --git a/routers/bots_api.py b/routers/bots_api.py index c960fa33f..63a2637db 100644 --- a/routers/bots_api.py +++ b/routers/bots_api.py @@ -16,6 +16,7 @@ from daras_ai_v2.base import RecipeRunState, StateKeys from daras_ai_v2.bots import BotInterface, msg_handler, ButtonPressed from daras_ai_v2.redis_cache import get_redis_cache +from daras_ai_v2.search_ref import SearchReference from recipes.VideoBots import VideoBotsPage, ReplyButton from routers.api import ( AsyncApiResponseModelV3, @@ -141,6 +142,8 @@ class MessagePart(BaseModel): description="Details about the status of the run as a human readable string" ) + references: list[SearchReference] | None + text: str | None audio: str | None video: str | None @@ -326,9 +329,20 @@ def on_run_created(self, sr: SavedRun): self.uid = sr.uid self.queue.put(RunStart(**build_async_api_response(sr))) - def send_run_status(self, update_msg_id: str | None) -> str | None: + def send_run_status( + self, update_msg_id: str | None, references: list[SearchReference] | None = None + ) -> str | None: self.queue.put( - MessagePart(status=self.recipe_run_state, detail=self.run_status) + MessagePart( + status=self.recipe_run_state, + detail=self.run_status, + references=( + # avoid sending the entire snippet to save bandwidth + [r | dict(snippet="") for r in references] + if references + else None + ), + ) ) return None