Skip to content

Commit

Permalink
allow web widget to interact with sources while the lipsync is being …
Browse files Browse the repository at this point in the history
…rendered
  • Loading branch information
devxpy committed Dec 12, 2024
1 parent 14356c1 commit 4589033
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
9 changes: 7 additions & 2 deletions daras_ai_v2/bots.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from daras_ai_v2.asr import run_google_translate, should_translate_lang
from daras_ai_v2.base import BasePage, RecipeRunState, StateKeys
from daras_ai_v2.language_model import CHATML_ROLE_USER, CHATML_ROLE_ASSISTANT
from daras_ai_v2.search_ref import SearchReference
from daras_ai_v2.vector_search import doc_url_to_file_metadata
from gooeysite.bg_db_conn import db_middleware
from recipes.VideoBots import VideoBotsPage, ReplyButton
Expand Down Expand Up @@ -205,7 +206,9 @@ def get_interactive_msg_info(self) -> ButtonPressed:
def on_run_created(self, sr: "SavedRun"):
pass

def send_run_status(self, update_msg_id: str | None) -> str | None:
def send_run_status(
self, update_msg_id: str | None, references: list[SearchReference] | None = None
) -> str | None:
pass

def nice_filename(self, mime_type: str) -> str:
Expand Down Expand Up @@ -411,7 +414,9 @@ def _process_and_send_msg(
text = state.get("output_text") and state.get("output_text")[0]
if not text:
# if no text, send the run status as text
update_msg_id = bot.send_run_status(update_msg_id=update_msg_id)
update_msg_id = bot.send_run_status(
update_msg_id=update_msg_id, references=state.get("references")
)
continue # no text, wait for the next update
streaming_done = state.get("finish_reason")
# send the response to the user
Expand Down
7 changes: 4 additions & 3 deletions daras_ai_v2/slack_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
from bots.models import BotIntegration, Platform, Conversation
from daras_ai.image_input import upload_file_from_bytes
from daras_ai_v2.asr import (
run_google_translate,
audio_bytes_to_wav,
should_translate_lang,
)
from daras_ai_v2.bots import BotInterface, SLACK_MAX_SIZE, ButtonPressed
from daras_ai_v2.exceptions import raise_for_status
from daras_ai_v2.functional import fetch_parallel
from daras_ai_v2.search_ref import SearchReference
from daras_ai_v2.text_splitter import text_splitter
from recipes.VideoBots import ReplyButton

Expand Down Expand Up @@ -133,7 +132,9 @@ def get_interactive_msg_info(self) -> ButtonPressed:
button_id=self._actions[0]["value"], context_msg_id=self._msg_ts
)

def send_run_status(self, update_msg_id: str | None) -> str | None:
def send_run_status(
self, update_msg_id: str | None, references: list[SearchReference] | None = None
) -> str | None:
if not self.run_status:
return update_msg_id
return self.send_msg(text=self.run_status, update_msg_id=update_msg_id)
Expand Down
18 changes: 16 additions & 2 deletions routers/bots_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from daras_ai_v2.base import RecipeRunState, StateKeys
from daras_ai_v2.bots import BotInterface, msg_handler, ButtonPressed
from daras_ai_v2.redis_cache import get_redis_cache
from daras_ai_v2.search_ref import SearchReference
from recipes.VideoBots import VideoBotsPage, ReplyButton
from routers.api import (
AsyncApiResponseModelV3,
Expand Down Expand Up @@ -141,6 +142,8 @@ class MessagePart(BaseModel):
description="Details about the status of the run as a human readable string"
)

references: list[SearchReference] | None

text: str | None
audio: str | None
video: str | None
Expand Down Expand Up @@ -326,9 +329,20 @@ def on_run_created(self, sr: SavedRun):
self.uid = sr.uid
self.queue.put(RunStart(**build_async_api_response(sr)))

def send_run_status(self, update_msg_id: str | None) -> str | None:
def send_run_status(
self, update_msg_id: str | None, references: list[SearchReference] | None = None
) -> str | None:
self.queue.put(
MessagePart(status=self.recipe_run_state, detail=self.run_status)
MessagePart(
status=self.recipe_run_state,
detail=self.run_status,
references=(
# avoid sending the entire snippet to save bandwidth
[r | dict(snippet="") for r in references]
if references
else None
),
)
)
return None

Expand Down

0 comments on commit 4589033

Please sign in to comment.