diff --git a/bots/models.py b/bots/models.py
index 8f358c50b..1d77b57fd 100644
--- a/bots/models.py
+++ b/bots/models.py
@@ -8,7 +8,7 @@
from django.contrib.auth import get_user_model
from django.db import models, transaction
from django.db.models import Q
-from django.utils.text import Truncator
+from django.utils.text import Truncator, slugify
from furl import furl
from phonenumber_field.modelfields import PhoneNumberField
@@ -93,7 +93,7 @@ class Workflow(models.IntegerChoices):
def short_slug(self):
return min(self.page_cls.slug_versions, key=len)
- def get_app_url(self, example_id: str, run_id: str, uid: str):
+ def get_app_url(self, example_id: str, run_id: str, uid: str, run_slug: str = ""):
"""return the url to the gooey app"""
query_params = {}
if run_id and uid:
@@ -102,7 +102,8 @@ def get_app_url(self, example_id: str, run_id: str, uid: str):
query_params |= dict(example_id=example_id)
return str(
furl(settings.APP_BASE_URL, query_params=query_params)
- / self.short_slug
+ / self.page_cls.slug_versions[-1]
+ / run_slug
/ "/"
)
@@ -903,7 +904,9 @@ def as_llm_context(self, limit: int = 100) -> list["ConversationEntry"]:
entries[i] = format_chat_entry(
role=msg.role,
content=msg.content,
- images=msg.attachments.values_list("url", flat=True),
+ images=msg.attachments.filter(
+ metadata__mime_type__startswith="image/"
+ ).values_list("url", flat=True),
)
return entries
@@ -1288,7 +1291,10 @@ def duplicate(
def get_app_url(self):
return Workflow(self.workflow).get_app_url(
- example_id=self.published_run_id, run_id="", uid=""
+ example_id=self.published_run_id,
+ run_id="",
+ uid="",
+ run_slug=self.title and slugify(self.title),
)
def add_version(
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 0a4dd2d1b..edd3d9ec8 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -16,6 +16,7 @@
import requests
import sentry_sdk
from django.utils import timezone
+from django.utils.text import slugify
from fastapi import HTTPException
from firebase_admin import auth
from furl import furl
@@ -136,11 +137,17 @@ def app_url(
query_params = cls.clean_query_params(
example_id=example_id, run_id=run_id, uid=uid
) | (query_params or {})
- f = furl(settings.APP_BASE_URL, query_params=query_params) / (
- cls.slug_versions[-1] + "/"
+ f = (
+ furl(settings.APP_BASE_URL, query_params=query_params)
+ / cls.slug_versions[-1]
)
+ if example_id := query_params.get("example_id"):
+ pr = cls.get_published_run(published_run_id=example_id)
+ if pr and pr.title:
+ f /= slugify(pr.title)
if tab_name:
- f /= tab_name + "/"
+ f /= tab_name
+ f /= "/" # keep trailing slash
return str(f)
@classmethod
@@ -180,12 +187,18 @@ def setup_render(self):
"/" + self.slug_versions[0], source=TRANSACTION_SOURCE_ROUTE
)
+ def refresh_state(self):
+ _, run_id, uid = extract_query_params(gooey_get_query_params())
+ channel = f"gooey-outputs/{self.slug_versions[0]}/{uid}/{run_id}"
+ output = realtime_pull([channel])[0]
+ if output:
+ st.session_state.update(output)
+
+ def render(self):
+ self.setup_render()
+
if self.get_run_state() == RecipeRunState.running:
- _, run_id, uid = extract_query_params(gooey_get_query_params())
- channel = f"gooey-outputs/{self.slug_versions[0]}/{uid}/{run_id}"
- output = realtime_pull([channel])[0]
- if output:
- st.session_state.update(output)
+ self.refresh_state()
else:
realtime_clear_subs()
@@ -196,9 +209,22 @@ def setup_render(self):
self.render_report_form()
return
- def render(self):
- self.setup_render()
+ self._render_header()
+
+ self._render_tab_menu(selected_tab=self.tab)
+ with st.nav_tab_content():
+ self.render_selected_tab(self.tab)
+
+ def _render_tab_menu(self, selected_tab: str):
+ assert selected_tab in MenuTabs.paths
+
+ with st.nav_tabs():
+ for name in self.get_tabs():
+ url = self.get_tab_url(name)
+ with st.nav_item(url, active=name == selected_tab):
+ st.html(name)
+ def _render_header(self):
current_run = self.get_current_sr()
published_run = self.get_current_published_run()
is_root_example = (
@@ -207,6 +233,7 @@ def render(self):
and published_run.saved_run == current_run
)
tbreadcrumbs = get_title_breadcrumbs(self, current_run, published_run)
+
with st.div(className="d-flex justify-content-between mt-4"):
with st.div(className="d-lg-flex d-block align-items-center"):
if not tbreadcrumbs.has_breadcrumbs() and not self.run_user:
@@ -268,21 +295,6 @@ def render(self):
elif is_root_example:
st.write(self.preview_description(current_run.to_dict()))
- try:
- selected_tab = MenuTabs.paths_reverse[self.tab]
- except KeyError:
- st.error(f"## 404 - Tab {self.tab!r} Not found")
- return
-
- with st.nav_tabs():
- tab_names = self.get_tabs()
- for name in tab_names:
- url = self.get_tab_url(name)
- with st.nav_item(url, active=name == selected_tab):
- st.html(name)
- with st.nav_tab_content():
- self.render_selected_tab(selected_tab)
-
def _render_title(self, title: str):
st.write(f"# {title}")
@@ -450,14 +462,14 @@ def _render_publish_modal(
if not pressed_save:
return
- recipe_title = self.get_root_published_run().title or self.title
+
is_root_published_run = is_update_mode and published_run.is_root()
- if (
- not is_root_published_run
- and published_run_title.strip() == recipe_title.strip()
- ):
- st.error("Title can't be the same as the recipe title", icon="⚠️")
- return
+ if not is_root_published_run:
+ try:
+ self._validate_published_run_title(published_run_title)
+ except TitleValidationError as e:
+ st.error(str(e))
+ return
if is_update_mode:
updates = dict(
@@ -483,6 +495,18 @@ def _render_publish_modal(
)
force_redirect(published_run.get_app_url())
+ def _validate_published_run_title(self, title: str):
+ if slugify(title) in settings.DISALLOWED_TITLE_SLUGS:
+ raise TitleValidationError(
+ "This title is not allowed. Please choose a different title."
+ )
+ elif title.strip() == self.get_recipe_title():
+ raise TitleValidationError(
+ "Please choose a different title for your published run."
+ )
+ elif title.strip() == "":
+ raise TitleValidationError("Title cannot be empty.")
+
def _has_published_run_changed(
self,
*,
@@ -911,7 +935,7 @@ def get_current_published_run(cls) -> PublishedRun | None:
@classmethod
def get_pr_from_query_params(
cls, example_id: str, run_id: str, uid: str
- ) -> PublishedRun:
+ ) -> PublishedRun | None:
if run_id and uid:
sr = cls.get_sr_from_query_params(example_id, run_id, uid)
return (
@@ -1883,6 +1907,17 @@ def err_msg_for_exc(e):
return f"{type(e).__name__}: {e}"
+def force_redirect(url: str):
+ # note: assumes sanitized URLs
+ st.html(
+ f"""
+
+ """
+ )
+
+
class RedirectException(Exception):
def __init__(self, url, status_code=302):
self.url = url
@@ -1896,16 +1931,5 @@ def __init__(self, query_params: dict, status_code=303):
super().__init__(url, status_code)
-def force_redirect(url: str):
- # note: assumes sanitized URLs
- st.html(
- f"""
-
- """
- )
-
-
-def reverse_enumerate(start, iterator):
- return zip(range(start, -1, -1), iterator)
+class TitleValidationError(Exception):
+ pass
diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py
index 31f64e762..dfc00c794 100644
--- a/daras_ai_v2/bots.py
+++ b/daras_ai_v2/bots.py
@@ -131,9 +131,12 @@ def get_input_audio(self) -> str | None:
def get_input_images(self) -> list[str] | None:
raise NotImplementedError
+ def get_input_documents(self) -> list[str] | None:
+ raise NotImplementedError
+
def nice_filename(self, mime_type: str) -> str:
ext = mimetypes.guess_extension(mime_type) or ""
- return f"{self.platform}_{self.input_type}_from_{self.user_id}_to_{self.bot_id}{ext}"
+ return f"{self.platform.name}_{self.input_type}_from_{self.user_id}_to_{self.bot_id}{ext}"
def _unpack_bot_integration(self):
bi = self.convo.bot_integration
@@ -198,6 +201,7 @@ def _mock_api_output(input_text):
def _on_msg(bot: BotInterface):
speech_run = None
input_images = None
+ input_documents = None
if not bot.page_cls:
bot.send_msg(text=PAGE_NOT_CONNECTED_ERROR)
return
@@ -240,6 +244,17 @@ def _on_msg(bot: BotInterface):
status_code=400, detail="No image found in request."
)
input_text = (bot.get_input_text() or "").strip()
+ case "document":
+ input_documents = bot.get_input_documents()
+ if not input_documents:
+ raise HTTPException(
+ status_code=400, detail="No documents found in request."
+ )
+ filenames = ", ".join(
+ furl(url.strip("/")).path.segments[-1] for url in input_documents
+ )
+ input_text = (bot.get_input_text() or "").strip()
+ input_text = f"Files: {filenames}\n\n{input_text}"
case "text":
input_text = (bot.get_input_text() or "").strip()
if not input_text:
@@ -268,6 +283,7 @@ def _on_msg(bot: BotInterface):
billing_account_user=billing_account_user,
bot=bot,
input_images=input_images,
+ input_documents=input_documents,
input_text=input_text,
speech_run=speech_run,
)
@@ -306,6 +322,7 @@ def _process_and_send_msg(
billing_account_user: AppUser,
bot: BotInterface,
input_images: list[str] | None,
+ input_documents: list[str] | None,
input_text: str,
speech_run: str | None,
):
@@ -324,6 +341,7 @@ def _process_and_send_msg(
user_language=bot.language,
speech_run=speech_run,
input_images=input_images,
+ input_documents=input_documents,
)
except HTTPException as e:
traceback.print_exc()
@@ -345,6 +363,7 @@ def _process_and_send_msg(
_save_msgs(
bot=bot,
input_images=input_images,
+ input_documents=input_documents,
input_text=input_text,
speech_run=speech_run,
platform_msg_id=msg_id,
@@ -356,6 +375,7 @@ def _process_and_send_msg(
def _save_msgs(
bot: BotInterface,
input_images: list[str] | None,
+ input_documents: list[str] | None,
input_text: str,
speech_run: str | None,
platform_msg_id: str | None,
@@ -376,10 +396,10 @@ def _save_msgs(
else None,
)
attachments = []
- for img in input_images or []:
- metadata = doc_url_to_file_metadata(img)
+ for f_url in (input_images or []) + (input_documents or []):
+ metadata = doc_url_to_file_metadata(f_url)
attachments.append(
- MessageAttachment(message=user_msg, url=img, metadata=metadata)
+ MessageAttachment(message=user_msg, url=f_url, metadata=metadata)
)
assistant_msg = Message(
platform_msg_id=platform_msg_id,
@@ -407,6 +427,7 @@ def _process_msg(
query_params: dict,
convo: Conversation,
input_images: list[str] | None,
+ input_documents: list[str] | None,
input_text: str,
user_language: str,
speech_run: str | None,
@@ -426,6 +447,7 @@ def _process_msg(
request_body={
"input_prompt": input_text,
"input_images": input_images,
+ "input_documents": input_documents,
"messages": saved_msgs,
"user_language": user_language,
},
diff --git a/daras_ai_v2/facebook_bots.py b/daras_ai_v2/facebook_bots.py
index 13004f8e1..2a22bbddd 100644
--- a/daras_ai_v2/facebook_bots.py
+++ b/daras_ai_v2/facebook_bots.py
@@ -71,6 +71,13 @@ def get_input_images(self) -> list[str] | None:
return None
return [self._download_wa_media(media_id)]
+ def get_input_documents(self) -> list[str] | None:
+ try:
+ media_id = self.input_message["document"]["id"]
+ except KeyError:
+ return None
+ return [self._download_wa_media(media_id)]
+
def _download_wa_media(self, media_id: str) -> str:
# download file from whatsapp
data, mime_type = retrieve_wa_media_by_id(media_id)
diff --git a/daras_ai_v2/meta_content.py b/daras_ai_v2/meta_content.py
index 9ed6b666f..12076d26f 100644
--- a/daras_ai_v2/meta_content.py
+++ b/daras_ai_v2/meta_content.py
@@ -1,7 +1,12 @@
+from django.utils.text import slugify
+from furl import furl
+
from bots.models import PublishedRun, SavedRun, WorkflowMetadata
+from daras_ai_v2 import settings
from daras_ai_v2.base import BasePage
from daras_ai_v2.breadcrumbs import get_title_breadcrumbs
from daras_ai_v2.meta_preview_url import meta_preview_url
+from daras_ai_v2.tabs_widget import MenuTabs
sep = " • "
@@ -15,32 +20,42 @@ def build_meta_tags(
uid: str,
example_id: str,
) -> list[dict]:
- sr, published_run = page.get_runs_from_query_params(example_id, run_id, uid)
+ sr, pr = page.get_runs_from_query_params(example_id, run_id, uid)
metadata = page.workflow.get_or_create_metadata()
title = meta_title_for_page(
page=page,
metadata=metadata,
sr=sr,
- published_run=published_run,
+ pr=pr,
)
description = meta_description_for_page(
metadata=metadata,
- published_run=published_run,
+ pr=pr,
)
image = meta_image_for_page(
page=page,
state=state,
sr=sr,
metadata=metadata,
- published_run=published_run,
+ pr=pr,
+ )
+ canonical_url = canonical_url_for_page(
+ page=page,
+ state=state,
+ sr=sr,
+ metadata=metadata,
+ pr=pr,
)
+ robots = robots_tag_for_page(page=page, sr=sr, pr=pr)
return raw_build_meta_tags(
url=url,
title=title,
description=description,
image=image,
+ canonical_url=canonical_url,
+ robots=robots,
)
@@ -50,6 +65,8 @@ def raw_build_meta_tags(
title: str,
description: str | None = None,
image: str | None = None,
+ canonical_url: str | None = None,
+ robots: str | None = None,
) -> list[dict[str, str]]:
ret = [
dict(title=title),
@@ -61,18 +78,27 @@ def raw_build_meta_tags(
dict(property="twitter:url", content=url),
dict(property="twitter:title", content=title),
]
+
if description:
ret += [
dict(name="description", content=description),
dict(property="og:description", content=description),
dict(property="twitter:description", content=description),
]
+
if image:
ret += [
dict(name="image", content=image),
dict(property="og:image", content=image),
dict(property="twitter:image", content=image),
]
+
+ if canonical_url:
+ ret += [dict(tagName="link", rel="canonical", href=canonical_url)]
+
+ if robots:
+ ret += [dict(name="robots", content=robots)]
+
return ret
@@ -81,9 +107,9 @@ def meta_title_for_page(
page: BasePage,
metadata: WorkflowMetadata,
sr: SavedRun,
- published_run: PublishedRun | None,
+ pr: PublishedRun | None,
) -> str:
- tbreadcrumbs = get_title_breadcrumbs(page, sr, published_run)
+ tbreadcrumbs = get_title_breadcrumbs(page, sr, pr)
parts = []
if tbreadcrumbs.published_title or tbreadcrumbs.root_title:
@@ -91,7 +117,7 @@ def meta_title_for_page(
# use the short title for non-root examples
part = metadata.short_title
if tbreadcrumbs.published_title:
- part = f"{published_run.title} {part}"
+ part = f"{pr.title} {part}"
# add the creator's name
user = sr.get_creator()
if user and user.display_name:
@@ -108,14 +134,14 @@ def meta_title_for_page(
def meta_description_for_page(
*,
metadata: WorkflowMetadata,
- published_run: PublishedRun | None,
+ pr: PublishedRun | None,
) -> str:
- if published_run and not published_run.is_root():
- description = published_run.notes or metadata.meta_description
+ if pr and not pr.is_root():
+ description = pr.notes or metadata.meta_description
else:
description = metadata.meta_description
- if not (published_run and published_run.is_root()) or not description:
+ if not (pr and pr.is_root()) or not description:
# for all non-root examples, or when there is no other description
description += sep + "AI API, workflow & prompt shared on Gooey.AI."
@@ -128,9 +154,9 @@ def meta_image_for_page(
state: dict,
metadata: WorkflowMetadata,
sr: SavedRun,
- published_run: PublishedRun | None,
+ pr: PublishedRun | None,
) -> str | None:
- if published_run and published_run.saved_run == sr and published_run.is_root():
+ if pr and pr.saved_run == sr and pr.is_root():
file_url = metadata.meta_image or page.preview_image(state)
else:
file_url = page.preview_image(state)
@@ -139,3 +165,91 @@ def meta_image_for_page(
file_url=file_url,
fallback_img=page.fallback_preivew_image(),
)
+
+
+def canonical_url_for_page(
+ *,
+ page: BasePage,
+ state: dict,
+ metadata: WorkflowMetadata,
+ sr: SavedRun,
+ pr: PublishedRun | None,
+) -> str:
+ """
+ Assumes that `page.tab` is a valid tab defined in MenuTabs
+ """
+
+ latest_slug = page.slug_versions[-1] # for recipe
+ recipe_url = furl(str(settings.APP_BASE_URL)) / latest_slug
+
+ if pr and pr.saved_run == sr and pr.is_root():
+ query_params = {}
+ pr_slug = ""
+ elif pr and pr.saved_run == sr:
+ query_params = {"example_id": pr.published_run_id}
+ pr_slug = (pr.title and slugify(pr.title)) or ""
+ else:
+ query_params = {"run_id": sr.run_id, "uid": sr.uid}
+ pr_slug = ""
+
+ tab_path = MenuTabs.paths[page.tab]
+ match page.tab:
+ case MenuTabs.examples:
+ # no query params / run_slug in this case
+ return str(recipe_url / tab_path / "/")
+ case MenuTabs.history, MenuTabs.saved:
+ # no run slug in this case
+ return str(furl(recipe_url, query_params=query_params) / tab_path / "/")
+ case _:
+ # all other cases
+ return str(
+ furl(recipe_url, query_params=query_params) / pr_slug / tab_path / "/"
+ )
+
+
+def robots_tag_for_page(
+ *,
+ page: BasePage,
+ sr: SavedRun,
+ pr: PublishedRun | None,
+) -> str:
+ is_root = pr and pr.saved_run == sr and pr.is_root()
+ is_example = pr and pr.saved_run == sr and not pr.is_root()
+
+ match page.tab:
+ case MenuTabs.run if is_root or is_example:
+ no_follow, no_index = False, False
+ case MenuTabs.run: # ordinary run (not example)
+ no_follow, no_index = False, True
+ case MenuTabs.examples:
+ no_follow, no_index = False, False
+ case MenuTabs.run_as_api:
+ no_follow, no_index = False, True
+ case MenuTabs.integrations:
+ no_follow, no_index = True, True
+ case MenuTabs.history:
+ no_follow, no_index = True, True
+ case MenuTabs.saved:
+ no_follow, no_index = True, True
+ case _:
+ raise ValueError(f"Unknown tab: {page.tab}")
+
+ parts = []
+ if no_follow:
+ parts.append("nofollow")
+ if no_index:
+ parts.append("noindex")
+ return ",".join(parts)
+
+
+def get_is_indexable_for_page(
+ *,
+ page: BasePage,
+ sr: SavedRun,
+ pr: PublishedRun | None,
+) -> bool:
+ if pr and pr.saved_run == sr and pr.is_root():
+ # index all tabs on root
+ return True
+
+ return bool(pr and pr.saved_run == sr and page.tab == MenuTabs.run)
diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py
index 48e661dcf..74f38ed5f 100644
--- a/daras_ai_v2/settings.py
+++ b/daras_ai_v2/settings.py
@@ -244,6 +244,17 @@
SUPPORT_EMAIL = "Gooey.AI Support "
SEND_RUN_EMAIL_AFTER_SEC = config("SEND_RUN_EMAIL_AFTER_SEC", 20)
+DISALLOWED_TITLE_SLUGS = config("DISALLOWED_TITLE_SLUGS", cast=Csv(), default="") + [
+ # tab names
+ "api",
+ "examples",
+ "history",
+ "saved",
+ "integrations",
+ # other
+ "docs",
+]
+
SAFTY_CHECKER_EXAMPLE_ID = "3rcxqx0r"
SAFTY_CHECKER_BILLING_EMAIL = "support+mods@gooey.ai"
diff --git a/recipes/LipsyncTTS.py b/recipes/LipsyncTTS.py
index f85d4550a..60886c017 100644
--- a/recipes/LipsyncTTS.py
+++ b/recipes/LipsyncTTS.py
@@ -62,13 +62,12 @@ class ResponseModel(BaseModel):
def related_workflows(self) -> list:
from recipes.VideoBots import VideoBotsPage
from recipes.DeforumSD import DeforumSDPage
- from recipes.CompareText2Img import CompareText2ImgPage
return [
VideoBotsPage,
TextToSpeechPage,
DeforumSDPage,
- CompareText2ImgPage,
+ LipsyncPage,
]
def render_form_v2(self):
diff --git a/recipes/SocialLookupEmail.py b/recipes/SocialLookupEmail.py
index eab559b84..082fdcdaa 100644
--- a/recipes/SocialLookupEmail.py
+++ b/recipes/SocialLookupEmail.py
@@ -25,7 +25,7 @@ class SocialLookupEmailPage(BasePage):
slug_versions = ["SocialLookupEmail", "email-writer-with-profile-lookup"]
sane_defaults = {
- "selected_model": LargeLanguageModels.text_davinci_003.name,
+ "selected_model": LargeLanguageModels.gpt_4.name,
}
class RequestModel(BaseModel):
diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py
index 1d80162c9..149931ffb 100644
--- a/recipes/VideoBots.py
+++ b/recipes/VideoBots.py
@@ -1,4 +1,5 @@
import json
+import mimetypes
import os
import os.path
import typing
@@ -150,6 +151,7 @@ class RequestModel(BaseModel):
input_prompt: str
input_images: list[str] | None
+ input_documents: list[str] | None
# conversation history/context
messages: list[ConversationEntry] | None
@@ -458,16 +460,22 @@ def render_output(self):
# chat window
with st.div(className="pb-3"):
chat_list_view()
- pressed_send, new_input, new_input_images = chat_input_view()
+ (
+ pressed_send,
+ new_input,
+ new_input_images,
+ new_input_documents,
+ ) = chat_input_view()
if pressed_send:
- self.on_send(new_input, new_input_images)
+ self.on_send(new_input, new_input_images, new_input_documents)
# clear chat inputs
if st.button("🗑️ Clear"):
st.session_state["messages"] = []
st.session_state["input_prompt"] = ""
st.session_state["input_images"] = None
+ st.session_state["new_input_documents"] = None
st.session_state["raw_input_text"] = ""
self.clear_outputs()
st.session_state["final_keyword_query"] = ""
@@ -487,12 +495,18 @@ def render_output(self):
label_visibility="collapsed",
)
- def on_send(self, new_input: str, new_input_images: list[str]):
+ def on_send(
+ self,
+ new_input: str,
+ new_input_images: list[str],
+ new_input_documents: list[str],
+ ):
prev_input = st.session_state.get("raw_input_text") or ""
prev_output = (st.session_state.get("raw_output_text") or [""])[0]
prev_input_images = st.session_state.get("input_images")
+ prev_input_documents = st.session_state.get("input_documents")
- if (prev_input or prev_input_images) and prev_output:
+ if (prev_input or prev_input_images or prev_input_documents) and prev_output:
# append previous input to the history
st.session_state["messages"] = st.session_state.get("messages", []) + [
format_chat_entry(
@@ -507,8 +521,14 @@ def on_send(self, new_input: str, new_input_images: list[str]):
]
# add new input to the state
+ if new_input_documents:
+ filenames = ", ".join(
+ furl(url.strip("/")).path.segments[-1] for url in new_input_documents
+ )
+ new_input = f"Files: {filenames}\n\n{new_input}"
st.session_state["input_prompt"] = new_input
st.session_state["input_images"] = new_input_images or None
+ st.session_state["input_documents"] = new_input_documents or None
self.on_submit()
@@ -544,6 +564,7 @@ def render_steps(self):
if isinstance(final_prompt, str):
text_output("**Final Prompt**", value=final_prompt, height=300)
else:
+ st.write("**Final Prompt**")
st.json(final_prompt)
for idx, text in enumerate(st.session_state.get("raw_output_text", [])):
@@ -601,7 +622,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
"""
user_input = request.input_prompt.strip()
- if not (user_input or request.input_images):
+ if not (user_input or request.input_images or request.input_documents):
return
model = LargeLanguageModels[request.selected_model]
is_chat_model = model.is_chat_model()
@@ -609,13 +630,11 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
bot_script = request.bot_script
ocr_texts = []
- if request.input_images:
+ if request.document_model and (request.input_images or request.input_documents):
yield "Running Azure Form Recognizer..."
- for img in request.input_images:
+ for url in (request.input_images or []) + (request.input_documents or []):
ocr_text = (
- azure_form_recognizer(
- img, model_id=request.document_model or "prebuilt-read"
- )
+ azure_form_recognizer(url, model_id="prebuilt-read")
.get("content", "")
.strip()
)
@@ -641,7 +660,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
target_language="en",
)
for text in ocr_texts:
- user_input = f"Image: {text!r}\n{user_input}"
+ user_input = f"Exracted Text: {text!r}\n\n{user_input}"
# parse the bot script
# system_message, scripted_msgs = parse_script(bot_script)
@@ -666,18 +685,16 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
# except IndexError:
# user_display_name = CHATML_ROLE_USER
- # construct user prompt
+ # save raw input for reference
state["raw_input_text"] = user_input
- user_prompt = {
- "role": CHATML_ROLE_USER,
- "content": user_input,
- }
# if documents are provided, run doc search on the saved msgs and get back the references
references = None
if request.documents:
# formulate the search query as a history of all the messages
- query_msgs = saved_msgs + [user_prompt]
+ query_msgs = saved_msgs + [
+ format_chat_entry(role=CHATML_ROLE_USER, content=user_input)
+ ]
clip_idx = convo_window_clipper(
query_msgs, model_max_tokens[model] // 2, sep=" "
)
@@ -741,17 +758,22 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
if references:
# add task instructions
task_instructions = render_prompt_vars(request.task_instructions, state)
- user_prompt["content"] = (
+ user_input = (
references_as_prompt(references)
+ f"\n**********\n{task_instructions.strip()}\n**********\n"
- + user_prompt["content"]
+ + user_input
)
+ # construct user prompt
+ user_prompt = format_chat_entry(
+ role=CHATML_ROLE_USER, content=user_input, images=request.input_images
+ )
+
# truncate the history to fit the model's max tokens
history_window = scripted_msgs + saved_msgs
max_history_tokens = (
model_max_tokens[model]
- - calc_gpt_tokens([system_prompt, user_prompt], is_chat_model=is_chat_model)
+ - calc_gpt_tokens([system_prompt, user_input], is_chat_model=is_chat_model)
- request.max_tokens
- SAFETY_BUFFER
)
@@ -853,7 +875,9 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
tts_state = dict(state)
for text in state.get("raw_tts_text", state["raw_output_text"]):
tts_state["text_prompt"] = text
- yield from TextToSpeechPage().run(tts_state)
+ yield from TextToSpeechPage(
+ request=self.request, run_user=self.run_user
+ ).run(tts_state)
state["output_audio"].append(tts_state["audio_url"])
if not request.input_face:
@@ -861,7 +885,9 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
lip_state = dict(state)
for audio_url in state["output_audio"]:
lip_state["input_audio"] = audio_url
- yield from LipsyncPage().run(lip_state)
+ yield from LipsyncPage(request=self.request, run_user=self.run_user).run(
+ lip_state
+ )
state["output_video"].append(lip_state["output_video"])
def get_tabs(self):
@@ -1217,7 +1243,7 @@ def chat_list_view():
st.image(im, style={"maxHeight": "200px"})
-def chat_input_view() -> tuple[bool, str, list[str]]:
+def chat_input_view() -> tuple[bool, str, list[str], list[str]]:
with st.div(
className="px-3 pt-3 d-flex gap-1",
style=dict(background="rgba(239, 239, 239, 0.6)"),
@@ -1237,14 +1263,20 @@ def chat_input_view() -> tuple[bool, str, list[str]]:
pressed_send = st.button("✈ Send", style=dict(height="3.2rem"))
if show_uploader:
- new_input_images = st.file_uploader(
- "",
- accept_multiple_files=True,
- )
+ uploaded_files = st.file_uploader("", accept_multiple_files=True)
+ new_input_images = []
+ new_input_documents = []
+ for f in uploaded_files:
+ mime_type = mimetypes.guess_type(f)[0] or ""
+ if mime_type.startswith("image/"):
+ new_input_images.append(f)
+ else:
+ new_input_documents.append(f)
else:
new_input_images = None
+ new_input_documents = None
- return pressed_send, new_input, new_input_images
+ return pressed_send, new_input, new_input_images, new_input_documents
def msg_container_widget(role: str):
diff --git a/routers/root.py b/routers/root.py
index 6b40fca61..6f854ba20 100644
--- a/routers/root.py
+++ b/routers/root.py
@@ -38,6 +38,7 @@
from daras_ai_v2.meta_preview_url import meta_preview_url
from daras_ai_v2.query_params_util import extract_query_params
from daras_ai_v2.settings import templates
+from daras_ai_v2.tabs_widget import MenuTabs
from routers.api import request_form_files
app = APIRouter()
@@ -198,7 +199,7 @@ def file_upload(request: Request, form_data: FormData = Depends(request_form_fil
if content_type.startswith("image/"):
with Image(blob=data) as img:
- if img.format not in ["png", "jpeg", "jpg", "gif"]:
+ if img.format.lower() not in ["png", "jpeg", "jpg", "gif"]:
img.format = "png"
content_type = "image/png"
filename += ".png"
@@ -331,27 +332,38 @@ def _api_docs_page(request):
@app.post("/")
@app.post("/{page_slug}/")
-@app.post("/{page_slug}/{tab}/")
+@app.post("/{page_slug}/{run_slug_or_tab}/")
+@app.post("/{page_slug}/{run_slug_or_tab}/{tab}/")
def st_page(
request: Request,
page_slug="",
+ run_slug_or_tab="",
tab="",
json_data: dict = Depends(request_json),
):
+ run_slug, tab = _extract_run_slug_and_tab(run_slug_or_tab, tab)
+ try:
+ selected_tab = MenuTabs.paths_reverse[tab]
+ except KeyError:
+ raise HTTPException(status_code=404)
+
try:
page_cls = page_slug_map[normalize_slug(page_slug)]
except KeyError:
raise HTTPException(status_code=404)
+
# ensure the latest slug is used
latest_slug = page_cls.slug_versions[-1]
if latest_slug != page_slug:
return RedirectResponse(
- request.url.replace(path=os.path.join("/", latest_slug, tab, ""))
+ request.url.replace(path=os.path.join("/", latest_slug, run_slug, tab, ""))
)
example_id, run_id, uid = extract_query_params(request.query_params)
- page = page_cls(tab=tab, request=request, run_user=get_run_user(request, uid))
+ page = page_cls(
+ tab=selected_tab, request=request, run_user=get_run_user(request, uid)
+ )
state = json_data.get("state", {})
if not state:
@@ -372,19 +384,6 @@ def st_page(
except RedirectException as e:
return RedirectResponse(e.url, status_code=e.status_code)
- # Canonical URLs should not include uid or run_id (don't index specific runs).
- # In the case of examples, all tabs other than "Run" are duplicates of the page
- # without the `example_id`, and so their canonical shouldn't include `example_id`
- canonical_url = str(
- furl(
- str(settings.APP_BASE_URL),
- query_params={"example_id": example_id} if not tab and example_id else {},
- )
- / latest_slug
- / tab
- / "/" # preserve trailing slash
- )
-
ret |= {
"meta": build_meta_tags(
url=get_og_url_path(request),
@@ -394,11 +393,6 @@ def st_page(
uid=uid,
example_id=example_id,
)
- + [dict(tagName="link", rel="canonical", href=canonical_url)]
- # + [
- # dict(tagName="link", rel="icon", href="/static/favicon.ico"),
- # dict(tagName="link", rel="stylesheet", href="/static/css/app.css"),
- # ],
}
return ret
@@ -440,3 +434,12 @@ def page_wrapper(request: Request, render_fn: typing.Callable, **kwargs):
st.html(templates.get_template("footer.html").render(**context))
st.html(templates.get_template("login_scripts.html").render(**context))
+
+
+def _extract_run_slug_and_tab(run_slug_or_tab, tab) -> tuple[str, str]:
+ if run_slug_or_tab and tab:
+ return run_slug_or_tab, tab
+ elif run_slug_or_tab in MenuTabs.paths_reverse:
+ return "", run_slug_or_tab
+ else:
+ return run_slug_or_tab, ""
diff --git a/tests/test_public_endpoints.py b/tests/test_public_endpoints.py
index 4542277b4..be18647d8 100644
--- a/tests/test_public_endpoints.py
+++ b/tests/test_public_endpoints.py
@@ -12,7 +12,7 @@
client = TestClient(app)
excluded_endpoints = [
- facebook.fb_webhook_verify.__name__, # gives 403
+ facebook_api.fb_webhook_verify.__name__, # gives 403
slack_connect_redirect.__name__,
slack_connect_redirect_shortcuts.__name__,
"get_run_status", # needs query params