Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into copilot_stats
Browse files Browse the repository at this point in the history
# Conflicts:
#	daras_ai_v2/base.py
  • Loading branch information
devxpy committed Jan 24, 2024
2 parents e3a2461 + d97ffe4 commit 0f91b82
Show file tree
Hide file tree
Showing 11 changed files with 340 additions and 122 deletions.
16 changes: 11 additions & 5 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from django.contrib.auth import get_user_model
from django.db import models, transaction
from django.db.models import Q
from django.utils.text import Truncator
from django.utils.text import Truncator, slugify
from furl import furl
from phonenumber_field.modelfields import PhoneNumberField

Expand Down Expand Up @@ -93,7 +93,7 @@ class Workflow(models.IntegerChoices):
def short_slug(self):
return min(self.page_cls.slug_versions, key=len)

def get_app_url(self, example_id: str, run_id: str, uid: str):
def get_app_url(self, example_id: str, run_id: str, uid: str, run_slug: str = ""):
"""return the url to the gooey app"""
query_params = {}
if run_id and uid:
Expand All @@ -102,7 +102,8 @@ def get_app_url(self, example_id: str, run_id: str, uid: str):
query_params |= dict(example_id=example_id)
return str(
furl(settings.APP_BASE_URL, query_params=query_params)
/ self.short_slug
/ self.page_cls.slug_versions[-1]
/ run_slug
/ "/"
)

Expand Down Expand Up @@ -903,7 +904,9 @@ def as_llm_context(self, limit: int = 100) -> list["ConversationEntry"]:
entries[i] = format_chat_entry(
role=msg.role,
content=msg.content,
images=msg.attachments.values_list("url", flat=True),
images=msg.attachments.filter(
metadata__mime_type__startswith="image/"
).values_list("url", flat=True),
)
return entries

