diff --git a/bots/models.py b/bots/models.py index 8f358c50b..1d77b57fd 100644 --- a/bots/models.py +++ b/bots/models.py @@ -8,7 +8,7 @@ from django.contrib.auth import get_user_model from django.db import models, transaction from django.db.models import Q -from django.utils.text import Truncator +from django.utils.text import Truncator, slugify from furl import furl from phonenumber_field.modelfields import PhoneNumberField @@ -93,7 +93,7 @@ class Workflow(models.IntegerChoices): def short_slug(self): return min(self.page_cls.slug_versions, key=len) - def get_app_url(self, example_id: str, run_id: str, uid: str): + def get_app_url(self, example_id: str, run_id: str, uid: str, run_slug: str = ""): """return the url to the gooey app""" query_params = {} if run_id and uid: @@ -102,7 +102,8 @@ def get_app_url(self, example_id: str, run_id: str, uid: str): query_params |= dict(example_id=example_id) return str( furl(settings.APP_BASE_URL, query_params=query_params) - / self.short_slug + / self.page_cls.slug_versions[-1] + / run_slug / "/" ) @@ -903,7 +904,9 @@ def as_llm_context(self, limit: int = 100) -> list["ConversationEntry"]: entries[i] = format_chat_entry( role=msg.role, content=msg.content, - images=msg.attachments.values_list("url", flat=True), + images=msg.attachments.filter( + metadata__mime_type__startswith="image/" + ).values_list("url", flat=True), ) return entries @@ -1288,7 +1291,10 @@ def duplicate( def get_app_url(self): return Workflow(self.workflow).get_app_url( - example_id=self.published_run_id, run_id="", uid="" + example_id=self.published_run_id, + run_id="", + uid="", + run_slug=self.title and slugify(self.title), ) def add_version( diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 0a4dd2d1b..edd3d9ec8 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -16,6 +16,7 @@ import requests import sentry_sdk from django.utils import timezone +from django.utils.text import slugify from fastapi import HTTPException from firebase_admin import auth from furl import furl @@ -136,11 +137,17 @@ def app_url( query_params = cls.clean_query_params( example_id=example_id, run_id=run_id, uid=uid ) | (query_params or {}) - f = furl(settings.APP_BASE_URL, query_params=query_params) / ( - cls.slug_versions[-1] + "/" + f = ( + furl(settings.APP_BASE_URL, query_params=query_params) + / cls.slug_versions[-1] ) + if example_id := query_params.get("example_id"): + pr = cls.get_published_run(published_run_id=example_id) + if pr and pr.title: + f /= slugify(pr.title) if tab_name: - f /= tab_name + "/" + f /= tab_name + f /= "/" # keep trailing slash return str(f) @classmethod @@ -180,12 +187,18 @@ def setup_render(self): "/" + self.slug_versions[0], source=TRANSACTION_SOURCE_ROUTE ) + def refresh_state(self): + _, run_id, uid = extract_query_params(gooey_get_query_params()) + channel = f"gooey-outputs/{self.slug_versions[0]}/{uid}/{run_id}" + output = realtime_pull([channel])[0] + if output: + st.session_state.update(output) + + def render(self): + self.setup_render() + if self.get_run_state() == RecipeRunState.running: - _, run_id, uid = extract_query_params(gooey_get_query_params()) - channel = f"gooey-outputs/{self.slug_versions[0]}/{uid}/{run_id}" - output = realtime_pull([channel])[0] - if output: - st.session_state.update(output) + self.refresh_state() else: realtime_clear_subs() @@ -196,9 +209,22 @@ def setup_render(self): self.render_report_form() return - def render(self): - self.setup_render() + self._render_header() + + self._render_tab_menu(selected_tab=self.tab) + with st.nav_tab_content(): + self.render_selected_tab(self.tab) + + def _render_tab_menu(self, selected_tab: str): + assert selected_tab in MenuTabs.paths + + with st.nav_tabs(): + for name in self.get_tabs(): + url = self.get_tab_url(name) + with st.nav_item(url, active=name == selected_tab): + st.html(name) + def _render_header(self): current_run = self.get_current_sr() published_run = self.get_current_published_run() is_root_example = ( @@ -207,6 +233,7 @@ def render(self): and published_run.saved_run == current_run ) tbreadcrumbs = get_title_breadcrumbs(self, current_run, published_run) + with st.div(className="d-flex justify-content-between mt-4"): with st.div(className="d-lg-flex d-block align-items-center"): if not tbreadcrumbs.has_breadcrumbs() and not self.run_user: @@ -268,21 +295,6 @@ def render(self): elif is_root_example: st.write(self.preview_description(current_run.to_dict())) - try: - selected_tab = MenuTabs.paths_reverse[self.tab] - except KeyError: - st.error(f"## 404 - Tab {self.tab!r} Not found") - return - - with st.nav_tabs(): - tab_names = self.get_tabs() - for name in tab_names: - url = self.get_tab_url(name) - with st.nav_item(url, active=name == selected_tab): - st.html(name) - with st.nav_tab_content(): - self.render_selected_tab(selected_tab) - def _render_title(self, title: str): st.write(f"# {title}") @@ -450,14 +462,14 @@ def _render_publish_modal( if not pressed_save: return - recipe_title = self.get_root_published_run().title or self.title + is_root_published_run = is_update_mode and published_run.is_root() - if ( - not is_root_published_run - and published_run_title.strip() == recipe_title.strip() - ): - st.error("Title can't be the same as the recipe title", icon="⚠️") - return + if not is_root_published_run: + try: + self._validate_published_run_title(published_run_title) + except TitleValidationError as e: + st.error(str(e)) + return if is_update_mode: updates = dict( @@ -483,6 +495,18 @@ def _render_publish_modal( ) force_redirect(published_run.get_app_url()) + def _validate_published_run_title(self, title: str): + if slugify(title) in settings.DISALLOWED_TITLE_SLUGS: + raise TitleValidationError( + "This title is not allowed. Please choose a different title." + ) + elif title.strip() == self.get_recipe_title(): + raise TitleValidationError( + "Please choose a different title for your published run." + ) + elif title.strip() == "": + raise TitleValidationError("Title cannot be empty.") + def _has_published_run_changed( self, *, @@ -911,7 +935,7 @@ def get_current_published_run(cls) -> PublishedRun | None: @classmethod def get_pr_from_query_params( cls, example_id: str, run_id: str, uid: str - ) -> PublishedRun: + ) -> PublishedRun | None: if run_id and uid: sr = cls.get_sr_from_query_params(example_id, run_id, uid) return ( @@ -1883,6 +1907,17 @@ def err_msg_for_exc(e): return f"{type(e).__name__}: {e}" +def force_redirect(url: str): + # note: assumes sanitized URLs + st.html( + f""" + + """ + ) + + class RedirectException(Exception): def __init__(self, url, status_code=302): self.url = url @@ -1896,16 +1931,5 @@ def __init__(self, query_params: dict, status_code=303): super().__init__(url, status_code) -def force_redirect(url: str): - # note: assumes sanitized URLs - st.html( - f""" - - """ - ) - - -def reverse_enumerate(start, iterator): - return zip(range(start, -1, -1), iterator) +class TitleValidationError(Exception): + pass diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index 31f64e762..dfc00c794 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -131,9 +131,12 @@ def get_input_audio(self) -> str | None: def get_input_images(self) -> list[str] | None: raise NotImplementedError + def get_input_documents(self) -> list[str] | None: + raise NotImplementedError + def nice_filename(self, mime_type: str) -> str: ext = mimetypes.guess_extension(mime_type) or "" - return f"{self.platform}_{self.input_type}_from_{self.user_id}_to_{self.bot_id}{ext}" + return f"{self.platform.name}_{self.input_type}_from_{self.user_id}_to_{self.bot_id}{ext}" def _unpack_bot_integration(self): bi = self.convo.bot_integration @@ -198,6 +201,7 @@ def _mock_api_output(input_text): def _on_msg(bot: BotInterface): speech_run = None input_images = None + input_documents = None if not bot.page_cls: bot.send_msg(text=PAGE_NOT_CONNECTED_ERROR) return @@ -240,6 +244,17 @@ def _on_msg(bot: BotInterface): status_code=400, detail="No image found in request." ) input_text = (bot.get_input_text() or "").strip() + case "document": + input_documents = bot.get_input_documents() + if not input_documents: + raise HTTPException( + status_code=400, detail="No documents found in request." + ) + filenames = ", ".join( + furl(url.strip("/")).path.segments[-1] for url in input_documents + ) + input_text = (bot.get_input_text() or "").strip() + input_text = f"Files: {filenames}\n\n{input_text}" case "text": input_text = (bot.get_input_text() or "").strip() if not input_text: @@ -268,6 +283,7 @@ def _on_msg(bot: BotInterface): billing_account_user=billing_account_user, bot=bot, input_images=input_images, + input_documents=input_documents, input_text=input_text, speech_run=speech_run, ) @@ -306,6 +322,7 @@ def _process_and_send_msg( billing_account_user: AppUser, bot: BotInterface, input_images: list[str] | None, + input_documents: list[str] | None, input_text: str, speech_run: str | None, ): @@ -324,6 +341,7 @@ def _process_and_send_msg( user_language=bot.language, speech_run=speech_run, input_images=input_images, + input_documents=input_documents, ) except HTTPException as e: traceback.print_exc() @@ -345,6 +363,7 @@ def _process_and_send_msg( _save_msgs( bot=bot, input_images=input_images, + input_documents=input_documents, input_text=input_text, speech_run=speech_run, platform_msg_id=msg_id, @@ -356,6 +375,7 @@ def _process_and_send_msg( def _save_msgs( bot: BotInterface, input_images: list[str] | None, + input_documents: list[str] | None, input_text: str, speech_run: str | None, platform_msg_id: str | None, @@ -376,10 +396,10 @@ def _save_msgs( else None, ) attachments = [] - for img in input_images or []: - metadata = doc_url_to_file_metadata(img) + for f_url in (input_images or []) + (input_documents or []): + metadata = doc_url_to_file_metadata(f_url) attachments.append( - MessageAttachment(message=user_msg, url=img, metadata=metadata) + MessageAttachment(message=user_msg, url=f_url, metadata=metadata) ) assistant_msg = Message( platform_msg_id=platform_msg_id, @@ -407,6 +427,7 @@ def _process_msg( query_params: dict, convo: Conversation, input_images: list[str] | None, + input_documents: list[str] | None, input_text: str, user_language: str, speech_run: str | None, @@ -426,6 +447,7 @@ def _process_msg( request_body={ "input_prompt": input_text, "input_images": input_images, + "input_documents": input_documents, "messages": saved_msgs, "user_language": user_language, }, diff --git a/daras_ai_v2/facebook_bots.py b/daras_ai_v2/facebook_bots.py index 13004f8e1..2a22bbddd 100644 --- a/daras_ai_v2/facebook_bots.py +++ b/daras_ai_v2/facebook_bots.py @@ -71,6 +71,13 @@ def get_input_images(self) -> list[str] | None: return None return [self._download_wa_media(media_id)] + def get_input_documents(self) -> list[str] | None: + try: + media_id = self.input_message["document"]["id"] + except KeyError: + return None + return [self._download_wa_media(media_id)] + def _download_wa_media(self, media_id: str) -> str: # download file from whatsapp data, mime_type = retrieve_wa_media_by_id(media_id) diff --git a/daras_ai_v2/meta_content.py b/daras_ai_v2/meta_content.py index 9ed6b666f..12076d26f 100644 --- a/daras_ai_v2/meta_content.py +++ b/daras_ai_v2/meta_content.py @@ -1,7 +1,12 @@ +from django.utils.text import slugify +from furl import furl + from bots.models import PublishedRun, SavedRun, WorkflowMetadata +from daras_ai_v2 import settings from daras_ai_v2.base import BasePage from daras_ai_v2.breadcrumbs import get_title_breadcrumbs from daras_ai_v2.meta_preview_url import meta_preview_url +from daras_ai_v2.tabs_widget import MenuTabs sep = " • " @@ -15,32 +20,42 @@ def build_meta_tags( uid: str, example_id: str, ) -> list[dict]: - sr, published_run = page.get_runs_from_query_params(example_id, run_id, uid) + sr, pr = page.get_runs_from_query_params(example_id, run_id, uid) metadata = page.workflow.get_or_create_metadata() title = meta_title_for_page( page=page, metadata=metadata, sr=sr, - published_run=published_run, + pr=pr, ) description = meta_description_for_page( metadata=metadata, - published_run=published_run, + pr=pr, ) image = meta_image_for_page( page=page, state=state, sr=sr, metadata=metadata, - published_run=published_run, + pr=pr, + ) + canonical_url = canonical_url_for_page( + page=page, + state=state, + sr=sr, + metadata=metadata, + pr=pr, ) + robots = robots_tag_for_page(page=page, sr=sr, pr=pr) return raw_build_meta_tags( url=url, title=title, description=description, image=image, + canonical_url=canonical_url, + robots=robots, ) @@ -50,6 +65,8 @@ def raw_build_meta_tags( title: str, description: str | None = None, image: str | None = None, + canonical_url: str | None = None, + robots: str | None = None, ) -> list[dict[str, str]]: ret = [ dict(title=title), @@ -61,18 +78,27 @@ def raw_build_meta_tags( dict(property="twitter:url", content=url), dict(property="twitter:title", content=title), ] + if description: ret += [ dict(name="description", content=description), dict(property="og:description", content=description), dict(property="twitter:description", content=description), ] + if image: ret += [ dict(name="image", content=image), dict(property="og:image", content=image), dict(property="twitter:image", content=image), ] + + if canonical_url: + ret += [dict(tagName="link", rel="canonical", href=canonical_url)] + + if robots: + ret += [dict(name="robots", content=robots)] + return ret @@ -81,9 +107,9 @@ def meta_title_for_page( page: BasePage, metadata: WorkflowMetadata, sr: SavedRun, - published_run: PublishedRun | None, + pr: PublishedRun | None, ) -> str: - tbreadcrumbs = get_title_breadcrumbs(page, sr, published_run) + tbreadcrumbs = get_title_breadcrumbs(page, sr, pr) parts = [] if tbreadcrumbs.published_title or tbreadcrumbs.root_title: @@ -91,7 +117,7 @@ def meta_title_for_page( # use the short title for non-root examples part = metadata.short_title if tbreadcrumbs.published_title: - part = f"{published_run.title} {part}" + part = f"{pr.title} {part}" # add the creator's name user = sr.get_creator() if user and user.display_name: @@ -108,14 +134,14 @@ def meta_title_for_page( def meta_description_for_page( *, metadata: WorkflowMetadata, - published_run: PublishedRun | None, + pr: PublishedRun | None, ) -> str: - if published_run and not published_run.is_root(): - description = published_run.notes or metadata.meta_description + if pr and not pr.is_root(): + description = pr.notes or metadata.meta_description else: description = metadata.meta_description - if not (published_run and published_run.is_root()) or not description: + if not (pr and pr.is_root()) or not description: # for all non-root examples, or when there is no other description description += sep + "AI API, workflow & prompt shared on Gooey.AI." @@ -128,9 +154,9 @@ def meta_image_for_page( state: dict, metadata: WorkflowMetadata, sr: SavedRun, - published_run: PublishedRun | None, + pr: PublishedRun | None, ) -> str | None: - if published_run and published_run.saved_run == sr and published_run.is_root(): + if pr and pr.saved_run == sr and pr.is_root(): file_url = metadata.meta_image or page.preview_image(state) else: file_url = page.preview_image(state) @@ -139,3 +165,91 @@ def meta_image_for_page( file_url=file_url, fallback_img=page.fallback_preivew_image(), ) + + +def canonical_url_for_page( + *, + page: BasePage, + state: dict, + metadata: WorkflowMetadata, + sr: SavedRun, + pr: PublishedRun | None, +) -> str: + """ + Assumes that `page.tab` is a valid tab defined in MenuTabs + """ + + latest_slug = page.slug_versions[-1] # for recipe + recipe_url = furl(str(settings.APP_BASE_URL)) / latest_slug + + if pr and pr.saved_run == sr and pr.is_root(): + query_params = {} + pr_slug = "" + elif pr and pr.saved_run == sr: + query_params = {"example_id": pr.published_run_id} + pr_slug = (pr.title and slugify(pr.title)) or "" + else: + query_params = {"run_id": sr.run_id, "uid": sr.uid} + pr_slug = "" + + tab_path = MenuTabs.paths[page.tab] + match page.tab: + case MenuTabs.examples: + # no query params / run_slug in this case + return str(recipe_url / tab_path / "/") + case MenuTabs.history, MenuTabs.saved: + # no run slug in this case + return str(furl(recipe_url, query_params=query_params) / tab_path / "/") + case _: + # all other cases + return str( + furl(recipe_url, query_params=query_params) / pr_slug / tab_path / "/" + ) + + +def robots_tag_for_page( + *, + page: BasePage, + sr: SavedRun, + pr: PublishedRun | None, +) -> str: + is_root = pr and pr.saved_run == sr and pr.is_root() + is_example = pr and pr.saved_run == sr and not pr.is_root() + + match page.tab: + case MenuTabs.run if is_root or is_example: + no_follow, no_index = False, False + case MenuTabs.run: # ordinary run (not example) + no_follow, no_index = False, True + case MenuTabs.examples: + no_follow, no_index = False, False + case MenuTabs.run_as_api: + no_follow, no_index = False, True + case MenuTabs.integrations: + no_follow, no_index = True, True + case MenuTabs.history: + no_follow, no_index = True, True + case MenuTabs.saved: + no_follow, no_index = True, True + case _: + raise ValueError(f"Unknown tab: {page.tab}") + + parts = [] + if no_follow: + parts.append("nofollow") + if no_index: + parts.append("noindex") + return ",".join(parts) + + +def get_is_indexable_for_page( + *, + page: BasePage, + sr: SavedRun, + pr: PublishedRun | None, +) -> bool: + if pr and pr.saved_run == sr and pr.is_root(): + # index all tabs on root + return True + + return bool(pr and pr.saved_run == sr and page.tab == MenuTabs.run) diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index 48e661dcf..74f38ed5f 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -244,6 +244,17 @@ SUPPORT_EMAIL = "Gooey.AI Support " SEND_RUN_EMAIL_AFTER_SEC = config("SEND_RUN_EMAIL_AFTER_SEC", 20) +DISALLOWED_TITLE_SLUGS = config("DISALLOWED_TITLE_SLUGS", cast=Csv(), default="") + [ + # tab names + "api", + "examples", + "history", + "saved", + "integrations", + # other + "docs", +] + SAFTY_CHECKER_EXAMPLE_ID = "3rcxqx0r" SAFTY_CHECKER_BILLING_EMAIL = "support+mods@gooey.ai" diff --git a/recipes/LipsyncTTS.py b/recipes/LipsyncTTS.py index f85d4550a..60886c017 100644 --- a/recipes/LipsyncTTS.py +++ b/recipes/LipsyncTTS.py @@ -62,13 +62,12 @@ class ResponseModel(BaseModel): def related_workflows(self) -> list: from recipes.VideoBots import VideoBotsPage from recipes.DeforumSD import DeforumSDPage - from recipes.CompareText2Img import CompareText2ImgPage return [ VideoBotsPage, TextToSpeechPage, DeforumSDPage, - CompareText2ImgPage, + LipsyncPage, ] def render_form_v2(self): diff --git a/recipes/SocialLookupEmail.py b/recipes/SocialLookupEmail.py index eab559b84..082fdcdaa 100644 --- a/recipes/SocialLookupEmail.py +++ b/recipes/SocialLookupEmail.py @@ -25,7 +25,7 @@ class SocialLookupEmailPage(BasePage): slug_versions = ["SocialLookupEmail", "email-writer-with-profile-lookup"] sane_defaults = { - "selected_model": LargeLanguageModels.text_davinci_003.name, + "selected_model": LargeLanguageModels.gpt_4.name, } class RequestModel(BaseModel): diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 1d80162c9..149931ffb 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -1,4 +1,5 @@ import json +import mimetypes import os import os.path import typing @@ -150,6 +151,7 @@ class RequestModel(BaseModel): input_prompt: str input_images: list[str] | None + input_documents: list[str] | None # conversation history/context messages: list[ConversationEntry] | None @@ -458,16 +460,22 @@ def render_output(self): # chat window with st.div(className="pb-3"): chat_list_view() - pressed_send, new_input, new_input_images = chat_input_view() + ( + pressed_send, + new_input, + new_input_images, + new_input_documents, + ) = chat_input_view() if pressed_send: - self.on_send(new_input, new_input_images) + self.on_send(new_input, new_input_images, new_input_documents) # clear chat inputs if st.button("🗑️ Clear"): st.session_state["messages"] = [] st.session_state["input_prompt"] = "" st.session_state["input_images"] = None + st.session_state["new_input_documents"] = None st.session_state["raw_input_text"] = "" self.clear_outputs() st.session_state["final_keyword_query"] = "" @@ -487,12 +495,18 @@ def render_output(self): label_visibility="collapsed", ) - def on_send(self, new_input: str, new_input_images: list[str]): + def on_send( + self, + new_input: str, + new_input_images: list[str], + new_input_documents: list[str], + ): prev_input = st.session_state.get("raw_input_text") or "" prev_output = (st.session_state.get("raw_output_text") or [""])[0] prev_input_images = st.session_state.get("input_images") + prev_input_documents = st.session_state.get("input_documents") - if (prev_input or prev_input_images) and prev_output: + if (prev_input or prev_input_images or prev_input_documents) and prev_output: # append previous input to the history st.session_state["messages"] = st.session_state.get("messages", []) + [ format_chat_entry( @@ -507,8 +521,14 @@ def on_send(self, new_input: str, new_input_images: list[str]): ] # add new input to the state + if new_input_documents: + filenames = ", ".join( + furl(url.strip("/")).path.segments[-1] for url in new_input_documents + ) + new_input = f"Files: {filenames}\n\n{new_input}" st.session_state["input_prompt"] = new_input st.session_state["input_images"] = new_input_images or None + st.session_state["input_documents"] = new_input_documents or None self.on_submit() @@ -544,6 +564,7 @@ def render_steps(self): if isinstance(final_prompt, str): text_output("**Final Prompt**", value=final_prompt, height=300) else: + st.write("**Final Prompt**") st.json(final_prompt) for idx, text in enumerate(st.session_state.get("raw_output_text", [])): @@ -601,7 +622,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: """ user_input = request.input_prompt.strip() - if not (user_input or request.input_images): + if not (user_input or request.input_images or request.input_documents): return model = LargeLanguageModels[request.selected_model] is_chat_model = model.is_chat_model() @@ -609,13 +630,11 @@ def run(self, state: dict) -> typing.Iterator[str | None]: bot_script = request.bot_script ocr_texts = [] - if request.input_images: + if request.document_model and (request.input_images or request.input_documents): yield "Running Azure Form Recognizer..." - for img in request.input_images: + for url in (request.input_images or []) + (request.input_documents or []): ocr_text = ( - azure_form_recognizer( - img, model_id=request.document_model or "prebuilt-read" - ) + azure_form_recognizer(url, model_id="prebuilt-read") .get("content", "") .strip() ) @@ -641,7 +660,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: target_language="en", ) for text in ocr_texts: - user_input = f"Image: {text!r}\n{user_input}" + user_input = f"Exracted Text: {text!r}\n\n{user_input}" # parse the bot script # system_message, scripted_msgs = parse_script(bot_script) @@ -666,18 +685,16 @@ def run(self, state: dict) -> typing.Iterator[str | None]: # except IndexError: # user_display_name = CHATML_ROLE_USER - # construct user prompt + # save raw input for reference state["raw_input_text"] = user_input - user_prompt = { - "role": CHATML_ROLE_USER, - "content": user_input, - } # if documents are provided, run doc search on the saved msgs and get back the references references = None if request.documents: # formulate the search query as a history of all the messages - query_msgs = saved_msgs + [user_prompt] + query_msgs = saved_msgs + [ + format_chat_entry(role=CHATML_ROLE_USER, content=user_input) + ] clip_idx = convo_window_clipper( query_msgs, model_max_tokens[model] // 2, sep=" " ) @@ -741,17 +758,22 @@ def run(self, state: dict) -> typing.Iterator[str | None]: if references: # add task instructions task_instructions = render_prompt_vars(request.task_instructions, state) - user_prompt["content"] = ( + user_input = ( references_as_prompt(references) + f"\n**********\n{task_instructions.strip()}\n**********\n" - + user_prompt["content"] + + user_input ) + # construct user prompt + user_prompt = format_chat_entry( + role=CHATML_ROLE_USER, content=user_input, images=request.input_images + ) + # truncate the history to fit the model's max tokens history_window = scripted_msgs + saved_msgs max_history_tokens = ( model_max_tokens[model] - - calc_gpt_tokens([system_prompt, user_prompt], is_chat_model=is_chat_model) + - calc_gpt_tokens([system_prompt, user_input], is_chat_model=is_chat_model) - request.max_tokens - SAFETY_BUFFER ) @@ -853,7 +875,9 @@ def run(self, state: dict) -> typing.Iterator[str | None]: tts_state = dict(state) for text in state.get("raw_tts_text", state["raw_output_text"]): tts_state["text_prompt"] = text - yield from TextToSpeechPage().run(tts_state) + yield from TextToSpeechPage( + request=self.request, run_user=self.run_user + ).run(tts_state) state["output_audio"].append(tts_state["audio_url"]) if not request.input_face: @@ -861,7 +885,9 @@ def run(self, state: dict) -> typing.Iterator[str | None]: lip_state = dict(state) for audio_url in state["output_audio"]: lip_state["input_audio"] = audio_url - yield from LipsyncPage().run(lip_state) + yield from LipsyncPage(request=self.request, run_user=self.run_user).run( + lip_state + ) state["output_video"].append(lip_state["output_video"]) def get_tabs(self): @@ -1217,7 +1243,7 @@ def chat_list_view(): st.image(im, style={"maxHeight": "200px"}) -def chat_input_view() -> tuple[bool, str, list[str]]: +def chat_input_view() -> tuple[bool, str, list[str], list[str]]: with st.div( className="px-3 pt-3 d-flex gap-1", style=dict(background="rgba(239, 239, 239, 0.6)"), @@ -1237,14 +1263,20 @@ def chat_input_view() -> tuple[bool, str, list[str]]: pressed_send = st.button("✈ Send", style=dict(height="3.2rem")) if show_uploader: - new_input_images = st.file_uploader( - "", - accept_multiple_files=True, - ) + uploaded_files = st.file_uploader("", accept_multiple_files=True) + new_input_images = [] + new_input_documents = [] + for f in uploaded_files: + mime_type = mimetypes.guess_type(f)[0] or "" + if mime_type.startswith("image/"): + new_input_images.append(f) + else: + new_input_documents.append(f) else: new_input_images = None + new_input_documents = None - return pressed_send, new_input, new_input_images + return pressed_send, new_input, new_input_images, new_input_documents def msg_container_widget(role: str): diff --git a/routers/root.py b/routers/root.py index 6b40fca61..6f854ba20 100644 --- a/routers/root.py +++ b/routers/root.py @@ -38,6 +38,7 @@ from daras_ai_v2.meta_preview_url import meta_preview_url from daras_ai_v2.query_params_util import extract_query_params from daras_ai_v2.settings import templates +from daras_ai_v2.tabs_widget import MenuTabs from routers.api import request_form_files app = APIRouter() @@ -198,7 +199,7 @@ def file_upload(request: Request, form_data: FormData = Depends(request_form_fil if content_type.startswith("image/"): with Image(blob=data) as img: - if img.format not in ["png", "jpeg", "jpg", "gif"]: + if img.format.lower() not in ["png", "jpeg", "jpg", "gif"]: img.format = "png" content_type = "image/png" filename += ".png" @@ -331,27 +332,38 @@ def _api_docs_page(request): @app.post("/") @app.post("/{page_slug}/") -@app.post("/{page_slug}/{tab}/") +@app.post("/{page_slug}/{run_slug_or_tab}/") +@app.post("/{page_slug}/{run_slug_or_tab}/{tab}/") def st_page( request: Request, page_slug="", + run_slug_or_tab="", tab="", json_data: dict = Depends(request_json), ): + run_slug, tab = _extract_run_slug_and_tab(run_slug_or_tab, tab) + try: + selected_tab = MenuTabs.paths_reverse[tab] + except KeyError: + raise HTTPException(status_code=404) + try: page_cls = page_slug_map[normalize_slug(page_slug)] except KeyError: raise HTTPException(status_code=404) + # ensure the latest slug is used latest_slug = page_cls.slug_versions[-1] if latest_slug != page_slug: return RedirectResponse( - request.url.replace(path=os.path.join("/", latest_slug, tab, "")) + request.url.replace(path=os.path.join("/", latest_slug, run_slug, tab, "")) ) example_id, run_id, uid = extract_query_params(request.query_params) - page = page_cls(tab=tab, request=request, run_user=get_run_user(request, uid)) + page = page_cls( + tab=selected_tab, request=request, run_user=get_run_user(request, uid) + ) state = json_data.get("state", {}) if not state: @@ -372,19 +384,6 @@ def st_page( except RedirectException as e: return RedirectResponse(e.url, status_code=e.status_code) - # Canonical URLs should not include uid or run_id (don't index specific runs). - # In the case of examples, all tabs other than "Run" are duplicates of the page - # without the `example_id`, and so their canonical shouldn't include `example_id` - canonical_url = str( - furl( - str(settings.APP_BASE_URL), - query_params={"example_id": example_id} if not tab and example_id else {}, - ) - / latest_slug - / tab - / "/" # preserve trailing slash - ) - ret |= { "meta": build_meta_tags( url=get_og_url_path(request), @@ -394,11 +393,6 @@ def st_page( uid=uid, example_id=example_id, ) - + [dict(tagName="link", rel="canonical", href=canonical_url)] - # + [ - # dict(tagName="link", rel="icon", href="/static/favicon.ico"), - # dict(tagName="link", rel="stylesheet", href="/static/css/app.css"), - # ], } return ret @@ -440,3 +434,12 @@ def page_wrapper(request: Request, render_fn: typing.Callable, **kwargs): st.html(templates.get_template("footer.html").render(**context)) st.html(templates.get_template("login_scripts.html").render(**context)) + + +def _extract_run_slug_and_tab(run_slug_or_tab, tab) -> tuple[str, str]: + if run_slug_or_tab and tab: + return run_slug_or_tab, tab + elif run_slug_or_tab in MenuTabs.paths_reverse: + return "", run_slug_or_tab + else: + return run_slug_or_tab, "" diff --git a/tests/test_public_endpoints.py b/tests/test_public_endpoints.py index 4542277b4..be18647d8 100644 --- a/tests/test_public_endpoints.py +++ b/tests/test_public_endpoints.py @@ -12,7 +12,7 @@ client = TestClient(app) excluded_endpoints = [ - facebook.fb_webhook_verify.__name__, # gives 403 + facebook_api.fb_webhook_verify.__name__, # gives 403 slack_connect_redirect.__name__, slack_connect_redirect_shortcuts.__name__, "get_run_status", # needs query params