diff --git a/Dockerfile b/Dockerfile index 929e01ae5..83052f5ed 100644 --- a/Dockerfile +++ b/Dockerfile @@ -54,6 +54,9 @@ COPY . . ENV FORWARDED_ALLOW_IPS='*' ENV PYTHONUNBUFFERED=1 +ARG CAPROVER_GIT_COMMIT_SHA=${CAPROVER_GIT_COMMIT_SHA} +ENV CAPROVER_GIT_COMMIT_SHA=${CAPROVER_GIT_COMMIT_SHA} + EXPOSE 8000 EXPOSE 8501 diff --git a/app_users/admin.py b/app_users/admin.py index caa61d223..56f325fca 100644 --- a/app_users/admin.py +++ b/app_users/admin.py @@ -4,7 +4,8 @@ from app_users import models from bots.admin_links import open_in_new_tab, list_related_html_url -from bots.models import SavedRun +from bots.models import SavedRun, PublishedRun +from embeddings.models import EmbeddedFile from usage_costs.models import UsageCost @@ -31,7 +32,12 @@ class AppUserAdmin(admin.ModelAdmin): "is_paying", "disable_safety_checker", "disable_rate_limits", - ("user_runs", "view_transactions"), + ( + "view_saved_runs", + "view_published_runs", + "view_embedded_files", + "view_transactions", + ), "created_at", "upgraded_from_anonymous_at", ("open_in_firebase", "open_in_stripe"), @@ -86,7 +92,9 @@ class AppUserAdmin(admin.ModelAdmin): "total_usage_cost", "created_at", "upgraded_from_anonymous_at", - "user_runs", + "view_saved_runs", + "view_published_runs", + "view_embedded_files", "view_transactions", "open_in_firebase", "open_in_stripe", @@ -95,7 +103,7 @@ class AppUserAdmin(admin.ModelAdmin): autocomplete_fields = ["handle", "subscription"] @admin.display(description="User Runs") - def user_runs(self, user: models.AppUser): + def view_saved_runs(self, user: models.AppUser): return list_related_html_url( SavedRun.objects.filter(uid=user.uid), query_param="uid", @@ -103,6 +111,24 @@ def user_runs(self, user: models.AppUser): show_add=False, ) + @admin.display(description="Published Runs") + def view_published_runs(self, user: models.AppUser): + return list_related_html_url( + PublishedRun.objects.filter(created_by=user), + query_param="created_by", + instance_id=user.id, + show_add=False, + ) + + @admin.display(description="Embedded Files") + def view_embedded_files(self, user: models.AppUser): + return list_related_html_url( + EmbeddedFile.objects.filter(created_by=user), + query_param="created_by", + instance_id=user.id, + show_add=False, + ) + @admin.display(description="Total Payments") def total_payments(self, user: models.AppUser): return "$" + str( diff --git a/bots/admin.py b/bots/admin.py index 51f6e9aa9..4b6a731c8 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -77,8 +77,6 @@ "twilio_username", "twilio_password", "twilio_use_missed_call", - "twilio_tts_voice", - "twilio_asr_language", "twilio_initial_text", "twilio_initial_audio_url", "twilio_waiting_text", diff --git a/bots/migrations/0079_remove_botintegration_twilio_asr_language_and_more.py b/bots/migrations/0079_remove_botintegration_twilio_asr_language_and_more.py new file mode 100644 index 000000000..08bd81043 --- /dev/null +++ b/bots/migrations/0079_remove_botintegration_twilio_asr_language_and_more.py @@ -0,0 +1,21 @@ +# Generated by Django 4.2.7 on 2024-07-22 22:52 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('bots', '0078_alter_botintegration_unique_together_and_more'), + ] + + operations = [ + migrations.RemoveField( + model_name='botintegration', + name='twilio_asr_language', + ), + migrations.RemoveField( + model_name='botintegration', + name='twilio_tts_voice', + ), + ] diff --git a/bots/models.py b/bots/models.py index 25a1a7e9d..0407569d5 100644 --- a/bots/models.py +++ b/bots/models.py @@ -694,16 +694,6 @@ class BotIntegration(models.Model): blank=True, help_text="The audio url to play to the user while waiting for a response if using voice", ) - twilio_tts_voice = models.TextField( - default="", - blank=True, - help_text="The voice to use for Twilio TTS ('man', 'woman', or Amazon Polly/Google Voices: https://www.twilio.com/docs/voice/twiml/say/text-speech#available-voices-and-languages)", - ) - twilio_asr_language = models.TextField( - default="", - blank=True, - help_text="The language to use for Twilio ASR (https://www.twilio.com/docs/voice/twiml/gather#languagetags)", - ) streaming_enabled = models.BooleanField( default=False, @@ -1239,10 +1229,9 @@ def to_df_format( metadata__mime_type__startswith="image/" ).values_list("url", flat=True) ), - "Audio Input": ", ".join( - message.attachments.filter( - metadata__mime_type__startswith="audio/" - ).values_list("url", flat=True) + "Audio Input": ( + (message.saved_run and message.saved_run.state.get("input_audio")) + or "" ), } rows.append(row) diff --git a/bots/tasks.py b/bots/tasks.py index 436098f24..70f381bf1 100644 --- a/bots/tasks.py +++ b/bots/tasks.py @@ -150,6 +150,7 @@ def send_broadcast_msgs_chunked( buttons: list[ReplyButton] = None, convo_qs: QuerySet[Conversation], bi: BotIntegration, + medium: str = "Voice Call", ): convo_ids = list(convo_qs.values_list("id", flat=True)) for i in range(0, len(convo_ids), 100): @@ -161,6 +162,7 @@ def send_broadcast_msgs_chunked( documents=documents, bi_id=bi.id, convo_ids=convo_ids[i : i + 100], + medium=medium, ) diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 2e30e4379..794b5a061 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -5,13 +5,14 @@ from time import time from types import SimpleNamespace +import gooey_gui as gui import requests import sentry_sdk from django.db.models import Sum from django.utils import timezone from fastapi import HTTPException +from loguru import logger -import gooey_ui as st from app_users.models import AppUser, AppUserTransaction from bots.admin_links import change_obj_url from bots.models import SavedRun, Platform, Workflow @@ -22,8 +23,6 @@ from daras_ai_v2.exceptions import UserError from daras_ai_v2.send_email import send_email_via_postmark, send_low_balance_email from daras_ai_v2.settings import templates -from gooey_ui.pubsub import realtime_push -from gooey_ui.state import set_query_params from gooeysite.bg_db_conn import db_middleware from payments.auto_recharge import ( should_attempt_auto_recharge, @@ -41,6 +40,7 @@ def runner_task( run_id: str, uid: str, channel: str, + unsaved_state: dict[str, typing.Any] = None, ) -> int: start_time = time() error_msg = None @@ -68,7 +68,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False # extract outputs from local state | { k: v - for k, v in st.session_state.items() + for k, v in gui.session_state.items() if k in page.ResponseModel.__fields__ } # add extra outputs from the run @@ -76,26 +76,27 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False ) # send outputs to ui - realtime_push(channel, output) + gui.realtime_push(channel, output) # save to db - page.dump_state_to_sr(st.session_state | output, sr) + page.dump_state_to_sr(gui.session_state | output, sr) user = AppUser.objects.get(id=user_id) page = page_cls(request=SimpleNamespace(user=user)) page.setup_sentry() sr = page.run_doc_sr(run_id, uid) - st.set_session_state(sr.to_dict()) - set_query_params(dict(run_id=run_id, uid=uid)) + gui.set_session_state(sr.to_dict() | (unsaved_state or {})) + gui.set_query_params(dict(run_id=run_id, uid=uid)) try: save_on_step() - for val in page.main(sr, st.session_state): + for val in page.main(sr, gui.session_state): save_on_step(val) # render errors nicely except Exception as e: if isinstance(e, UserError): sentry_level = e.sentry_level + logger.warning(e) else: sentry_level = "error" traceback.print_exc() @@ -106,7 +107,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False # run completed successfully, deduct credits else: - sr.transaction, sr.price = page.deduct_credits(st.session_state) + sr.transaction, sr.price = page.deduct_credits(gui.session_state) # save everything, mark run as completed finally: diff --git a/components_doc.py b/components_doc.py index 74c9cdde2..f6ec26d8a 100644 --- a/components_doc.py +++ b/components_doc.py @@ -1,7 +1,7 @@ import inspect from functools import wraps -import gooey_ui as gui +import gooey_gui as gui META_TITLE = "Gooey Components" META_DESCRIPTION = "Explore the Gooey Component Library" diff --git a/conftest.py b/conftest.py index a38c6a11a..57d5d5cca 100644 --- a/conftest.py +++ b/conftest.py @@ -55,7 +55,7 @@ def mock_celery_tasks(): with ( patch("celeryapp.tasks.runner_task", _mock_runner_task), patch("celeryapp.tasks.post_runner_tasks", _mock_post_runner_tasks), - patch("daras_ai_v2.bots.realtime_subscribe", _mock_realtime_subscribe), + patch("gooey_gui.realtime_subscribe", _mock_realtime_subscribe), ): yield diff --git a/daras_ai/image_input.py b/daras_ai/image_input.py index 5e61fcba9..4ee0ab5f9 100644 --- a/daras_ai/image_input.py +++ b/daras_ai/image_input.py @@ -1,10 +1,11 @@ +import math import mimetypes import os import re +import typing import uuid from pathlib import Path -import math import numpy as np import requests from PIL import Image, ImageOps @@ -13,6 +14,9 @@ from daras_ai_v2 import settings from daras_ai_v2.exceptions import UserError +if typing.TYPE_CHECKING: + from google.cloud.storage import Blob, Bucket + def resize_img_pad(img_bytes: bytes, size: tuple[int, int]) -> bytes: img_cv2 = bytes_to_cv2_img(img_bytes) @@ -57,25 +61,38 @@ def upload_file_from_bytes( data: bytes, content_type: str = None, ) -> str: - if not content_type: - content_type = mimetypes.guess_type(filename)[0] - content_type = content_type or "application/octet-stream" - blob = storage_blob_for(filename) - blob.upload_from_string(data, content_type=content_type) + blob = gcs_blob_for(filename) + upload_gcs_blob_from_bytes(blob, data, content_type) return blob.public_url -def storage_blob_for(filename: str) -> "storage.storage.Blob": - from firebase_admin import storage - +def gcs_blob_for(filename: str) -> "Blob": filename = safe_filename(filename) - bucket = storage.bucket(settings.GS_BUCKET_NAME) + bucket = gcs_bucket() blob = bucket.blob( os.path.join(settings.GS_MEDIA_PATH, str(uuid.uuid1()), filename) ) return blob +def upload_gcs_blob_from_bytes( + blob: "Blob", + data: bytes, + content_type: str = None, +) -> str: + if not content_type: + content_type = mimetypes.guess_type(blob.path)[0] + content_type = content_type or "application/octet-stream" + blob.upload_from_string(data, content_type=content_type) + return blob.public_url + + +def gcs_bucket() -> "Bucket": + from firebase_admin import storage + + return storage.bucket(settings.GS_BUCKET_NAME) + + def cv2_img_to_bytes(img: np.ndarray) -> bytes: import cv2 diff --git a/daras_ai_v2/analysis_results.py b/daras_ai_v2/analysis_results.py index 7a361c4aa..5bd519982 100644 --- a/daras_ai_v2/analysis_results.py +++ b/daras_ai_v2/analysis_results.py @@ -4,14 +4,14 @@ from django.db.models import IntegerChoices -import gooey_ui as st +import gooey_gui as gui from app_users.models import AppUser from bots.models import BotIntegration, Message from daras_ai_v2.base import RecipeTabs from daras_ai_v2.copy_to_clipboard_button_widget import copy_to_clipboard_button from daras_ai_v2.grid_layout_widget import grid_layout from daras_ai_v2.workflow_url_input import del_button -from gooey_ui import QueryParamsRedirectException +from gooey_gui import QueryParamsRedirectException from gooeysite.custom_filters import related_json_field_summary from recipes.BulkRunner import list_view_editor from recipes.VideoBots import VideoBotsPage @@ -41,7 +41,7 @@ def render_analysis_results_page( render_title_breadcrumb_share(bi, current_url, current_user) if title: - st.write(title) + gui.write(title) if graphs_json: graphs = json.loads(graphs_json) @@ -55,40 +55,40 @@ def render_analysis_results_page( except Message.DoesNotExist: results = None if not results: - with st.center(): - st.error("No analysis results found") + with gui.center(): + gui.error("No analysis results found") return - with st.div(className="pb-5 pt-3"): + with gui.div(className="pb-5 pt-3"): grid_layout(2, graphs, partial(render_graph_data, bi, results), separator=False) - st.checkbox("๐Ÿ”„ Refresh every 10s", key="autorefresh") + gui.checkbox("๐Ÿ”„ Refresh every 10s", key="autorefresh") - with st.expander("โœ๏ธ Edit"): - title = st.text_area("##### Title", value=title) + with gui.expander("โœ๏ธ Edit"): + title = gui.text_area("##### Title", value=title) - st.session_state.setdefault("selected_graphs", graphs) + gui.session_state.setdefault("selected_graphs", graphs) selected_graphs = list_view_editor( add_btn_label="โž• Add a Graph", key="selected_graphs", render_inputs=partial(render_inputs, results), ) - with st.center(): - if st.button("โœ… Update"): + with gui.center(): + if gui.button("โœ… Update"): _on_press_update(title, selected_graphs) def render_inputs(results: dict, key: str, del_key: str, d: dict): - ocol1, ocol2 = st.columns([11, 1], responsive=False) + ocol1, ocol2 = gui.columns([11, 1], responsive=False) with ocol1: - col1, col2, col3 = st.columns(3) + col1, col2, col3 = gui.columns(3) with ocol2: ocol2.node.props["style"] = dict(paddingTop="2rem") del_button(del_key) with col1: - d["key"] = st.selectbox( + d["key"] = gui.selectbox( label="##### Key", options=results.keys(), key=f"{key}_key", @@ -96,7 +96,7 @@ def render_inputs(results: dict, key: str, del_key: str, d: dict): ) with col2: col2.node.props["style"] = dict(paddingTop="0.45rem") - d["graph_type"] = st.selectbox( + d["graph_type"] = gui.selectbox( label="###### Graph Type", options=[g.value for g in GraphType], format_func=lambda x: GraphType(x).label, @@ -105,7 +105,7 @@ def render_inputs(results: dict, key: str, del_key: str, d: dict): ) with col3: col3.node.props["style"] = dict(paddingTop="0.45rem") - d["data_selection"] = st.selectbox( + d["data_selection"] = gui.selectbox( label="###### Data Selection", options=[d.value for d in DataSelection], format_func=lambda x: DataSelection(x).label, @@ -115,10 +115,10 @@ def render_inputs(results: dict, key: str, del_key: str, d: dict): def _autorefresh_script(): - if not st.session_state.get("autorefresh"): + if not gui.session_state.get("autorefresh"): return - st.session_state.pop("__cache__", None) - st.js( + gui.session_state.pop("__cache__", None) + gui.js( # language=JavaScript """ setTimeout(() => { @@ -139,23 +139,23 @@ def render_title_breadcrumb_share( else: run_title = bi.saved_run.page_title # this is mostly for backwards compat query_params = dict(run_id=bi.saved_run.run_id, uid=bi.saved_run.uid) - with st.div(className="d-flex justify-content-between mt-4"): - with st.div(className="d-lg-flex d-block align-items-center"): - with st.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"): - with st.breadcrumbs(): + with gui.div(className="d-flex justify-content-between mt-4"): + with gui.div(className="d-lg-flex d-block align-items-center"): + with gui.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"): + with gui.breadcrumbs(): metadata = VideoBotsPage.workflow.get_or_create_metadata() - st.breadcrumb_item( + gui.breadcrumb_item( metadata.short_title, link_to=VideoBotsPage.app_url(), className="text-muted", ) if not (bi.published_run_id and bi.published_run.is_root()): - st.breadcrumb_item( + gui.breadcrumb_item( run_title, link_to=VideoBotsPage.app_url(**query_params), className="text-muted", ) - st.breadcrumb_item( + gui.breadcrumb_item( "Integrations", link_to=VideoBotsPage.app_url( **query_params, @@ -170,8 +170,8 @@ def render_title_breadcrumb_share( show_as_link=current_user and VideoBotsPage.is_user_admin(current_user), ) - with st.div(className="d-flex align-items-center"): - with st.div(className="d-flex align-items-start right-action-icons"): + with gui.div(className="d-flex align-items-center"): + with gui.div(className="d-flex align-items-start right-action-icons"): copy_to_clipboard_button( f' Copy Link', value=current_url, @@ -197,7 +197,7 @@ def _on_press_update(title: str, selected_graphs: list[dict]): raise QueryParamsRedirectException(dict(title=title, graphs=graphs_json)) -@st.cache_in_session_state +@gui.cache_in_session_state def fetch_analysis_results(bi: BotIntegration) -> dict: msgs = Message.objects.filter( conversation__bot_integration=bi, @@ -217,7 +217,7 @@ def fetch_analysis_results(bi: BotIntegration) -> dict: def render_graph_data(bi: BotIntegration, results: dict, graph_data: dict): key = graph_data["key"] - st.write(f"##### {key}") + gui.write(f"##### {key}") obj_key = f"analysis_result__{key}" if graph_data["data_selection"] == DataSelection.last.value: @@ -227,7 +227,7 @@ def render_graph_data(bi: BotIntegration, results: dict, graph_data: dict): .latest() ) if not latest_msg: - st.write("No analysis results found") + gui.write("No analysis results found") return values = [[latest_msg.analysis_result.get(key), 1]] elif graph_data["data_selection"] == DataSelection.convo_last.value: @@ -249,7 +249,7 @@ def render_graph_data(bi: BotIntegration, results: dict, graph_data: dict): for val in values: if not val: continue - st.write(val[0]) + gui.write(val[0]) case GraphType.table_count.value: render_table_count(values) case GraphType.bar_count.value: @@ -261,8 +261,8 @@ def render_graph_data(bi: BotIntegration, results: dict, graph_data: dict): def render_table_count(values): - st.div(className="p-1") - st.data_table( + gui.div(className="p-1") + gui.data_table( [["Value", "Count"]] + [[result[0], result[1]] for result in values], ) @@ -318,4 +318,4 @@ def render_data_in_plotly(*data): dragmode="pan", ), ) - st.plotly_chart(fig) + gui.plotly_chart(fig) diff --git a/daras_ai_v2/api_examples_widget.py b/daras_ai_v2/api_examples_widget.py index 44086ba02..9cd9beeb1 100644 --- a/daras_ai_v2/api_examples_widget.py +++ b/daras_ai_v2/api_examples_widget.py @@ -5,7 +5,7 @@ from furl import furl -import gooey_ui as st +import gooey_gui as gui from auth.token_authentication import auth_keyword from daras_ai_v2 import settings from daras_ai_v2.doc_search_settings_widgets import is_user_uploaded_url @@ -30,7 +30,7 @@ def get_filenames(request_body): def api_example_generator( *, api_url: furl, request_body: dict, as_form_data: bool, as_async: bool ): - js, python, curl = st.tabs(["`node.js`", "`python`", "`curl`"]) + js, python, curl = gui.tabs(["`node.js`", "`python`", "`curl`"]) filenames = [] if as_async: @@ -95,7 +95,7 @@ def api_example_generator( json=shlex.quote(json.dumps(request_body, indent=2)), ) - st.write( + gui.write( """ 1. Generate an api key [below๐Ÿ‘‡](#api-keys) @@ -193,7 +193,7 @@ def api_example_generator( from black.mode import Mode py_code = format_str(py_code, mode=Mode()) - st.write( + gui.write( rf""" 1. Generate an api key [below๐Ÿ‘‡](#api-keys) @@ -308,7 +308,7 @@ def api_example_generator( js_code += "\n}\n\ngooeyAPI();" - st.write( + gui.write( r""" 1. Generate an api key [below๐Ÿ‘‡](#api-keys) @@ -385,7 +385,7 @@ def bot_api_example_generator(integration_id: str): integration_id=integration_id, ) - st.write( + gui.write( f""" Your Integration ID: `{integration_id}` @@ -407,14 +407,14 @@ def bot_api_example_generator(integration_id: str): ) / "docs" ) - st.markdown( + gui.markdown( f""" Read our complete API for features like conversation history, input media files, and more. """, unsafe_allow_html=True, ) - st.js( + gui.js( """ document.startStreaming = async function() { document.getElementById('stream-output').style.display = 'flex'; @@ -426,7 +426,7 @@ def bot_api_example_generator(integration_id: str): ).strip() ) - st.html( + gui.html( f"""
diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index fea57f5ba..c09102fe4 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -10,7 +10,7 @@ from django.db.models import F from furl import furl -import gooey_ui as st +import gooey_gui as gui from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri from daras_ai_v2 import settings from daras_ai_v2.azure_asr import azure_asr @@ -297,7 +297,7 @@ def translation_language_selector( **kwargs, ) -> str | None: if not model: - st.session_state[key] = None + gui.session_state[key] = None return if model == TranslationModels.google: @@ -308,7 +308,7 @@ def translation_language_selector( raise ValueError("Unsupported translation model: " + str(model)) options = list(languages.keys()) - return st.selectbox( + return gui.selectbox( label=label, key=key, format_func=lang_format_func, @@ -351,7 +351,7 @@ def google_translate_language_selector( """ languages = google_translate_target_languages() options = list(languages.keys()) - return st.selectbox( + return gui.selectbox( label=label, key=key, format_func=lambda k: languages[k] if k else "โ€”โ€”โ€”", @@ -408,12 +408,10 @@ def asr_language_selector( label="##### Spoken Language", key="language", ): - import langcodes - # don't show language selector for models with forced language forced_lang = forced_asr_languages.get(selected_model) if forced_lang: - st.session_state[key] = forced_lang + gui.session_state[key] = forced_lang return forced_lang options = list(asr_supported_languages.get(selected_model, [])) @@ -421,18 +419,14 @@ def asr_language_selector( options.insert(0, None) # handle non-canonical language codes - old_val = st.session_state.get(key) - if old_val and old_val not in options: - old_val_lang = langcodes.Language.get(old_val).language - for opt in options: - try: - if opt and langcodes.Language.get(opt).language == old_val_lang: - st.session_state[key] = opt - break - except langcodes.LanguageTagError: - pass - - return st.selectbox( + old_lang = gui.session_state.get(key) + if old_lang: + try: + gui.session_state[key] = normalised_lang_in_collection(old_lang, options) + except UserError: + gui.session_state[key] = None + + return gui.selectbox( label=label, key=key, format_func=lang_format_func, @@ -584,14 +578,27 @@ def run_google_translate( def normalised_lang_in_collection(target: str, collection: typing.Iterable[str]) -> str: import langcodes - for candidate in collection: - if langcodes.get(candidate).language == langcodes.get(target).language: - return candidate - - raise UserError( + ERROR = UserError( f"Unsupported language: {target!r} | must be one of {set(collection)}" ) + if target in collection: + return target + + try: + target_lan = langcodes.Language.get(target).language + except langcodes.LanguageTagError: + raise ERROR + + for candidate in collection: + try: + if candidate and langcodes.Language.get(candidate).language == target_lan: + return candidate + except langcodes.LanguageTagError: + pass + + raise ERROR + def _translate_text( text: str, @@ -1069,7 +1076,15 @@ def iterate_subtitles( yield segment_start, segment_end, segment_text -def format_timestamp(seconds: float, always_include_hours: bool, decimal_marker: str): +INFINITY_SECONDS = 99 * 3600 + 59 * 60 + 59 # 99:59:59 in seconds + + +def format_timestamp( + seconds: float | None, always_include_hours: bool, decimal_marker: str +): + if seconds is None: + # treat None as end of time + seconds = INFINITY_SECONDS assert seconds >= 0, "non-negative timestamp expected" milliseconds = round(seconds * 1000.0) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 5be61de83..8e3610a28 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -12,6 +12,7 @@ from time import sleep from types import SimpleNamespace +import gooey_gui as gui import sentry_sdk from django.db.models import Sum from django.utils import timezone @@ -25,7 +26,6 @@ ) from starlette.requests import Request -import gooey_ui as st from app_users.models import AppUser, AppUserTransaction from bots.models import ( SavedRun, @@ -35,6 +35,7 @@ Workflow, RetentionPolicy, ) +from daras_ai.image_input import truncate_text_words from daras_ai.text_format import format_number_with_suffix from daras_ai_v2 import settings, urls from daras_ai_v2.api_examples_widget import api_example_generator @@ -55,9 +56,6 @@ from daras_ai_v2.manage_api_keys_widget import manage_api_keys from daras_ai_v2.meta_preview_url import meta_preview_url from daras_ai_v2.prompt_vars import variables_input -from daras_ai_v2.query_params import ( - gooey_get_query_params, -) from daras_ai_v2.query_params_util import ( extract_query_params, ) @@ -76,13 +74,6 @@ is_functions_enabled, render_called_functions, ) -from gooey_ui import ( - realtime_clear_subs, - RedirectException, -) -from gooey_ui.components.modal import Modal -from gooey_ui.components.pills import pill -from gooey_ui.pubsub import realtime_pull from payments.auto_recharge import ( should_attempt_auto_recharge, run_auto_recharge_gracefully, @@ -170,7 +161,7 @@ def __init__( @classmethod @property def endpoint(cls) -> str: - return f"/v2/{cls.slug_versions[0]}/" + return f"/v2/{cls.slug_versions[0]}" @classmethod def current_app_url( @@ -182,7 +173,7 @@ def current_app_url( ) -> str: if query_params is None: query_params = {} - example_id, run_id, uid = extract_query_params(gooey_get_query_params()) + example_id, run_id, uid = extract_query_params(gui.get_query_params()) return cls.app_url( tab=tab, example_id=example_id, @@ -276,7 +267,7 @@ def setup_sentry(self): def sentry_event_set_request(self, event, hint): request = event.setdefault("request", {}) request.setdefault("method", "POST") - request["data"] = st.session_state + request["data"] = gui.session_state if url := request.get("url"): f = furl(url) request["url"] = str( @@ -284,7 +275,7 @@ def sentry_event_set_request(self, event, hint): ) else: request["url"] = self.app_url( - tab=self.tab, query_params=st.get_query_params() + tab=self.tab, query_params=gui.get_query_params() ) return event @@ -313,36 +304,36 @@ def sentry_event_set_user(self, event, hint): return event def refresh_state(self): - _, run_id, uid = extract_query_params(gooey_get_query_params()) + _, run_id, uid = extract_query_params(gui.get_query_params()) channel = self.realtime_channel_name(run_id, uid) - output = realtime_pull([channel])[0] + output = gui.realtime_pull([channel])[0] if output: - st.session_state.update(output) + gui.session_state.update(output) def render(self): self.setup_sentry() - if self.get_run_state(st.session_state) == RecipeRunState.running: + if self.get_run_state(gui.session_state) == RecipeRunState.running: self.refresh_state() else: - realtime_clear_subs() + gui.realtime_clear_subs() self._user_disabled_check() self._check_if_flagged() - if st.session_state.get("show_report_workflow"): + if gui.session_state.get("show_report_workflow"): self.render_report_form() return self._render_header() - st.newline() + gui.newline() - with st.nav_tabs(): + with gui.nav_tabs(): for tab in self.get_tabs(): url = self.current_app_url(tab) - with st.nav_item(url, active=tab == self.tab): - st.html(tab.title) - with st.nav_tab_content(): + with gui.nav_item(url, active=tab == self.tab): + gui.html(tab.title) + with gui.nav_tab_content(): self.render_selected_tab() def _render_header(self): @@ -354,13 +345,13 @@ def _render_header(self): self, current_run, published_run, tab=self.tab ) - with st.div(className="d-flex justify-content-between mt-4"): - with st.div(className="d-lg-flex d-block align-items-center"): + with gui.div(className="d-flex justify-content-between mt-4"): + with gui.div(className="d-lg-flex d-block align-items-center"): if not tbreadcrumbs.has_breadcrumbs() and not self.run_user: self._render_title(tbreadcrumbs.h1_title) if tbreadcrumbs: - with st.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"): + with gui.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"): render_breadcrumbs( tbreadcrumbs, is_api_call=( @@ -376,7 +367,7 @@ def _render_header(self): if not is_root_example: self.render_author(author) - with st.div(className="d-flex align-items-center"): + with gui.div(className="d-flex align-items-center"): can_user_edit_run = self.can_user_edit_run(current_run, published_run) has_unpublished_changes = ( published_run @@ -389,8 +380,8 @@ def _render_header(self): if can_user_edit_run and has_unpublished_changes: self._render_unpublished_changes_indicator() - with st.div(className="d-flex align-items-start right-action-icons"): - st.html( + with gui.div(className="d-flex align-items-start right-action-icons"): + gui.html( """ """ ) - st.image(user.photo_url, className=class_name) + gui.image(user.photo_url, className=class_name) if user.display_name: name_style = {"fontSize": text_size} if text_size else {} - with st.tag("span", style=name_style): - st.html(html.escape(user.display_name)) + with gui.tag("span", style=name_style): + gui.html(html.escape(user.display_name)) def get_credits_click_url(self): if self.request.user and self.request.user.is_anonymous: @@ -1315,9 +1313,9 @@ def get_submit_container_props(self): ) def render_submit_button(self, key="--submit-1"): - with st.div(**self.get_submit_container_props()): - st.write("---") - col1, col2 = st.columns([2, 1], responsive=False) + with gui.div(**self.get_submit_container_props()): + gui.write("---") + col1, col2 = gui.columns([2, 1], responsive=False) col2.node.props[ "className" ] += " d-flex justify-content-end align-items-center" @@ -1325,25 +1323,25 @@ def render_submit_button(self, key="--submit-1"): with col1: self.render_run_cost() with col2: - submitted = st.button( + submitted = gui.button( "๐Ÿƒ Submit", key=key, type="primary", - # disabled=bool(st.session_state.get(StateKeys.run_status)), + # disabled=bool(gui.session_state.get(StateKeys.run_status)), ) if not submitted: return False try: self.validate_form_v2() except AssertionError as e: - st.error(str(e)) + gui.error(str(e)) return False else: return True def render_run_cost(self): url = self.get_credits_click_url() - run_cost = self.get_price_roundoff(st.session_state) + run_cost = self.get_price_roundoff(gui.session_state) ret = f'Run cost = {run_cost} credits' cost_note = self.get_cost_note() @@ -1354,18 +1352,18 @@ def render_run_cost(self): if additional_notes: ret += f" \n{additional_notes}" - st.caption(ret, line_clamp=1, unsafe_allow_html=True) + gui.caption(ret, line_clamp=1, unsafe_allow_html=True) def _render_step_row(self): key = "details-expander" - with st.expander("**โ„น๏ธ Details**", key=key): - if not st.session_state.get(key): + with gui.expander("**โ„น๏ธ Details**", key=key): + if not gui.session_state.get(key): return - col1, col2 = st.columns([1, 2]) + col1, col2 = gui.columns([1, 2]) with col1: self.render_description() with col2: - placeholder = st.div() + placeholder = gui.div() render_called_functions( saved_run=self.get_current_sr(), trigger=FunctionTrigger.pre ) @@ -1375,33 +1373,33 @@ def _render_step_row(self): pass else: with placeholder: - st.write("##### ๐Ÿ‘ฃ Steps") + gui.write("##### ๐Ÿ‘ฃ Steps") render_called_functions( saved_run=self.get_current_sr(), trigger=FunctionTrigger.post ) def _render_help(self): - placeholder = st.div() + placeholder = gui.div() try: self.render_usage_guide() except NotImplementedError: pass else: with placeholder: - st.write( + gui.write( """ ## How to Use This Recipe """ ) key = "discord-expander" - with st.expander( + with gui.expander( f"**๐Ÿ™‹๐Ÿฝโ€โ™€๏ธ Need more help? [Join our Discord]({settings.DISCORD_INVITE_URL})**", key=key, ): - if not st.session_state.get(key): + if not gui.session_state.get(key): return - st.markdown( + gui.markdown( """
@@ -1458,34 +1456,34 @@ def run_v2( raise NotImplementedError def _render_report_button(self): - example_id, run_id, uid = extract_query_params(gooey_get_query_params()) + example_id, run_id, uid = extract_query_params(gui.get_query_params()) # only logged in users can report a run (but not examples/default runs) if not (self.request.user and run_id and uid): return - reported = st.button( + reported = gui.button( ' Report', type="tertiary" ) if not reported: return - st.session_state["show_report_workflow"] = reported - st.experimental_rerun() + gui.session_state["show_report_workflow"] = reported + gui.rerun() def update_flag_for_run(self, run_id: str, uid: str, is_flagged: bool): ref = self.run_doc_sr(uid=uid, run_id=run_id) ref.is_flagged = is_flagged ref.save(update_fields=["is_flagged"]) - st.session_state["is_flagged"] = is_flagged + gui.session_state["is_flagged"] = is_flagged # Functions in every recipe feels like overkill for now, hide it in settings functions_in_settings = True def _render_input_col(self): self.render_form_v2() - placeholder = st.div() + placeholder = gui.div() - with st.expander("โš™๏ธ Settings"): + with gui.expander("โš™๏ธ Settings"): if self.functions_in_settings: functions_input(self.request.user) self.render_settings() @@ -1494,8 +1492,8 @@ def _render_input_col(self): self.render_variables() submitted = self.render_submit_button() - with st.div(style={"textAlign": "right"}): - st.caption( + with gui.div(style={"textAlign": "right"}): + gui.caption( "_By submitting, you agree to Gooey.AI's [terms](https://gooey.ai/terms) & " "[privacy policy](https://gooey.ai/privacy)._" ) @@ -1503,7 +1501,7 @@ def _render_input_col(self): def render_variables(self): if not self.functions_in_settings: - st.write("---") + gui.write("---") functions_input(self.request.user) variables_input( template_keys=self.template_keys, allow_add=is_functions_enabled() @@ -1522,30 +1520,30 @@ def get_run_state(cls, state: dict[str, typing.Any]) -> RecipeRunState: return RecipeRunState.starting def render_deleted_output(self): - col1, *_ = st.columns(2) + col1, *_ = gui.columns(2) with col1: - st.error( + gui.error( "This data has been deleted as per the retention policy.", icon="๐Ÿ—‘๏ธ", color="rgba(255, 200, 100, 0.5)", ) - st.newline() + gui.newline() self._render_output_col(is_deleted=True) - st.newline() + gui.newline() self.render_run_cost() def _render_output_col(self, *, submitted: bool = False, is_deleted: bool = False): assert inspect.isgeneratorfunction(self.run) - if st.session_state.get(StateKeys.pressed_randomize): - st.session_state["seed"] = int(gooey_rng.randrange(MAX_SEED)) - st.session_state.pop(StateKeys.pressed_randomize, None) + if gui.session_state.get(StateKeys.pressed_randomize): + gui.session_state["seed"] = int(gooey_rng.randrange(MAX_SEED)) + gui.session_state.pop(StateKeys.pressed_randomize, None) submitted = True if submitted or self.should_submit_after_login(): self.on_submit() - run_state = self.get_run_state(st.session_state) + run_state = self.get_run_state(gui.session_state) match run_state: case RecipeRunState.completed: self._render_completed_output() @@ -1569,16 +1567,16 @@ def _render_completed_output(self): pass def _render_failed_output(self): - err_msg = st.session_state.get(StateKeys.error_msg) - st.error(err_msg, unsafe_allow_html=True) + err_msg = gui.session_state.get(StateKeys.error_msg) + gui.error(err_msg, unsafe_allow_html=True) def _render_running_output(self): - run_status = st.session_state.get(StateKeys.run_status) + run_status = gui.session_state.get(StateKeys.run_status) html_spinner(run_status) self.render_extra_waiting_output() def render_extra_waiting_output(self): - created_at = st.session_state.get("created_at") + created_at = gui.session_state.get("created_at") if not created_at: return @@ -1589,15 +1587,15 @@ def render_extra_waiting_output(self): estimated_run_time = self.estimate_run_duration() if not estimated_run_time: return - with st.countdown_timer( + with gui.countdown_timer( end_time=created_at + datetime.timedelta(seconds=estimated_run_time), delay_text="Sorry for the wait. Your run is taking longer than we expected.", ): if self.is_current_user_owner() and self.request.user.email: - st.write( + gui.write( f"""We'll email **{self.request.user.email}** when your workflow is done.""" ) - st.write( + gui.write( f"""In the meantime, check out [๐Ÿš€ Examples]({self.current_app_url(RecipeTabs.examples)}) for inspiration.""" ) @@ -1609,21 +1607,21 @@ def on_submit(self): try: sr = self.create_new_run(enable_rate_limits=True) except ValidationError as e: - st.session_state[StateKeys.run_status] = None - st.session_state[StateKeys.error_msg] = str(e) + gui.session_state[StateKeys.run_status] = None + gui.session_state[StateKeys.error_msg] = str(e) return except RateLimitExceeded as e: - st.session_state[StateKeys.run_status] = None - st.session_state[StateKeys.error_msg] = e.detail.get("error", "") + gui.session_state[StateKeys.run_status] = None + gui.session_state[StateKeys.error_msg] = e.detail.get("error", "") return self.call_runner_task(sr) - raise RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid)) + raise gui.RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid)) def should_submit_after_login(self) -> bool: return ( - st.get_query_params().get(SUBMIT_AFTER_LOGIN_Q) + gui.get_query_params().get(SUBMIT_AFTER_LOGIN_Q) and self.request and self.request.user and not self.request.user.is_anonymous @@ -1632,9 +1630,9 @@ def should_submit_after_login(self) -> bool: def create_new_run( self, *, enable_rate_limits: bool = False, **defaults ) -> SavedRun: - st.session_state[StateKeys.run_status] = "Starting..." - st.session_state.pop(StateKeys.error_msg, None) - st.session_state.pop(StateKeys.run_time, None) + gui.session_state[StateKeys.run_status] = "Starting..." + gui.session_state.pop(StateKeys.error_msg, None) + gui.session_state.pop(StateKeys.run_time, None) self._setup_rng_seed() self.clear_outputs() @@ -1654,7 +1652,7 @@ def create_new_run( run_id = get_random_doc_id() parent_example_id, parent_run_id, parent_uid = extract_query_params( - gooey_get_query_params() + gui.get_query_params() ) parent = self.get_sr_from_query_params( parent_example_id, parent_run_id, parent_uid @@ -1673,8 +1671,8 @@ def create_new_run( ) # ensure the request is validated - state = st.session_state | json.loads( - self.RequestModel.parse_obj(st.session_state).json(exclude_unset=True) + state = gui.session_state | json.loads( + self.RequestModel.parse_obj(gui.session_state).json(exclude_unset=True) ) self.dump_state_to_sr(state, sr) @@ -1699,6 +1697,7 @@ def call_runner_task(self, sr: SavedRun): run_id=sr.run_id, uid=sr.uid, channel=self.realtime_channel_name(sr.run_id, sr.uid), + unsaved_state=self._unsaved_state(), ) | post_runner_tasks.s() ) @@ -1741,28 +1740,28 @@ def generate_credit_error_message(self, run_id, uid) -> str: return error_msg def _setup_rng_seed(self): - seed = st.session_state.get("seed") + seed = gui.session_state.get("seed") if not seed: return gooey_rng.seed(seed) def clear_outputs(self): # clear error msg - st.session_state.pop(StateKeys.error_msg, None) + gui.session_state.pop(StateKeys.error_msg, None) # clear outputs for field_name in self.ResponseModel.__fields__: - st.session_state.pop(field_name, None) + gui.session_state.pop(field_name, None) def _render_after_output(self): self._render_report_button() if "seed" in self.RequestModel.schema_json(): - randomize = st.button( + randomize = gui.button( ' Regenerate', type="tertiary" ) if randomize: - st.session_state[StateKeys.pressed_randomize] = True - st.experimental_rerun() + gui.session_state[StateKeys.pressed_randomize] = True + gui.rerun() @classmethod def load_state_from_sr(cls, sr: SavedRun) -> dict: @@ -1782,7 +1781,7 @@ def load_state_defaults(cls, state: dict): state.setdefault(k, v) return state - def fields_to_save(self) -> [str]: + def fields_to_save(self) -> list[str]: # only save the fields in request/response return [ field_name @@ -1794,6 +1793,18 @@ def fields_to_save(self) -> [str]: StateKeys.run_time, ] + def _unsaved_state(self) -> dict[str, typing.Any]: + result = {} + for field in self.fields_not_to_save(): + try: + result[field] = gui.session_state[field] + except KeyError: + pass + return result + + def fields_not_to_save(self) -> list[str]: + return [] + def _examples_tab(self): allow_hide = self.is_current_user_admin() @@ -1823,12 +1834,12 @@ def _saved_tab(self): created_by=self.request.user, )[:50] if not published_runs: - st.write("No published runs yet") + gui.write("No published runs yet") return def _render(pr: PublishedRun): - with st.div(className="mb-2", style={"font-size": "0.9rem"}): - pill( + with gui.div(className="mb-2", style={"font-size": "0.9rem"}): + gui.pill( PublishedRunVisibility(pr.visibility).get_badge_html(), unsafe_allow_html=True, className="border border-dark", @@ -1845,7 +1856,7 @@ def _history_tab(self): if self.is_current_user_admin(): uid = self.request.query_params.get("uid", uid) - before = gooey_get_query_params().get("updated_at__lt", None) + before = gui.get_query_params().get("updated_at__lt", None) if before: before = datetime.datetime.fromisoformat(before) else: @@ -1858,7 +1869,7 @@ def _history_tab(self): )[:25] ) if not run_history: - st.write("No history yet") + gui.write("No history yet") return grid_layout(3, run_history, self._render_run_preview) @@ -1867,15 +1878,15 @@ def _history_tab(self): RecipeTabs.history, query_params={"updated_at__lt": run_history[-1].to_dict()["updated_at"]}, ) - with st.link(to=str(next_url)): - st.html( + with gui.link(to=str(next_url)): + gui.html( # language=HTML f"""""" ) def ensure_authentication(self, next_url: str | None = None, anon_ok: bool = False): if not self.request.user or (self.request.user.is_anonymous and not anon_ok): - raise RedirectException(self.get_auth_url(next_url)) + raise gui.RedirectException(self.get_auth_url(next_url)) def get_auth_url(self, next_url: str | None = None) -> str: from routers.root import login @@ -1890,10 +1901,10 @@ def _render_run_preview(self, saved_run: SavedRun): is_latest_version = published_run and published_run.saved_run == saved_run tb = get_title_breadcrumbs(self, sr=saved_run, pr=published_run) - with st.link(to=saved_run.get_app_url()): - with st.div(className="mb-1", style={"fontSize": "0.9rem"}): + with gui.link(to=saved_run.get_app_url()): + with gui.div(className="mb-1", style={"fontSize": "0.9rem"}): if is_latest_version: - pill( + gui.pill( PublishedRunVisibility( published_run.visibility ).get_badge_html(), @@ -1901,7 +1912,7 @@ def _render_run_preview(self, saved_run: SavedRun): className="border border-dark", ) - st.write(f"#### {tb.h1_title}") + gui.write(f"#### {tb.h1_title}") updated_at = saved_run.updated_at if ( @@ -1909,34 +1920,34 @@ def _render_run_preview(self, saved_run: SavedRun): and isinstance(updated_at, datetime.datetime) and not saved_run.run_status ): - st.caption("Loading...", **render_local_dt_attrs(updated_at)) + gui.caption("Loading...", **render_local_dt_attrs(updated_at)) if saved_run.run_status: started_at_text(saved_run.created_at) html_spinner(saved_run.run_status, scroll_into_view=False) elif saved_run.error_msg: - st.error(saved_run.error_msg, unsafe_allow_html=True) + gui.error(saved_run.error_msg, unsafe_allow_html=True) return self.render_example(saved_run.to_dict()) def render_published_run_preview(self, published_run: PublishedRun): tb = get_title_breadcrumbs(self, published_run.saved_run, published_run) - with st.link(to=published_run.get_app_url()): - st.write(f"#### {tb.h1_title}") + with gui.link(to=published_run.get_app_url()): + gui.write(f"#### {tb.h1_title}") - with st.div(className="d-flex align-items-center justify-content-between"): - with st.div(): + with gui.div(className="d-flex align-items-center justify-content-between"): + with gui.div(): updated_at = published_run.saved_run.updated_at if updated_at and isinstance(updated_at, datetime.datetime): - st.caption("Loading...", **render_local_dt_attrs(updated_at)) + gui.caption("Loading...", **render_local_dt_attrs(updated_at)) if published_run.visibility == PublishedRunVisibility.PUBLIC: run_icon = '' run_count = format_number_with_suffix(published_run.get_run_count()) - st.caption(f"{run_icon} {run_count} runs", unsafe_allow_html=True) + gui.caption(f"{run_icon} {run_count} runs", unsafe_allow_html=True) if published_run.notes: - st.caption(published_run.notes, line_clamp=2) + gui.caption(published_run.notes, line_clamp=2) doc = published_run.saved_run.to_dict() self.render_example(doc) @@ -1950,28 +1961,28 @@ def _render_example_preview( tb = get_title_breadcrumbs(self, published_run.saved_run, published_run) if published_run.created_by: - with st.div(className="mb-1 text-truncate", style={"height": "1.5rem"}): + with gui.div(className="mb-1 text-truncate", style={"height": "1.5rem"}): self.render_author( published_run.created_by, image_size="20px", text_size="0.9rem", ) - with st.link(to=published_run.get_app_url()): - st.write(f"#### {tb.h1_title}") + with gui.link(to=published_run.get_app_url()): + gui.write(f"#### {tb.h1_title}") - with st.div(className="d-flex align-items-center justify-content-between"): - with st.div(): + with gui.div(className="d-flex align-items-center justify-content-between"): + with gui.div(): updated_at = published_run.saved_run.updated_at if updated_at and isinstance(updated_at, datetime.datetime): - st.caption("Loading...", **render_local_dt_attrs(updated_at)) + gui.caption("Loading...", **render_local_dt_attrs(updated_at)) run_icon = '' run_count = format_number_with_suffix(published_run.get_run_count()) - st.caption(f"{run_icon} {run_count} runs", unsafe_allow_html=True) + gui.caption(f"{run_icon} {run_count} runs", unsafe_allow_html=True) if published_run.notes: - st.caption(published_run.notes, line_clamp=2) + gui.caption(published_run.notes, line_clamp=2) if allow_hide: self._example_hide_button(published_run=published_run) @@ -1980,7 +1991,7 @@ def _render_example_preview( self.render_example(doc) def _example_hide_button(self, published_run: PublishedRun): - pressed_delete = st.button( + pressed_delete = gui.button( "๐Ÿ™ˆ๏ธ Hide", key=f"delete_example_{published_run.published_run_id}", style={"color": "red"}, @@ -1990,11 +2001,11 @@ def _example_hide_button(self, published_run: PublishedRun): self.set_hidden(published_run=published_run, hidden=True) def set_hidden(self, *, published_run: PublishedRun, hidden: bool): - with st.spinner("Hiding..."): + with gui.spinner("Hiding..."): published_run.is_approved_example = not hidden published_run.save() - st.experimental_rerun() + gui.rerun() def render_example(self, state: dict): pass @@ -2036,45 +2047,45 @@ def run_as_api_tab(self): / "docs" ) - st.markdown( + gui.markdown( f'๐Ÿ“– To learn more, take a look at our complete API', unsafe_allow_html=True, ) - st.write("#### ๐Ÿ“ค Example Request") + gui.write("#### ๐Ÿ“ค Example Request") - include_all = st.checkbox("##### Show all fields") - as_async = st.checkbox("##### Run Async") - as_form_data = st.checkbox("##### Upload Files via Form Data") + include_all = gui.checkbox("##### Show all fields") + as_async = gui.checkbox("##### Run Async") + as_form_data = gui.checkbox("##### Upload Files via Form Data") pr = self.get_current_published_run() api_url, request_body = self.get_example_request( - st.session_state, + gui.session_state, include_all=include_all, pr=pr, ) response_body = self.get_example_response_body( - st.session_state, as_async=as_async, include_all=include_all + gui.session_state, as_async=as_async, include_all=include_all ) api_example_generator( api_url=api_url, - request_body=request_body, + request_body=request_body | self._unsaved_state(), as_form_data=as_form_data, as_async=as_async, ) - st.write("") + gui.write("") - st.write("#### ๐ŸŽ Example Response") - st.json(response_body, expanded=True) + gui.write("#### ๐ŸŽ Example Response") + gui.json(response_body, expanded=True) if not self.request.user or self.request.user.is_anonymous: - st.write("**Please Login to generate the `$GOOEY_API_KEY`**") + gui.write("**Please Login to generate the `$GOOEY_API_KEY`**") return - st.write("---") - with st.tag("a", id="api-keys"): - st.write("### ๐Ÿ” API keys") + gui.write("---") + with gui.tag("a", id="api-keys"): + gui.write("### ๐Ÿ” API keys") manage_api_keys(self.request.user) @@ -2168,7 +2179,7 @@ def get_example_response_body( include_all: bool = False, ) -> dict: run_id = get_random_doc_id() - created_at = st.session_state.get( + created_at = gui.session_state.get( StateKeys.created_at, datetime.datetime.utcnow().isoformat() ) web_url = self.app_url( @@ -2182,7 +2193,7 @@ def get_example_response_body( run_id=run_id, web_url=web_url, created_at=created_at, - run_time_sec=st.session_state.get(StateKeys.run_time, 0), + run_time_sec=gui.session_state.get(StateKeys.run_time, 0), status="completed", output=output, ) @@ -2220,12 +2231,12 @@ def is_current_user_owner(self) -> bool: def started_at_text(dt: datetime.datetime): - with st.div(className="d-flex"): + with gui.div(className="d-flex"): text = "Started" - if seed := st.session_state.get("seed"): + if seed := gui.session_state.get("seed"): text += f' with seed {seed}' - st.caption(text + " on ", unsafe_allow_html=True) - st.caption( + gui.caption(text + " on ", unsafe_allow_html=True) + gui.caption( "...", className="text-black", **render_local_dt_attrs(dt), @@ -2235,23 +2246,23 @@ def started_at_text(dt: datetime.datetime): def render_output_caption(): caption = "" - run_time = st.session_state.get(StateKeys.run_time, 0) + run_time = gui.session_state.get(StateKeys.run_time, 0) if run_time: caption += f'Generated in {run_time :.1f}s' - if seed := st.session_state.get("seed"): + if seed := gui.session_state.get("seed"): caption += f' with seed {seed} ' - updated_at = st.session_state.get(StateKeys.updated_at, datetime.datetime.today()) + updated_at = gui.session_state.get(StateKeys.updated_at, datetime.datetime.today()) if updated_at: if isinstance(updated_at, str): updated_at = datetime.datetime.fromisoformat(updated_at) caption += " on " - with st.div(className="d-flex"): - st.caption(caption, unsafe_allow_html=True) + with gui.div(className="d-flex"): + gui.caption(caption, unsafe_allow_html=True) if updated_at: - st.caption( + gui.caption( "...", className="text-black", **render_local_dt_attrs( diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py index 0e16d25db..09b48d375 100644 --- a/daras_ai_v2/billing.py +++ b/daras_ai_v2/billing.py @@ -1,18 +1,15 @@ from typing import Literal +import gooey_gui as gui import stripe from django.core.exceptions import ValidationError -import gooey_ui as st from app_users.models import AppUser, PaymentProvider from daras_ai_v2 import icons, settings, paypal -from daras_ai_v2.base import RedirectException from daras_ai_v2.fastapi_tricks import get_app_route_url from daras_ai_v2.grid_layout_widget import grid_layout from daras_ai_v2.settings import templates from daras_ai_v2.user_date_widgets import render_local_date_attrs -from gooey_ui.components.modal import Modal -from gooey_ui.components.pills import pill from payments.models import PaymentMethodSummary from payments.plans import PricingPlan from scripts.migrate_existing_subscriptions import available_subscriptions @@ -29,30 +26,30 @@ def billing_page(user: AppUser): if user.subscription: render_current_plan(user) - with st.div(className="my-5"): + with gui.div(className="my-5"): render_credit_balance(user) - with st.div(className="my-5"): + with gui.div(className="my-5"): selected_payment_provider = render_all_plans(user) - with st.div(className="my-5"): + with gui.div(className="my-5"): render_addon_section(user, selected_payment_provider) if user.subscription and user.subscription.payment_provider: if user.subscription.payment_provider == PaymentProvider.STRIPE: - with st.div(className="my-5"): + with gui.div(className="my-5"): render_auto_recharge_section(user) - with st.div(className="my-5"): + with gui.div(className="my-5"): render_payment_information(user) - with st.div(className="my-5"): + with gui.div(className="my-5"): render_billing_history(user) def render_payments_setup(): from routers.account import payment_processing_route - st.html( + gui.html( templates.get_template("payment_setup.html").render( settings=settings, payment_processing_url=get_app_route_url(payment_processing_route), @@ -68,25 +65,25 @@ def render_current_plan(user: AppUser): else None ) - with st.div(className=f"{rounded_border} border-dark"): + with gui.div(className=f"{rounded_border} border-dark"): # ROW 1: Plan title and next invoice date left, right = left_and_right() with left: - st.write(f"#### Gooey.AI {plan.title}") + gui.write(f"#### Gooey.AI {plan.title}") if provider: - st.write( + gui.write( f"[{icons.edit} Manage Subscription](#payment-information)", unsafe_allow_html=True, ) - with right, st.div(className="d-flex align-items-center gap-1"): + with right, gui.div(className="d-flex align-items-center gap-1"): if provider and ( - next_invoice_ts := st.run_in_thread( + next_invoice_ts := gui.run_in_thread( user.subscription.get_next_invoice_timestamp, cache=True ) ): - st.html("Next invoice on ") - pill( + gui.html("Next invoice on ") + gui.pill( "...", text_bg="dark", **render_local_date_attrs( @@ -102,22 +99,22 @@ def render_current_plan(user: AppUser): # ROW 2: Plan pricing details left, right = left_and_right(className="mt-5") with left: - st.write(f"# {plan.pricing_title()}", className="no-margin") + gui.write(f"# {plan.pricing_title()}", className="no-margin") if plan.monthly_charge: provider_text = f" **via {provider.label}**" if provider else "" - st.caption("per month" + provider_text) + gui.caption("per month" + provider_text) - with right, st.div(className="text-end"): - st.write(f"# {plan.credits:,} credits", className="no-margin") + with right, gui.div(className="text-end"): + gui.write(f"# {plan.credits:,} credits", className="no-margin") if plan.monthly_charge: - st.write( + gui.write( f"**${plan.monthly_charge:,}** monthly renewal for {plan.credits:,} credits" ) def render_credit_balance(user: AppUser): - st.write(f"## Credit Balance: {user.balance:,}") - st.caption( + gui.write(f"## Credit Balance: {user.balance:,}") + gui.caption( "Every time you submit a workflow or make an API call, we deduct credits from your account." ) @@ -130,13 +127,13 @@ def render_all_plans(user: AppUser) -> PaymentProvider: ) all_plans = [plan for plan in PricingPlan if not plan.deprecated] - st.write("## All Plans") - plans_div = st.div(className="mb-1") + gui.write("## All Plans") + plans_div = gui.div(className="mb-1") if user.subscription and user.subscription.payment_provider: selected_payment_provider = None else: - with st.div(): + with gui.div(): selected_payment_provider = PaymentProvider[ payment_provider_radio() or PaymentProvider.STRIPE.name ] @@ -146,8 +143,8 @@ def _render_plan(plan: PricingPlan): extra_class = "border-dark" else: extra_class = "bg-light" - with st.div(className="d-flex flex-column h-100"): - with st.div( + with gui.div(className="d-flex flex-column h-100"): + with gui.div( className=f"{rounded_border} flex-grow-1 d-flex flex-column p-3 mb-2 {extra_class}" ): _render_plan_details(plan) @@ -158,30 +155,32 @@ def _render_plan(plan: PricingPlan): with plans_div: grid_layout(4, all_plans, _render_plan, separator=False) - with st.div(className="my-2 d-flex justify-content-center"): - st.caption(f"**[See all features & benefits]({settings.PRICING_DETAILS_URL})**") + with gui.div(className="my-2 d-flex justify-content-center"): + gui.caption( + f"**[See all features & benefits]({settings.PRICING_DETAILS_URL})**" + ) return selected_payment_provider def _render_plan_details(plan: PricingPlan): - with st.div(className="flex-grow-1"): - with st.div(className="mb-4"): - with st.tag("h4", className="mb-0"): - st.html(plan.title) - st.caption( + with gui.div(className="flex-grow-1"): + with gui.div(className="mb-4"): + with gui.tag("h4", className="mb-0"): + gui.html(plan.title) + gui.caption( plan.description, style={ "minHeight": "calc(var(--bs-body-line-height) * 2em)", "display": "block", }, ) - with st.div(className="my-3 w-100"): - with st.tag("h4", className="my-0 d-inline me-2"): - st.html(plan.pricing_title()) - with st.tag("span", className="text-muted my-0"): - st.html(plan.pricing_caption()) - st.write(plan.long_description, unsafe_allow_html=True) + with gui.div(className="my-3 w-100"): + with gui.tag("h4", className="my-0 d-inline me-2"): + gui.html(plan.pricing_title()) + with gui.tag("span", className="text-muted my-0"): + gui.html(plan.pricing_caption()) + gui.write(plan.long_description, unsafe_allow_html=True) def _render_plan_action_button( @@ -192,13 +191,13 @@ def _render_plan_action_button( ): btn_classes = "w-100 mt-3" if plan == current_plan: - st.button("Your Plan", className=btn_classes, disabled=True, type="tertiary") + gui.button("Your Plan", className=btn_classes, disabled=True, type="tertiary") elif plan.contact_us_link: - with st.link( + with gui.link( to=plan.contact_us_link, className=btn_classes + " btn btn-theme btn-primary", ): - st.html("Contact Us") + gui.html("Contact Us") elif user.subscription and not user.subscription.payment_provider: # don't show upgrade/downgrade buttons for enterprise customers # assumption: anyone without a payment provider attached is admin/enterprise @@ -262,11 +261,11 @@ def _render_update_subscription_button( key = f"change-sub-{plan.key}" match label: case "Downgrade": - downgrade_modal = Modal( + downgrade_modal = gui.Modal( "Confirm downgrade", key=f"downgrade-plan-modal-{plan.key}", ) - if st.button( + if gui.button( label, className=className, key=key, @@ -275,28 +274,28 @@ def _render_update_subscription_button( if downgrade_modal.is_open(): with downgrade_modal.container(): - st.write( + gui.write( f""" Are you sure you want to change from: **{current_plan.title} ({fmt_price(current_plan)})** to **{plan.title} ({fmt_price(plan)})**? """, className="d-block py-4", ) - with st.div(className="d-flex w-100"): - if st.button( + with gui.div(className="d-flex w-100"): + if gui.button( "Downgrade", className="btn btn-theme bg-danger border-danger text-white", key=f"{key}-confirm", ): change_subscription(user, plan) - if st.button( + if gui.button( "Cancel", className="border border-danger text-danger", key=f"{key}-cancel", ): downgrade_modal.close() case _: - if st.button(label, className=className, key=key): + if gui.button(label, className=className, key=key): change_subscription( user, plan, @@ -319,19 +318,19 @@ def change_subscription(user: AppUser, new_plan: PricingPlan, **kwargs): current_plan = PricingPlan.from_sub(user.subscription) if new_plan == current_plan: - raise RedirectException(get_app_route_url(account_route), status_code=303) + raise gui.RedirectException(get_app_route_url(account_route), status_code=303) if new_plan == PricingPlan.STARTER: user.subscription.cancel() user.subscription.delete() - raise RedirectException( + raise gui.RedirectException( get_app_route_url(payment_processing_route), status_code=303 ) match user.subscription.payment_provider: case PaymentProvider.STRIPE: if not new_plan.supports_stripe(): - st.error(f"Stripe subscription not available for {new_plan}") + gui.error(f"Stripe subscription not available for {new_plan}") subscription = stripe.Subscription.retrieve(user.subscription.external_id) stripe.Subscription.modify( @@ -346,13 +345,13 @@ def change_subscription(user: AppUser, new_plan: PricingPlan, **kwargs): **kwargs, proration_behavior="none", ) - raise RedirectException( + raise gui.RedirectException( get_app_route_url(payment_processing_route), status_code=303 ) case PaymentProvider.PAYPAL: if not new_plan.supports_paypal(): - st.error(f"Paypal subscription not available for {new_plan}") + gui.error(f"Paypal subscription not available for {new_plan}") subscription = paypal.Subscription.retrieve(user.subscription.external_id) paypal_plan_info = new_plan.get_paypal_plan() @@ -360,16 +359,16 @@ def change_subscription(user: AppUser, new_plan: PricingPlan, **kwargs): plan_id=paypal_plan_info["plan_id"], plan=paypal_plan_info["plan"], ) - raise RedirectException(approval_url, status_code=303) + raise gui.RedirectException(approval_url, status_code=303) case _: - st.error("Not implemented for this payment provider") + gui.error("Not implemented for this payment provider") def payment_provider_radio(**props) -> str | None: - with st.div(className="d-flex"): - st.write("###### Pay Via", className="d-block me-3") - return st.radio( + with gui.div(className="d-flex"): + gui.write("###### Pay Via", className="d-block me-3") + return gui.radio( "", options=PaymentProvider.names, format_func=lambda name: f'{PaymentProvider[name].label}', @@ -379,10 +378,10 @@ def payment_provider_radio(**props) -> str | None: def render_addon_section(user: AppUser, selected_payment_provider: PaymentProvider): if user.subscription: - st.write("# Purchase More Credits") + gui.write("# Purchase More Credits") else: - st.write("# Purchase Credits") - st.caption(f"Buy more credits. $1 per {settings.ADDON_CREDITS_PER_DOLLAR} credits") + gui.write("# Purchase Credits") + gui.caption(f"Buy more credits. $1 per {settings.ADDON_CREDITS_PER_DOLLAR} credits") if user.subscription: provider = PaymentProvider(user.subscription.payment_provider) @@ -396,22 +395,22 @@ def render_addon_section(user: AppUser, selected_payment_provider: PaymentProvid def render_paypal_addon_buttons(): - selected_amt = st.horizontal_radio( + selected_amt = gui.horizontal_radio( "", settings.ADDON_AMOUNT_CHOICES, format_func=lambda amt: f"${amt:,}", checked_by_default=False, ) if selected_amt: - st.js( + gui.js( f"setPaypalAddonQuantity({int(selected_amt) * settings.ADDON_CREDITS_PER_DOLLAR})" ) - st.div( + gui.div( id="paypal-addon-buttons", className="mt-2", style={"width": "fit-content"}, ) - st.div(id="paypal-result-message") + gui.div(id="paypal-result-message") def render_stripe_addon_buttons(user: AppUser): @@ -420,10 +419,10 @@ def render_stripe_addon_buttons(user: AppUser): def render_stripe_addon_button(dollat_amt: int, user: AppUser): - confirm_purchase_modal = Modal( + confirm_purchase_modal = gui.Modal( "Confirm Purchase", key=f"confirm-purchase-{dollat_amt}" ) - if st.button(f"${dollat_amt:,}", type="primary"): + if gui.button(f"${dollat_amt:,}", type="primary"): if user.subscription: confirm_purchase_modal.open() else: @@ -432,32 +431,32 @@ def render_stripe_addon_button(dollat_amt: int, user: AppUser): if not confirm_purchase_modal.is_open(): return with confirm_purchase_modal.container(): - st.write( + gui.write( f""" Please confirm your purchase: **{dollat_amt * settings.ADDON_CREDITS_PER_DOLLAR:,} credits for ${dollat_amt}**. """, className="py-4 d-block text-center", ) - with st.div(className="d-flex w-100 justify-content-end"): - if st.session_state.get("--confirm-purchase"): - success = st.run_in_thread( + with gui.div(className="d-flex w-100 justify-content-end"): + if gui.session_state.get("--confirm-purchase"): + success = gui.run_in_thread( user.subscription.stripe_attempt_addon_purchase, args=[dollat_amt], placeholder="Processing payment...", ) if success is None: return - st.session_state.pop("--confirm-purchase") + gui.session_state.pop("--confirm-purchase") if success: confirm_purchase_modal.close() else: - st.error("Payment failed... Please try again.") + gui.error("Payment failed... Please try again.") return - if st.button("Cancel", className="border border-danger text-danger me-2"): + if gui.button("Cancel", className="border border-danger text-danger me-2"): confirm_purchase_modal.close() - st.button("Buy", type="primary", key="--confirm-purchase") + gui.button("Buy", type="primary", key="--confirm-purchase") def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int): @@ -478,7 +477,7 @@ def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int): "payment_method_save": "enabled", }, ) - raise RedirectException(checkout_session.url, status_code=303) + raise gui.RedirectException(checkout_session.url, status_code=303) def render_stripe_subscription_button( @@ -490,13 +489,13 @@ def render_stripe_subscription_button( key: str, ): if not plan.supports_stripe(): - st.write("Stripe subscription not available") + gui.write("Stripe subscription not available") return # IMPORTANT: key=... is needed here to maintain uniqueness # of buttons with the same label. otherwise, all buttons # will be the same to the server - if st.button(label, key=key, type=btn_type): + if gui.button(label, key=key, type=btn_type): stripe_subscription_checkout_redirect(user=user, plan=plan) @@ -522,7 +521,7 @@ def stripe_subscription_checkout_redirect(user: AppUser, plan: PricingPlan): "payment_method_save": "enabled", }, ) - raise RedirectException(checkout_session.url, status_code=303) + raise gui.RedirectException(checkout_session.url, status_code=303) def render_paypal_subscription_button( @@ -530,11 +529,11 @@ def render_paypal_subscription_button( plan: PricingPlan, ): if not plan.supports_paypal(): - st.write("Paypal subscription not available") + gui.write("Paypal subscription not available") return lookup_key = plan.key - st.html( + gui.html( f"""
str: @@ -617,8 +616,8 @@ def render_billing_history(user: AppUser, limit: int = 50): if not txns: return - st.write("## Billing History", className="d-block") - st.table( + gui.write("## Billing History", className="d-block") + gui.table( pd.DataFrame.from_records( [ { @@ -633,7 +632,7 @@ def render_billing_history(user: AppUser, limit: int = 50): ), ) if txns.count() > limit: - st.caption(f"Showing only the most recent {limit} transactions.") + gui.caption(f"Showing only the most recent {limit} transactions.") def render_auto_recharge_section(user: AppUser): @@ -643,9 +642,9 @@ def render_auto_recharge_section(user: AppUser): ) subscription = user.subscription - st.write("## Auto Recharge & Limits") - with st.div(className="h4"): - auto_recharge_enabled = st.checkbox( + gui.write("## Auto Recharge & Limits") + with gui.div(className="h4"): + auto_recharge_enabled = gui.checkbox( "Enable auto recharge", value=subscription.auto_recharge_enabled, ) @@ -656,20 +655,20 @@ def render_auto_recharge_section(user: AppUser): subscription.save(update_fields=["auto_recharge_enabled"]) if not auto_recharge_enabled: - st.caption( + gui.caption( "Enable auto recharge to automatically keep your credit balance topped up." ) return - col1, col2 = st.columns(2) - with col1, st.div(className="mb-2"): - subscription.auto_recharge_topup_amount = st.selectbox( + col1, col2 = gui.columns(2) + with col1, gui.div(className="mb-2"): + subscription.auto_recharge_topup_amount = gui.selectbox( "###### Automatically purchase", options=settings.ADDON_AMOUNT_CHOICES, format_func=lambda amt: f"{settings.ADDON_CREDITS_PER_DOLLAR * int(amt):,} credits for ${amt}", value=subscription.auto_recharge_topup_amount, ) - subscription.auto_recharge_balance_threshold = st.selectbox( + subscription.auto_recharge_balance_threshold = gui.selectbox( "###### when balance falls below", options=settings.AUTO_RECHARGE_BALANCE_THRESHOLD_CHOICES, format_func=lambda c: f"{c:,} credits", @@ -677,49 +676,51 @@ def render_auto_recharge_section(user: AppUser): ) with col2: - st.write("###### Monthly Recharge Budget") - st.caption( + gui.write("###### Monthly Recharge Budget") + gui.caption( """ If your account exceeds this budget in a given calendar month, subsequent runs & API requests will be rejected. """, ) - with st.div(className="d-flex align-items-center"): - user.subscription.monthly_spending_budget = st.number_input( + with gui.div(className="d-flex align-items-center"): + user.subscription.monthly_spending_budget = gui.number_input( "", min_value=10, value=user.subscription.monthly_spending_budget, key="monthly-spending-budget", ) - st.write("USD", className="d-block ms-2") + gui.write("USD", className="d-block ms-2") - st.write("###### Email Notification Threshold") - st.caption( + gui.write("###### Email Notification Threshold") + gui.caption( """ If your account purchases exceed this threshold in a given calendar month, you will receive an email notification. """ ) - with st.div(className="d-flex align-items-center"): - user.subscription.monthly_spending_notification_threshold = st.number_input( - "", - min_value=10, - value=user.subscription.monthly_spending_notification_threshold, - key="monthly-spending-notification-threshold", + with gui.div(className="d-flex align-items-center"): + user.subscription.monthly_spending_notification_threshold = ( + gui.number_input( + "", + min_value=10, + value=user.subscription.monthly_spending_notification_threshold, + key="monthly-spending-notification-threshold", + ) ) - st.write("USD", className="d-block ms-2") + gui.write("USD", className="d-block ms-2") - if st.button("Save", type="primary", key="save-auto-recharge-and-limits"): + if gui.button("Save", type="primary", key="save-auto-recharge-and-limits"): try: subscription.full_clean() except ValidationError as e: - st.error(str(e)) + gui.error(str(e)) else: subscription.save() - st.success("Settings saved!") + gui.success("Settings saved!") def left_and_right(*, className: str = "", **props): className += " d-flex flex-row justify-content-between align-items-center" - with st.div(className=className, **props): - return st.div(), st.div() + with gui.div(className=className, **props): + return gui.div(), gui.div() diff --git a/daras_ai_v2/bot_integration_widgets.py b/daras_ai_v2/bot_integration_widgets.py index b6324ce5f..4958f389d 100644 --- a/daras_ai_v2/bot_integration_widgets.py +++ b/daras_ai_v2/bot_integration_widgets.py @@ -6,7 +6,7 @@ from django.utils.text import slugify from furl import furl -import gooey_ui as st +import gooey_gui as gui from app_users.models import AppUser from bots.models import BotIntegration, BotIntegrationAnalysisRun, Platform from daras_ai_v2 import settings, icons @@ -19,52 +19,49 @@ def general_integration_settings(bi: BotIntegration, current_user: AppUser): - if st.session_state.get(f"_bi_reset_{bi.id}"): - st.session_state[f"_bi_streaming_enabled_{bi.id}"] = ( + if gui.session_state.get(f"_bi_reset_{bi.id}"): + gui.session_state[f"_bi_streaming_enabled_{bi.id}"] = ( BotIntegration._meta.get_field("streaming_enabled").default ) - st.session_state[f"_bi_show_feedback_buttons_{bi.id}"] = ( + gui.session_state[f"_bi_show_feedback_buttons_{bi.id}"] = ( BotIntegration._meta.get_field("show_feedback_buttons").default ) - st.session_state["analysis_urls"] = [] - st.session_state.pop("--list-view:analysis_urls", None) - - bi.streaming_enabled = st.checkbox( - "**๐Ÿ“ก Streaming Enabled**", - value=bi.streaming_enabled, - key=f"_bi_streaming_enabled_{bi.id}", - ) - st.caption("Responses will be streamed to the user in real-time if enabled.") - bi.show_feedback_buttons = st.checkbox( - "**๐Ÿ‘๐Ÿพ ๐Ÿ‘Ž๐Ÿฝ Show Feedback Buttons**", - value=bi.show_feedback_buttons, - key=f"_bi_show_feedback_buttons_{bi.id}", - ) - st.caption( - "Users can rate and provide feedback on every copilot response if enabled." - ) - - st.caption( - "Please note that this language is distinct from the one provided in the workflow settings. Hence, this allows you to integrate the same bot in many languages." - ) + gui.session_state["analysis_urls"] = [] + gui.session_state.pop("--list-view:analysis_urls", None) + + if bi.platform != Platform.TWILIO: + bi.streaming_enabled = gui.checkbox( + "**๐Ÿ“ก Streaming Enabled**", + value=bi.streaming_enabled, + key=f"_bi_streaming_enabled_{bi.id}", + ) + gui.caption("Responses will be streamed to the user in real-time if enabled.") + bi.show_feedback_buttons = gui.checkbox( + "**๐Ÿ‘๐Ÿพ ๐Ÿ‘Ž๐Ÿฝ Show Feedback Buttons**", + value=bi.show_feedback_buttons, + key=f"_bi_show_feedback_buttons_{bi.id}", + ) + gui.caption( + "Users can rate and provide feedback on every copilot response if enabled." + ) - st.write( + gui.write( """ ##### ๐Ÿง  Analysis Scripts Analyze each incoming message and the copilot's response using a Gooey.AI /LLM workflow. Must return a JSON object. [Learn more](https://gooey.ai/docs/guides/build-your-ai-copilot/conversation-analysis). """ ) - if "analysis_urls" not in st.session_state: - st.session_state["analysis_urls"] = [ + if "analysis_urls" not in gui.session_state: + gui.session_state["analysis_urls"] = [ (anal.published_run or anal.saved_run).get_app_url() for anal in bi.analysis_runs.all() ] - if st.session_state.get("analysis_urls"): + if gui.session_state.get("analysis_urls"): from recipes.VideoBots import VideoBotsPage - st.anchor( + gui.anchor( "๐Ÿ“Š View Results", str( furl( @@ -80,7 +77,7 @@ def general_integration_settings(bi: BotIntegration, current_user: AppUser): input_analysis_runs = [] def render_workflow_url_input(key: str, del_key: str | None, d: dict): - with st.columns([3, 2])[0]: + with gui.columns([3, 2])[0]: ret = workflow_url_input( page_cls=CompareLLMPage, key=key, @@ -103,10 +100,10 @@ def render_workflow_url_input(key: str, del_key: str | None, d: dict): flatten_dict_key="url", ) - with st.center(): - with st.div(): - pressed_update = st.button("โœ… Save") - pressed_reset = st.button( + with gui.center(): + with gui.div(): + pressed_update = gui.button("โœ… Save") + pressed_reset = gui.button( "Reset", key=f"_bi_reset_{bi.id}", type="tertiary" ) if pressed_update or pressed_reset: @@ -124,18 +121,61 @@ def render_workflow_url_input(key: str, del_key: str | None, d: dict): # delete any analysis runs that were removed bi.analysis_runs.all().exclude(id__in=input_analysis_runs).delete() except ValidationError as e: - st.error(str(e)) - st.write("---") + gui.error(str(e)) + gui.write("---") + + +def twilio_specific_settings(bi: BotIntegration): + SETTINGS_FIELDS = ["twilio_use_missed_call", "twilio_initial_text", "twilio_initial_audio_url", "twilio_waiting_text", "twilio_waiting_audio_url"] # fmt:skip + if gui.session_state.get(f"_bi_reset_{bi.id}"): + for field in SETTINGS_FIELDS: + gui.session_state[f"_bi_{field}_{bi.id}"] = BotIntegration._meta.get_field( + field + ).default + + bi.twilio_initial_text = gui.text_area( + "###### ๐Ÿ“ Initial Text (said at the beginning of each call)", + value=bi.twilio_initial_text, + key=f"_bi_twilio_initial_text_{bi.id}", + ) + bi.twilio_initial_audio_url = ( + gui.file_uploader( + "###### ๐Ÿ”Š Initial Audio (played at the beginning of each call)", + accept=["audio/*"], + key=f"_bi_twilio_initial_audio_url_{bi.id}", + ) + or "" + ) + bi.twilio_waiting_audio_url = ( + gui.file_uploader( + "###### ๐ŸŽต Waiting Audio (played while waiting for a response -- Voice)", + accept=["audio/*"], + key=f"_bi_twilio_waiting_audio_url_{bi.id}", + ) + or "" + ) + bi.twilio_waiting_text = gui.text_area( + "###### ๐Ÿ“ Waiting Text (texted while waiting for a response -- SMS)", + key=f"_bi_twilio_waiting_text_{bi.id}", + ) + bi.twilio_use_missed_call = gui.checkbox( + "๐Ÿ“ž Use Missed Call", + value=bi.twilio_use_missed_call, + key=f"_bi_twilio_use_missed_call_{bi.id}", + ) + gui.caption( + "When enabled, immediately hangs up incoming calls and calls back the user so they don't incur charges (depending on their carrier/plan)." + ) def slack_specific_settings(bi: BotIntegration, default_name: str): - if st.session_state.get(f"_bi_reset_{bi.id}"): - st.session_state[f"_bi_name_{bi.id}"] = default_name - st.session_state[f"_bi_slack_read_receipt_msg_{bi.id}"] = ( + if gui.session_state.get(f"_bi_reset_{bi.id}"): + gui.session_state[f"_bi_name_{bi.id}"] = default_name + gui.session_state[f"_bi_slack_read_receipt_msg_{bi.id}"] = ( BotIntegration._meta.get_field("slack_read_receipt_msg").default ) - bi.slack_read_receipt_msg = st.text_input( + bi.slack_read_receipt_msg = gui.text_input( """ ##### โœ… Read Receipt This message is sent immediately after recieving a user message and replaced with the copilot's response once it's ready. @@ -145,7 +185,7 @@ def slack_specific_settings(bi: BotIntegration, default_name: str): value=bi.slack_read_receipt_msg, key=f"_bi_slack_read_receipt_msg_{bi.id}", ) - bi.name = st.text_input( + bi.name = gui.text_input( """ ##### ๐Ÿชช Channel Specific Bot Name This is the name the bot will post as in this specific channel (to be displayed in Slack) @@ -154,7 +194,7 @@ def slack_specific_settings(bi: BotIntegration, default_name: str): value=bi.name, key=f"_bi_name_{bi.id}", ) - st.caption("Enable streaming messages to Slack in real-time.") + gui.caption("Enable streaming messages to Slack in real-time.") def broadcast_input(bi: BotIntegration): @@ -169,7 +209,7 @@ def broadcast_input(bi: BotIntegration): ) / "docs" ) - text = st.text_area( + text = gui.text_area( f""" ###### Broadcast Message ๐Ÿ“ข Broadcast a message to all users of this integration using this bot account. \\ @@ -178,40 +218,50 @@ def broadcast_input(bi: BotIntegration): key=key + ":text", placeholder="Type your message here...", ) - audio = st.file_uploader( + audio = gui.file_uploader( "**๐ŸŽค Audio**", key=key + ":audio", help="Attach a video to this message.", optional=True, accept=["audio/*"], ) - video = st.file_uploader( - "**๐ŸŽฅ Video**", - key=key + ":video", - help="Attach a video to this message.", - optional=True, - accept=["video/*"], - ) - documents = st.file_uploader( - "**๐Ÿ“„ Documents**", - key=key + ":documents", - help="Attach documents to this message.", - accept_multiple_files=True, - optional=True, - ) + video = None + documents = None + medium = "Voice Call" + if bi.platform == Platform.TWILIO: + medium = gui.selectbox( + "###### ๐Ÿ“ฑ Medium", + ["Voice Call", "SMS/MMS"], + key=key + ":medium", + ) + else: + video = gui.file_uploader( + "**๐ŸŽฅ Video**", + key=key + ":video", + help="Attach a video to this message.", + optional=True, + accept=["video/*"], + ) + documents = gui.file_uploader( + "**๐Ÿ“„ Documents**", + key=key + ":documents", + help="Attach documents to this message.", + accept_multiple_files=True, + optional=True, + ) should_confirm_key = key + ":should_confirm" confirmed_send_btn = key + ":confirmed_send" - if st.button("๐Ÿ“ค Send Broadcast", style=dict(height="3.2rem"), key=key + ":send"): - st.session_state[should_confirm_key] = True - if not st.session_state.get(should_confirm_key): + if gui.button("๐Ÿ“ค Send Broadcast", style=dict(height="3.2rem"), key=key + ":send"): + gui.session_state[should_confirm_key] = True + if not gui.session_state.get(should_confirm_key): return convos = bi.conversations.all() - if st.session_state.get(confirmed_send_btn): - st.success("Started sending broadcast!") - st.session_state.pop(confirmed_send_btn) - st.session_state.pop(should_confirm_key) + if gui.session_state.get(confirmed_send_btn): + gui.success("Started sending broadcast!") + gui.session_state.pop(confirmed_send_btn) + gui.session_state.pop(should_confirm_key) send_broadcast_msgs_chunked( text=text, audio=audio, @@ -219,15 +269,16 @@ def broadcast_input(bi: BotIntegration): documents=documents, bi=bi, convo_qs=convos, + medium=medium, ) else: if not convos.exists(): - st.error("No users have interacted with this bot yet.", icon="โš ๏ธ") + gui.error("No users have interacted with this bot yet.", icon="โš ๏ธ") return - st.write( + gui.write( f"Are you sure? This will send a message to all {convos.count()} users that have ever interacted with this bot.\n" ) - st.button("โœ… Yes, Send", key=confirmed_send_btn) + gui.button("โœ… Yes, Send", key=confirmed_send_btn) def get_bot_test_link(bi: BotIntegration) -> str | None: @@ -254,6 +305,8 @@ def get_bot_test_link(bi: BotIntegration) -> str | None: integration_name=slugify(bi.name) or "untitled", ), ) + elif bi.twilio_phone_number: + return str(furl("tel:") / bi.twilio_phone_number.as_e164) else: return None @@ -265,7 +318,7 @@ def get_web_widget_embed_code(bi: BotIntegration) -> str: integration_id=bi.api_integration_id(), integration_name=slugify(bi.name) or "untitled", ), - ) + ).rstrip("/") return dedent( f"""
@@ -275,43 +328,43 @@ def get_web_widget_embed_code(bi: BotIntegration) -> str: def web_widget_config(bi: BotIntegration, user: AppUser | None): - with st.div(style={"width": "100%", "textAlign": "left"}): - col1, col2 = st.columns(2) + with gui.div(style={"width": "100%", "textAlign": "left"}): + col1, col2 = gui.columns(2) with col1: - if st.session_state.get("--update-display-picture"): - display_pic = st.file_uploader( + if gui.session_state.get("--update-display-picture"): + display_pic = gui.file_uploader( label="###### Display Picture", accept=["image/*"], ) if display_pic: bi.photo_url = display_pic else: - if st.button(f"{icons.camera} Change Photo"): - st.session_state["--update-display-picture"] = True - st.experimental_rerun() - bi.name = st.text_input("###### Name", value=bi.name) - bi.descripton = st.text_area( + if gui.button(f"{icons.camera} Change Photo"): + gui.session_state["--update-display-picture"] = True + gui.rerun() + bi.name = gui.text_input("###### Name", value=bi.name) + bi.descripton = gui.text_area( "###### Description", value=bi.descripton, ) - scol1, scol2 = st.columns(2) + scol1, scol2 = gui.columns(2) with scol1: - bi.by_line = st.text_input( + bi.by_line = gui.text_input( "###### By Line", value=bi.by_line or (user and f"By {user.display_name}"), ) with scol2: - bi.website_url = st.text_input( + bi.website_url = gui.text_input( "###### Website Link", value=bi.website_url or (user and user.website_url), ) - st.write("###### Conversation Starters") + gui.write("###### Conversation Starters") bi.conversation_starters = list( filter( None, [ - st.text_input("", key=f"--question-{i}", value=value) + gui.text_input("", key=f"--question-{i}", value=value) for i, value in zip_longest(range(4), bi.conversation_starters) ], ) @@ -332,39 +385,39 @@ def web_widget_config(bi: BotIntegration, user: AppUser | None): | bi.web_config_extras ) - scol1, scol2 = st.columns(2) + scol1, scol2 = gui.columns(2) with scol1: - config["showSources"] = st.checkbox( + config["showSources"] = gui.checkbox( "Show Sources", value=config["showSources"] ) - config["enablePhotoUpload"] = st.checkbox( + config["enablePhotoUpload"] = gui.checkbox( "Allow Photo Upload", value=config["enablePhotoUpload"] ) with scol2: - config["enableAudioMessage"] = st.checkbox( + config["enableAudioMessage"] = gui.checkbox( "Enable Audio Message", value=config["enableAudioMessage"] ) - config["enableLipsyncVideo"] = st.checkbox( + config["enableLipsyncVideo"] = gui.checkbox( "Enable Lipsync Video", value=config["enableLipsyncVideo"] ) - # config["branding"]["showPoweredByGooey"] = st.checkbox( + # config["branding"]["showPoweredByGooey"] = gui.checkbox( # "Show Powered By Gooey", value=config["branding"]["showPoweredByGooey"] # ) - with st.expander("Embed Settings"): - st.caption( + with gui.expander("Embed Settings"): + gui.caption( "These settings will take effect when you embed the widget on your website." ) - scol1, scol2 = st.columns(2) + scol1, scol2 = gui.columns(2) with scol1: - config["mode"] = st.selectbox( + config["mode"] = gui.selectbox( "###### Mode", ["popup", "inline", "fullscreen"], value=config["mode"], format_func=lambda x: x.capitalize(), ) if config["mode"] == "popup": - config["branding"]["fabLabel"] = st.text_input( + config["branding"]["fabLabel"] = gui.text_input( "###### Label", value=config["branding"].get("fabLabel", "Help"), ) @@ -374,28 +427,28 @@ def web_widget_config(bi: BotIntegration, user: AppUser | None): # remove defaults bi.web_config_extras = config - with st.div(className="d-flex justify-content-end"): - if st.button( + with gui.div(className="d-flex justify-content-end"): + if gui.button( f"{icons.save} Update Web Preview", type="primary", className="align-right", ): bi.save() - st.experimental_rerun() + gui.rerun() with col2: - with st.center(), st.div(): + with gui.center(), gui.div(): web_preview_tab = f"{icons.chat} Web Preview" api_tab = f"{icons.api} API" - selected = st.horizontal_radio("", [web_preview_tab, api_tab]) + selected = gui.horizontal_radio("", [web_preview_tab, api_tab]) if selected == web_preview_tab: - st.html( + gui.html( # language=html f"""
""" ) - st.js( + gui.js( # language=javascript """ async function loadGooeyEmbed() { diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index e00805385..58a154590 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -2,6 +2,7 @@ import typing from datetime import datetime +import gooey_gui as gui from django.db import transaction from django.utils import timezone from fastapi import HTTPException @@ -24,7 +25,6 @@ from daras_ai_v2.base import BasePage, RecipeRunState, StateKeys from daras_ai_v2.language_model import CHATML_ROLE_USER, CHATML_ROLE_ASSISTANT from daras_ai_v2.vector_search import doc_url_to_file_metadata -from gooey_ui.pubsub import realtime_subscribe from gooeysite.bg_db_conn import db_middleware, get_celery_result_db_safe from recipes.VideoBots import VideoBotsPage, ReplyButton from routers.api import submit_api_call @@ -402,7 +402,7 @@ def _process_and_send_msg( if bot.streaming_enabled: # subscribe to the realtime channel for updates channel = page.realtime_channel_name(run_id, uid) - with realtime_subscribe(channel) as realtime_gen: + with gui.realtime_subscribe(channel) as realtime_gen: for state in realtime_gen: bot.recipe_run_state = page.get_run_state(state) bot.run_status = state.get(StateKeys.run_status) or "" @@ -497,7 +497,7 @@ def build_run_vars(convo: Conversation, user_msg_id: str): bi = convo.bot_integration if bi.platform == Platform.WEB: - user_msg_id = user_msg_id.lstrip(MSG_ID_PREFIX) + user_msg_id = user_msg_id.removeprefix(MSG_ID_PREFIX) variables = dict( platform=Platform(bi.platform).name, integration_id=bi.api_integration_id(), diff --git a/daras_ai_v2/breadcrumbs.py b/daras_ai_v2/breadcrumbs.py index c6d389b24..a0d06101e 100644 --- a/daras_ai_v2/breadcrumbs.py +++ b/daras_ai_v2/breadcrumbs.py @@ -1,6 +1,6 @@ import typing -import gooey_ui as st +import gooey_gui as gui from bots.models import ( SavedRun, PublishedRun, @@ -32,7 +32,7 @@ def has_breadcrumbs(self): def render_breadcrumbs(breadcrumbs: TitleBreadCrumbs, *, is_api_call: bool = False): - st.html( + gui.html( """