Expand Down Expand Up @@ -1288,7 +1291,10 @@ def duplicate(

def get_app_url(self):
return Workflow(self.workflow).get_app_url(
example_id=self.published_run_id, run_id="", uid=""
example_id=self.published_run_id,
run_id="",
uid="",
run_slug=self.title and slugify(self.title),
)

def add_version(
Expand Down
116 changes: 70 additions & 46 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import requests
import sentry_sdk
from django.utils import timezone
from django.utils.text import slugify
from fastapi import HTTPException
from firebase_admin import auth
from furl import furl
Expand Down Expand Up @@ -136,11 +137,17 @@ def app_url(
query_params = cls.clean_query_params(
example_id=example_id, run_id=run_id, uid=uid
) | (query_params or {})
f = furl(settings.APP_BASE_URL, query_params=query_params) / (
cls.slug_versions[-1] + "/"
f = (
furl(settings.APP_BASE_URL, query_params=query_params)
/ cls.slug_versions[-1]
)
if example_id := query_params.get("example_id"):
pr = cls.get_published_run(published_run_id=example_id)
if pr and pr.title:
f /= slugify(pr.title)
if tab_name:
f /= tab_name + "/"
f /= tab_name
f /= "/" # keep trailing slash
return str(f)

@classmethod
Expand Down Expand Up @@ -180,12 +187,18 @@ def setup_render(self):
"/" + self.slug_versions[0], source=TRANSACTION_SOURCE_ROUTE
)

def refresh_state(self):
_, run_id, uid = extract_query_params(gooey_get_query_params())
channel = f"gooey-outputs/{self.slug_versions[0]}/{uid}/{run_id}"
output = realtime_pull([channel])[0]
if output:
st.session_state.update(output)

def render(self):
self.setup_render()

if self.get_run_state() == RecipeRunState.running:
_, run_id, uid = extract_query_params(gooey_get_query_params())
channel = f"gooey-outputs/{self.slug_versions[0]}/{uid}/{run_id}"
output = realtime_pull([channel])[0]
if output:
st.session_state.update(output)
self.refresh_state()
else:
realtime_clear_subs()

Expand All @@ -196,9 +209,22 @@ def setup_render(self):
self.render_report_form()
return

def render(self):
self.setup_render()
self._render_header()

self._render_tab_menu(selected_tab=self.tab)
with st.nav_tab_content():
self.render_selected_tab(self.tab)

def _render_tab_menu(self, selected_tab: str):
assert selected_tab in MenuTabs.paths

with st.nav_tabs():
for name in self.get_tabs():
url = self.get_tab_url(name)
with st.nav_item(url, active=name == selected_tab):
st.html(name)

def _render_header(self):
current_run = self.get_current_sr()
published_run = self.get_current_published_run()
is_root_example = (
Expand All @@ -207,6 +233,7 @@ def render(self):
and published_run.saved_run == current_run
)
tbreadcrumbs = get_title_breadcrumbs(self, current_run, published_run)

with st.div(className="d-flex justify-content-between mt-4"):
with st.div(className="d-lg-flex d-block align-items-center"):
if not tbreadcrumbs.has_breadcrumbs() and not self.run_user:
Expand Down Expand Up @@ -268,21 +295,6 @@ def render(self):
elif is_root_example:
st.write(self.preview_description(current_run.to_dict()))

try:
selected_tab = MenuTabs.paths_reverse[self.tab]
except KeyError:
st.error(f"## 404 - Tab {self.tab!r} Not found")
return

with st.nav_tabs():
tab_names = self.get_tabs()
for name in tab_names:
url = self.get_tab_url(name)
with st.nav_item(url, active=name == selected_tab):
st.html(name)
with st.nav_tab_content():
self.render_selected_tab(selected_tab)

def _render_title(self, title: str):
st.write(f"# {title}")

Expand Down Expand Up @@ -450,14 +462,14 @@ def _render_publish_modal(

if not pressed_save:
return
recipe_title = self.get_root_published_run().title or self.title

is_root_published_run = is_update_mode and published_run.is_root()
if (
not is_root_published_run
and published_run_title.strip() == recipe_title.strip()
):
st.error("Title can't be the same as the recipe title", icon="⚠️")
return
if not is_root_published_run:
try:
self._validate_published_run_title(published_run_title)
except TitleValidationError as e:
st.error(str(e))
return

if is_update_mode:
updates = dict(
Expand All @@ -483,6 +495,18 @@ def _render_publish_modal(
)
force_redirect(published_run.get_app_url())

def _validate_published_run_title(self, title: str):
if slugify(title) in settings.DISALLOWED_TITLE_SLUGS:
raise TitleValidationError(
"This title is not allowed. Please choose a different title."
)
elif title.strip() == self.get_recipe_title():
raise TitleValidationError(
"Please choose a different title for your published run."
)
elif title.strip() == "":
raise TitleValidationError("Title cannot be empty.")

def _has_published_run_changed(
self,
*,
Expand Down Expand Up @@ -911,7 +935,7 @@ def get_current_published_run(cls) -> PublishedRun | None:
@classmethod
def get_pr_from_query_params(
cls, example_id: str, run_id: str, uid: str
) -> PublishedRun:
) -> PublishedRun | None:
if run_id and uid:
sr = cls.get_sr_from_query_params(example_id, run_id, uid)
return (
Expand Down Expand Up @@ -1883,6 +1907,17 @@ def err_msg_for_exc(e):
return f"{type(e).__name__}: {e}"


def force_redirect(url: str):
# note: assumes sanitized URLs
st.html(
f"""
<script>
window.location = '{url}';
</script>
"""
)


class RedirectException(Exception):
def __init__(self, url, status_code=302):
self.url = url
Expand All @@ -1896,16 +1931,5 @@ def __init__(self, query_params: dict, status_code=303):
super().__init__(url, status_code)


def force_redirect(url: str):
# note: assumes sanitized URLs
st.html(
f"""
<script>
window.location = '{url}';
</script>
"""
)


def reverse_enumerate(start, iterator):
return zip(range(start, -1, -1), iterator)
class TitleValidationError(Exception):
pass
30 changes: 26 additions & 4 deletions daras_ai_v2/bots.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,12 @@ def get_input_audio(self) -> str | None:
def get_input_images(self) -> list[str] | None:
raise NotImplementedError

def get_input_documents(self) -> list[str] | None:
raise NotImplementedError

def nice_filename(self, mime_type: str) -> str:
ext = mimetypes.guess_extension(mime_type) or ""
return f"{self.platform}_{self.input_type}_from_{self.user_id}_to_{self.bot_id}{ext}"
return f"{self.platform.name}_{self.input_type}_from_{self.user_id}_to_{self.bot_id}{ext}"

def _unpack_bot_integration(self):
bi = self.convo.bot_integration
Expand Down Expand Up @@ -198,6 +201,7 @@ def _mock_api_output(input_text):
def _on_msg(bot: BotInterface):
speech_run = None
input_images = None
input_documents = None
if not bot.page_cls:
bot.send_msg(text=PAGE_NOT_CONNECTED_ERROR)
return
Expand Down Expand Up @@ -240,6 +244,17 @@ def _on_msg(bot: BotInterface):
status_code=400, detail="No image found in request."
)
input_text = (bot.get_input_text() or "").strip()
case "document":
input_documents = bot.get_input_documents()
if not input_documents:
raise HTTPException(
status_code=400, detail="No documents found in request."
)
filenames = ", ".join(
furl(url.strip("/")).path.segments[-1] for url in input_documents
)
input_text = (bot.get_input_text() or "").strip()
input_text = f"Files: {filenames}\n\n{input_text}"
case "text":
input_text = (bot.get_input_text() or "").strip()
if not input_text:
Expand Down Expand Up @@ -268,6 +283,7 @@ def _on_msg(bot: BotInterface):
billing_account_user=billing_account_user,
bot=bot,
input_images=input_images,
input_documents=input_documents,
input_text=input_text,
speech_run=speech_run,
)
Expand Down Expand Up @@ -306,6 +322,7 @@ def _process_and_send_msg(
billing_account_user: AppUser,
bot: BotInterface,
input_images: list[str] | None,
input_documents: list[str] | None,
input_text: str,
speech_run: str | None,
):
Expand All @@ -324,6 +341,7 @@ def _process_and_send_msg(
user_language=bot.language,
speech_run=speech_run,
input_images=input_images,
input_documents=input_documents,
)
except HTTPException as e:
traceback.print_exc()
Expand All @@ -345,6 +363,7 @@ def _process_and_send_msg(
_save_msgs(
bot=bot,
input_images=input_images,
input_documents=input_documents,
input_text=input_text,
speech_run=speech_run,
platform_msg_id=msg_id,
Expand All @@ -356,6 +375,7 @@ def _process_and_send_msg(
def _save_msgs(
bot: BotInterface,
input_images: list[str] | None,
input_documents: list[str] | None,
input_text: str,
speech_run: str | None,
platform_msg_id: str | None,
Expand All @@ -376,10 +396,10 @@ def _save_msgs(
else None,
)
attachments = []
for img in input_images or []:
metadata = doc_url_to_file_metadata(img)
for f_url in (input_images or []) + (input_documents or []):
metadata = doc_url_to_file_metadata(f_url)
attachments.append(
MessageAttachment(message=user_msg, url=img, metadata=metadata)
MessageAttachment(message=user_msg, url=f_url, metadata=metadata)
)
assistant_msg = Message(
platform_msg_id=platform_msg_id,
Expand Down Expand Up @@ -407,6 +427,7 @@ def _process_msg(
query_params: dict,
convo: Conversation,
input_images: list[str] | None,
input_documents: list[str] | None,
input_text: str,
user_language: str,
speech_run: str | None,
Expand All @@ -426,6 +447,7 @@ def _process_msg(
request_body={
"input_prompt": input_text,
"input_images": input_images,
"input_documents": input_documents,
"messages": saved_msgs,
"user_language": user_language,
},
Expand Down
7 changes: 7 additions & 0 deletions daras_ai_v2/facebook_bots.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ def get_input_images(self) -> list[str] | None:
return None
return [self._download_wa_media(media_id)]

def get_input_documents(self) -> list[str] | None:
try:
media_id = self.input_message["document"]["id"]
except KeyError:
return None
return [self._download_wa_media(media_id)]

def _download_wa_media(self, media_id: str) -> str:
# download file from whatsapp
data, mime_type = retrieve_wa_media_by_id(media_id)
Expand Down
Loading

0 comments on commit 0f91b82

Please sign in to comment.