diff --git a/Dockerfile b/Dockerfile
index 929e01ae5..83052f5ed 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -54,6 +54,9 @@ COPY . .
ENV FORWARDED_ALLOW_IPS='*'
ENV PYTHONUNBUFFERED=1
+ARG CAPROVER_GIT_COMMIT_SHA=${CAPROVER_GIT_COMMIT_SHA}
+ENV CAPROVER_GIT_COMMIT_SHA=${CAPROVER_GIT_COMMIT_SHA}
+
EXPOSE 8000
EXPOSE 8501
diff --git a/app_users/admin.py b/app_users/admin.py
index caa61d223..56f325fca 100644
--- a/app_users/admin.py
+++ b/app_users/admin.py
@@ -4,7 +4,8 @@
from app_users import models
from bots.admin_links import open_in_new_tab, list_related_html_url
-from bots.models import SavedRun
+from bots.models import SavedRun, PublishedRun
+from embeddings.models import EmbeddedFile
from usage_costs.models import UsageCost
@@ -31,7 +32,12 @@ class AppUserAdmin(admin.ModelAdmin):
"is_paying",
"disable_safety_checker",
"disable_rate_limits",
- ("user_runs", "view_transactions"),
+ (
+ "view_saved_runs",
+ "view_published_runs",
+ "view_embedded_files",
+ "view_transactions",
+ ),
"created_at",
"upgraded_from_anonymous_at",
("open_in_firebase", "open_in_stripe"),
@@ -86,7 +92,9 @@ class AppUserAdmin(admin.ModelAdmin):
"total_usage_cost",
"created_at",
"upgraded_from_anonymous_at",
- "user_runs",
+ "view_saved_runs",
+ "view_published_runs",
+ "view_embedded_files",
"view_transactions",
"open_in_firebase",
"open_in_stripe",
@@ -95,7 +103,7 @@ class AppUserAdmin(admin.ModelAdmin):
autocomplete_fields = ["handle", "subscription"]
@admin.display(description="User Runs")
- def user_runs(self, user: models.AppUser):
+ def view_saved_runs(self, user: models.AppUser):
return list_related_html_url(
SavedRun.objects.filter(uid=user.uid),
query_param="uid",
@@ -103,6 +111,24 @@ def user_runs(self, user: models.AppUser):
show_add=False,
)
+ @admin.display(description="Published Runs")
+ def view_published_runs(self, user: models.AppUser):
+ return list_related_html_url(
+ PublishedRun.objects.filter(created_by=user),
+ query_param="created_by",
+ instance_id=user.id,
+ show_add=False,
+ )
+
+ @admin.display(description="Embedded Files")
+ def view_embedded_files(self, user: models.AppUser):
+ return list_related_html_url(
+ EmbeddedFile.objects.filter(created_by=user),
+ query_param="created_by",
+ instance_id=user.id,
+ show_add=False,
+ )
+
@admin.display(description="Total Payments")
def total_payments(self, user: models.AppUser):
return "$" + str(
diff --git a/bots/admin.py b/bots/admin.py
index 51f6e9aa9..4b6a731c8 100644
--- a/bots/admin.py
+++ b/bots/admin.py
@@ -77,8 +77,6 @@
"twilio_username",
"twilio_password",
"twilio_use_missed_call",
- "twilio_tts_voice",
- "twilio_asr_language",
"twilio_initial_text",
"twilio_initial_audio_url",
"twilio_waiting_text",
diff --git a/bots/migrations/0079_remove_botintegration_twilio_asr_language_and_more.py b/bots/migrations/0079_remove_botintegration_twilio_asr_language_and_more.py
new file mode 100644
index 000000000..08bd81043
--- /dev/null
+++ b/bots/migrations/0079_remove_botintegration_twilio_asr_language_and_more.py
@@ -0,0 +1,21 @@
+# Generated by Django 4.2.7 on 2024-07-22 22:52
+
+from django.db import migrations
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('bots', '0078_alter_botintegration_unique_together_and_more'),
+ ]
+
+ operations = [
+ migrations.RemoveField(
+ model_name='botintegration',
+ name='twilio_asr_language',
+ ),
+ migrations.RemoveField(
+ model_name='botintegration',
+ name='twilio_tts_voice',
+ ),
+ ]
diff --git a/bots/models.py b/bots/models.py
index 25a1a7e9d..0407569d5 100644
--- a/bots/models.py
+++ b/bots/models.py
@@ -694,16 +694,6 @@ class BotIntegration(models.Model):
blank=True,
help_text="The audio url to play to the user while waiting for a response if using voice",
)
- twilio_tts_voice = models.TextField(
- default="",
- blank=True,
- help_text="The voice to use for Twilio TTS ('man', 'woman', or Amazon Polly/Google Voices: https://www.twilio.com/docs/voice/twiml/say/text-speech#available-voices-and-languages)",
- )
- twilio_asr_language = models.TextField(
- default="",
- blank=True,
- help_text="The language to use for Twilio ASR (https://www.twilio.com/docs/voice/twiml/gather#languagetags)",
- )
streaming_enabled = models.BooleanField(
default=False,
@@ -1239,10 +1229,9 @@ def to_df_format(
metadata__mime_type__startswith="image/"
).values_list("url", flat=True)
),
- "Audio Input": ", ".join(
- message.attachments.filter(
- metadata__mime_type__startswith="audio/"
- ).values_list("url", flat=True)
+ "Audio Input": (
+ (message.saved_run and message.saved_run.state.get("input_audio"))
+ or ""
),
}
rows.append(row)
diff --git a/bots/tasks.py b/bots/tasks.py
index 436098f24..70f381bf1 100644
--- a/bots/tasks.py
+++ b/bots/tasks.py
@@ -150,6 +150,7 @@ def send_broadcast_msgs_chunked(
buttons: list[ReplyButton] = None,
convo_qs: QuerySet[Conversation],
bi: BotIntegration,
+ medium: str = "Voice Call",
):
convo_ids = list(convo_qs.values_list("id", flat=True))
for i in range(0, len(convo_ids), 100):
@@ -161,6 +162,7 @@ def send_broadcast_msgs_chunked(
documents=documents,
bi_id=bi.id,
convo_ids=convo_ids[i : i + 100],
+ medium=medium,
)
diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py
index 2e30e4379..794b5a061 100644
--- a/celeryapp/tasks.py
+++ b/celeryapp/tasks.py
@@ -5,13 +5,14 @@
from time import time
from types import SimpleNamespace
+import gooey_gui as gui
import requests
import sentry_sdk
from django.db.models import Sum
from django.utils import timezone
from fastapi import HTTPException
+from loguru import logger
-import gooey_ui as st
from app_users.models import AppUser, AppUserTransaction
from bots.admin_links import change_obj_url
from bots.models import SavedRun, Platform, Workflow
@@ -22,8 +23,6 @@
from daras_ai_v2.exceptions import UserError
from daras_ai_v2.send_email import send_email_via_postmark, send_low_balance_email
from daras_ai_v2.settings import templates
-from gooey_ui.pubsub import realtime_push
-from gooey_ui.state import set_query_params
from gooeysite.bg_db_conn import db_middleware
from payments.auto_recharge import (
should_attempt_auto_recharge,
@@ -41,6 +40,7 @@ def runner_task(
run_id: str,
uid: str,
channel: str,
+ unsaved_state: dict[str, typing.Any] = None,
) -> int:
start_time = time()
error_msg = None
@@ -68,7 +68,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False
# extract outputs from local state
| {
k: v
- for k, v in st.session_state.items()
+ for k, v in gui.session_state.items()
if k in page.ResponseModel.__fields__
}
# add extra outputs from the run
@@ -76,26 +76,27 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False
)
# send outputs to ui
- realtime_push(channel, output)
+ gui.realtime_push(channel, output)
# save to db
- page.dump_state_to_sr(st.session_state | output, sr)
+ page.dump_state_to_sr(gui.session_state | output, sr)
user = AppUser.objects.get(id=user_id)
page = page_cls(request=SimpleNamespace(user=user))
page.setup_sentry()
sr = page.run_doc_sr(run_id, uid)
- st.set_session_state(sr.to_dict())
- set_query_params(dict(run_id=run_id, uid=uid))
+ gui.set_session_state(sr.to_dict() | (unsaved_state or {}))
+ gui.set_query_params(dict(run_id=run_id, uid=uid))
try:
save_on_step()
- for val in page.main(sr, st.session_state):
+ for val in page.main(sr, gui.session_state):
save_on_step(val)
# render errors nicely
except Exception as e:
if isinstance(e, UserError):
sentry_level = e.sentry_level
+ logger.warning(e)
else:
sentry_level = "error"
traceback.print_exc()
@@ -106,7 +107,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False
# run completed successfully, deduct credits
else:
- sr.transaction, sr.price = page.deduct_credits(st.session_state)
+ sr.transaction, sr.price = page.deduct_credits(gui.session_state)
# save everything, mark run as completed
finally:
diff --git a/components_doc.py b/components_doc.py
index 74c9cdde2..f6ec26d8a 100644
--- a/components_doc.py
+++ b/components_doc.py
@@ -1,7 +1,7 @@
import inspect
from functools import wraps
-import gooey_ui as gui
+import gooey_gui as gui
META_TITLE = "Gooey Components"
META_DESCRIPTION = "Explore the Gooey Component Library"
diff --git a/conftest.py b/conftest.py
index a38c6a11a..57d5d5cca 100644
--- a/conftest.py
+++ b/conftest.py
@@ -55,7 +55,7 @@ def mock_celery_tasks():
with (
patch("celeryapp.tasks.runner_task", _mock_runner_task),
patch("celeryapp.tasks.post_runner_tasks", _mock_post_runner_tasks),
- patch("daras_ai_v2.bots.realtime_subscribe", _mock_realtime_subscribe),
+ patch("gooey_gui.realtime_subscribe", _mock_realtime_subscribe),
):
yield
diff --git a/daras_ai/image_input.py b/daras_ai/image_input.py
index 5e61fcba9..4ee0ab5f9 100644
--- a/daras_ai/image_input.py
+++ b/daras_ai/image_input.py
@@ -1,10 +1,11 @@
+import math
import mimetypes
import os
import re
+import typing
import uuid
from pathlib import Path
-import math
import numpy as np
import requests
from PIL import Image, ImageOps
@@ -13,6 +14,9 @@
from daras_ai_v2 import settings
from daras_ai_v2.exceptions import UserError
+if typing.TYPE_CHECKING:
+ from google.cloud.storage import Blob, Bucket
+
def resize_img_pad(img_bytes: bytes, size: tuple[int, int]) -> bytes:
img_cv2 = bytes_to_cv2_img(img_bytes)
@@ -57,25 +61,38 @@ def upload_file_from_bytes(
data: bytes,
content_type: str = None,
) -> str:
- if not content_type:
- content_type = mimetypes.guess_type(filename)[0]
- content_type = content_type or "application/octet-stream"
- blob = storage_blob_for(filename)
- blob.upload_from_string(data, content_type=content_type)
+ blob = gcs_blob_for(filename)
+ upload_gcs_blob_from_bytes(blob, data, content_type)
return blob.public_url
-def storage_blob_for(filename: str) -> "storage.storage.Blob":
- from firebase_admin import storage
-
+def gcs_blob_for(filename: str) -> "Blob":
filename = safe_filename(filename)
- bucket = storage.bucket(settings.GS_BUCKET_NAME)
+ bucket = gcs_bucket()
blob = bucket.blob(
os.path.join(settings.GS_MEDIA_PATH, str(uuid.uuid1()), filename)
)
return blob
+def upload_gcs_blob_from_bytes(
+ blob: "Blob",
+ data: bytes,
+ content_type: str = None,
+) -> str:
+ if not content_type:
+ content_type = mimetypes.guess_type(blob.path)[0]
+ content_type = content_type or "application/octet-stream"
+ blob.upload_from_string(data, content_type=content_type)
+ return blob.public_url
+
+
+def gcs_bucket() -> "Bucket":
+ from firebase_admin import storage
+
+ return storage.bucket(settings.GS_BUCKET_NAME)
+
+
def cv2_img_to_bytes(img: np.ndarray) -> bytes:
import cv2
diff --git a/daras_ai_v2/analysis_results.py b/daras_ai_v2/analysis_results.py
index 7a361c4aa..5bd519982 100644
--- a/daras_ai_v2/analysis_results.py
+++ b/daras_ai_v2/analysis_results.py
@@ -4,14 +4,14 @@
from django.db.models import IntegerChoices
-import gooey_ui as st
+import gooey_gui as gui
from app_users.models import AppUser
from bots.models import BotIntegration, Message
from daras_ai_v2.base import RecipeTabs
from daras_ai_v2.copy_to_clipboard_button_widget import copy_to_clipboard_button
from daras_ai_v2.grid_layout_widget import grid_layout
from daras_ai_v2.workflow_url_input import del_button
-from gooey_ui import QueryParamsRedirectException
+from gooey_gui import QueryParamsRedirectException
from gooeysite.custom_filters import related_json_field_summary
from recipes.BulkRunner import list_view_editor
from recipes.VideoBots import VideoBotsPage
@@ -41,7 +41,7 @@ def render_analysis_results_page(
render_title_breadcrumb_share(bi, current_url, current_user)
if title:
- st.write(title)
+ gui.write(title)
if graphs_json:
graphs = json.loads(graphs_json)
@@ -55,40 +55,40 @@ def render_analysis_results_page(
except Message.DoesNotExist:
results = None
if not results:
- with st.center():
- st.error("No analysis results found")
+ with gui.center():
+ gui.error("No analysis results found")
return
- with st.div(className="pb-5 pt-3"):
+ with gui.div(className="pb-5 pt-3"):
grid_layout(2, graphs, partial(render_graph_data, bi, results), separator=False)
- st.checkbox("๐ Refresh every 10s", key="autorefresh")
+ gui.checkbox("๐ Refresh every 10s", key="autorefresh")
- with st.expander("โ๏ธ Edit"):
- title = st.text_area("##### Title", value=title)
+ with gui.expander("โ๏ธ Edit"):
+ title = gui.text_area("##### Title", value=title)
- st.session_state.setdefault("selected_graphs", graphs)
+ gui.session_state.setdefault("selected_graphs", graphs)
selected_graphs = list_view_editor(
add_btn_label="โ Add a Graph",
key="selected_graphs",
render_inputs=partial(render_inputs, results),
)
- with st.center():
- if st.button("โ
Update"):
+ with gui.center():
+ if gui.button("โ
Update"):
_on_press_update(title, selected_graphs)
def render_inputs(results: dict, key: str, del_key: str, d: dict):
- ocol1, ocol2 = st.columns([11, 1], responsive=False)
+ ocol1, ocol2 = gui.columns([11, 1], responsive=False)
with ocol1:
- col1, col2, col3 = st.columns(3)
+ col1, col2, col3 = gui.columns(3)
with ocol2:
ocol2.node.props["style"] = dict(paddingTop="2rem")
del_button(del_key)
with col1:
- d["key"] = st.selectbox(
+ d["key"] = gui.selectbox(
label="##### Key",
options=results.keys(),
key=f"{key}_key",
@@ -96,7 +96,7 @@ def render_inputs(results: dict, key: str, del_key: str, d: dict):
)
with col2:
col2.node.props["style"] = dict(paddingTop="0.45rem")
- d["graph_type"] = st.selectbox(
+ d["graph_type"] = gui.selectbox(
label="###### Graph Type",
options=[g.value for g in GraphType],
format_func=lambda x: GraphType(x).label,
@@ -105,7 +105,7 @@ def render_inputs(results: dict, key: str, del_key: str, d: dict):
)
with col3:
col3.node.props["style"] = dict(paddingTop="0.45rem")
- d["data_selection"] = st.selectbox(
+ d["data_selection"] = gui.selectbox(
label="###### Data Selection",
options=[d.value for d in DataSelection],
format_func=lambda x: DataSelection(x).label,
@@ -115,10 +115,10 @@ def render_inputs(results: dict, key: str, del_key: str, d: dict):
def _autorefresh_script():
- if not st.session_state.get("autorefresh"):
+ if not gui.session_state.get("autorefresh"):
return
- st.session_state.pop("__cache__", None)
- st.js(
+ gui.session_state.pop("__cache__", None)
+ gui.js(
# language=JavaScript
"""
setTimeout(() => {
@@ -139,23 +139,23 @@ def render_title_breadcrumb_share(
else:
run_title = bi.saved_run.page_title # this is mostly for backwards compat
query_params = dict(run_id=bi.saved_run.run_id, uid=bi.saved_run.uid)
- with st.div(className="d-flex justify-content-between mt-4"):
- with st.div(className="d-lg-flex d-block align-items-center"):
- with st.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"):
- with st.breadcrumbs():
+ with gui.div(className="d-flex justify-content-between mt-4"):
+ with gui.div(className="d-lg-flex d-block align-items-center"):
+ with gui.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"):
+ with gui.breadcrumbs():
metadata = VideoBotsPage.workflow.get_or_create_metadata()
- st.breadcrumb_item(
+ gui.breadcrumb_item(
metadata.short_title,
link_to=VideoBotsPage.app_url(),
className="text-muted",
)
if not (bi.published_run_id and bi.published_run.is_root()):
- st.breadcrumb_item(
+ gui.breadcrumb_item(
run_title,
link_to=VideoBotsPage.app_url(**query_params),
className="text-muted",
)
- st.breadcrumb_item(
+ gui.breadcrumb_item(
"Integrations",
link_to=VideoBotsPage.app_url(
**query_params,
@@ -170,8 +170,8 @@ def render_title_breadcrumb_share(
show_as_link=current_user and VideoBotsPage.is_user_admin(current_user),
)
- with st.div(className="d-flex align-items-center"):
- with st.div(className="d-flex align-items-start right-action-icons"):
+ with gui.div(className="d-flex align-items-center"):
+ with gui.div(className="d-flex align-items-start right-action-icons"):
copy_to_clipboard_button(
f' Copy Link',
value=current_url,
@@ -197,7 +197,7 @@ def _on_press_update(title: str, selected_graphs: list[dict]):
raise QueryParamsRedirectException(dict(title=title, graphs=graphs_json))
-@st.cache_in_session_state
+@gui.cache_in_session_state
def fetch_analysis_results(bi: BotIntegration) -> dict:
msgs = Message.objects.filter(
conversation__bot_integration=bi,
@@ -217,7 +217,7 @@ def fetch_analysis_results(bi: BotIntegration) -> dict:
def render_graph_data(bi: BotIntegration, results: dict, graph_data: dict):
key = graph_data["key"]
- st.write(f"##### {key}")
+ gui.write(f"##### {key}")
obj_key = f"analysis_result__{key}"
if graph_data["data_selection"] == DataSelection.last.value:
@@ -227,7 +227,7 @@ def render_graph_data(bi: BotIntegration, results: dict, graph_data: dict):
.latest()
)
if not latest_msg:
- st.write("No analysis results found")
+ gui.write("No analysis results found")
return
values = [[latest_msg.analysis_result.get(key), 1]]
elif graph_data["data_selection"] == DataSelection.convo_last.value:
@@ -249,7 +249,7 @@ def render_graph_data(bi: BotIntegration, results: dict, graph_data: dict):
for val in values:
if not val:
continue
- st.write(val[0])
+ gui.write(val[0])
case GraphType.table_count.value:
render_table_count(values)
case GraphType.bar_count.value:
@@ -261,8 +261,8 @@ def render_graph_data(bi: BotIntegration, results: dict, graph_data: dict):
def render_table_count(values):
- st.div(className="p-1")
- st.data_table(
+ gui.div(className="p-1")
+ gui.data_table(
[["Value", "Count"]] + [[result[0], result[1]] for result in values],
)
@@ -318,4 +318,4 @@ def render_data_in_plotly(*data):
dragmode="pan",
),
)
- st.plotly_chart(fig)
+ gui.plotly_chart(fig)
diff --git a/daras_ai_v2/api_examples_widget.py b/daras_ai_v2/api_examples_widget.py
index 44086ba02..9cd9beeb1 100644
--- a/daras_ai_v2/api_examples_widget.py
+++ b/daras_ai_v2/api_examples_widget.py
@@ -5,7 +5,7 @@
from furl import furl
-import gooey_ui as st
+import gooey_gui as gui
from auth.token_authentication import auth_keyword
from daras_ai_v2 import settings
from daras_ai_v2.doc_search_settings_widgets import is_user_uploaded_url
@@ -30,7 +30,7 @@ def get_filenames(request_body):
def api_example_generator(
*, api_url: furl, request_body: dict, as_form_data: bool, as_async: bool
):
- js, python, curl = st.tabs(["`node.js`", "`python`", "`curl`"])
+ js, python, curl = gui.tabs(["`node.js`", "`python`", "`curl`"])
filenames = []
if as_async:
@@ -95,7 +95,7 @@ def api_example_generator(
json=shlex.quote(json.dumps(request_body, indent=2)),
)
- st.write(
+ gui.write(
"""
1. Generate an api key [below๐](#api-keys)
@@ -193,7 +193,7 @@ def api_example_generator(
from black.mode import Mode
py_code = format_str(py_code, mode=Mode())
- st.write(
+ gui.write(
rf"""
1. Generate an api key [below๐](#api-keys)
@@ -308,7 +308,7 @@ def api_example_generator(
js_code += "\n}\n\ngooeyAPI();"
- st.write(
+ gui.write(
r"""
1. Generate an api key [below๐](#api-keys)
@@ -385,7 +385,7 @@ def bot_api_example_generator(integration_id: str):
integration_id=integration_id,
)
- st.write(
+ gui.write(
f"""
Your Integration ID: `{integration_id}`
@@ -407,14 +407,14 @@ def bot_api_example_generator(integration_id: str):
)
/ "docs"
)
- st.markdown(
+ gui.markdown(
f"""
Read our complete API for features like conversation history, input media files, and more.
""",
unsafe_allow_html=True,
)
- st.js(
+ gui.js(
"""
document.startStreaming = async function() {
document.getElementById('stream-output').style.display = 'flex';
@@ -426,7 +426,7 @@ def bot_api_example_generator(integration_id: str):
).strip()
)
- st.html(
+ gui.html(
f"""
diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py
index fea57f5ba..c09102fe4 100644
--- a/daras_ai_v2/asr.py
+++ b/daras_ai_v2/asr.py
@@ -10,7 +10,7 @@
from django.db.models import F
from furl import furl
-import gooey_ui as st
+import gooey_gui as gui
from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri
from daras_ai_v2 import settings
from daras_ai_v2.azure_asr import azure_asr
@@ -297,7 +297,7 @@ def translation_language_selector(
**kwargs,
) -> str | None:
if not model:
- st.session_state[key] = None
+ gui.session_state[key] = None
return
if model == TranslationModels.google:
@@ -308,7 +308,7 @@ def translation_language_selector(
raise ValueError("Unsupported translation model: " + str(model))
options = list(languages.keys())
- return st.selectbox(
+ return gui.selectbox(
label=label,
key=key,
format_func=lang_format_func,
@@ -351,7 +351,7 @@ def google_translate_language_selector(
"""
languages = google_translate_target_languages()
options = list(languages.keys())
- return st.selectbox(
+ return gui.selectbox(
label=label,
key=key,
format_func=lambda k: languages[k] if k else "โโโ",
@@ -408,12 +408,10 @@ def asr_language_selector(
label="##### Spoken Language",
key="language",
):
- import langcodes
-
# don't show language selector for models with forced language
forced_lang = forced_asr_languages.get(selected_model)
if forced_lang:
- st.session_state[key] = forced_lang
+ gui.session_state[key] = forced_lang
return forced_lang
options = list(asr_supported_languages.get(selected_model, []))
@@ -421,18 +419,14 @@ def asr_language_selector(
options.insert(0, None)
# handle non-canonical language codes
- old_val = st.session_state.get(key)
- if old_val and old_val not in options:
- old_val_lang = langcodes.Language.get(old_val).language
- for opt in options:
- try:
- if opt and langcodes.Language.get(opt).language == old_val_lang:
- st.session_state[key] = opt
- break
- except langcodes.LanguageTagError:
- pass
-
- return st.selectbox(
+ old_lang = gui.session_state.get(key)
+ if old_lang:
+ try:
+ gui.session_state[key] = normalised_lang_in_collection(old_lang, options)
+ except UserError:
+ gui.session_state[key] = None
+
+ return gui.selectbox(
label=label,
key=key,
format_func=lang_format_func,
@@ -584,14 +578,27 @@ def run_google_translate(
def normalised_lang_in_collection(target: str, collection: typing.Iterable[str]) -> str:
import langcodes
- for candidate in collection:
- if langcodes.get(candidate).language == langcodes.get(target).language:
- return candidate
-
- raise UserError(
+ ERROR = UserError(
f"Unsupported language: {target!r} | must be one of {set(collection)}"
)
+ if target in collection:
+ return target
+
+ try:
+ target_lan = langcodes.Language.get(target).language
+ except langcodes.LanguageTagError:
+ raise ERROR
+
+ for candidate in collection:
+ try:
+ if candidate and langcodes.Language.get(candidate).language == target_lan:
+ return candidate
+ except langcodes.LanguageTagError:
+ pass
+
+ raise ERROR
+
def _translate_text(
text: str,
@@ -1069,7 +1076,15 @@ def iterate_subtitles(
yield segment_start, segment_end, segment_text
-def format_timestamp(seconds: float, always_include_hours: bool, decimal_marker: str):
+INFINITY_SECONDS = 99 * 3600 + 59 * 60 + 59 # 99:59:59 in seconds
+
+
+def format_timestamp(
+ seconds: float | None, always_include_hours: bool, decimal_marker: str
+):
+ if seconds is None:
+ # treat None as end of time
+ seconds = INFINITY_SECONDS
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 5be61de83..8e3610a28 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -12,6 +12,7 @@
from time import sleep
from types import SimpleNamespace
+import gooey_gui as gui
import sentry_sdk
from django.db.models import Sum
from django.utils import timezone
@@ -25,7 +26,6 @@
)
from starlette.requests import Request
-import gooey_ui as st
from app_users.models import AppUser, AppUserTransaction
from bots.models import (
SavedRun,
@@ -35,6 +35,7 @@
Workflow,
RetentionPolicy,
)
+from daras_ai.image_input import truncate_text_words
from daras_ai.text_format import format_number_with_suffix
from daras_ai_v2 import settings, urls
from daras_ai_v2.api_examples_widget import api_example_generator
@@ -55,9 +56,6 @@
from daras_ai_v2.manage_api_keys_widget import manage_api_keys
from daras_ai_v2.meta_preview_url import meta_preview_url
from daras_ai_v2.prompt_vars import variables_input
-from daras_ai_v2.query_params import (
- gooey_get_query_params,
-)
from daras_ai_v2.query_params_util import (
extract_query_params,
)
@@ -76,13 +74,6 @@
is_functions_enabled,
render_called_functions,
)
-from gooey_ui import (
- realtime_clear_subs,
- RedirectException,
-)
-from gooey_ui.components.modal import Modal
-from gooey_ui.components.pills import pill
-from gooey_ui.pubsub import realtime_pull
from payments.auto_recharge import (
should_attempt_auto_recharge,
run_auto_recharge_gracefully,
@@ -170,7 +161,7 @@ def __init__(
@classmethod
@property
def endpoint(cls) -> str:
- return f"/v2/{cls.slug_versions[0]}/"
+ return f"/v2/{cls.slug_versions[0]}"
@classmethod
def current_app_url(
@@ -182,7 +173,7 @@ def current_app_url(
) -> str:
if query_params is None:
query_params = {}
- example_id, run_id, uid = extract_query_params(gooey_get_query_params())
+ example_id, run_id, uid = extract_query_params(gui.get_query_params())
return cls.app_url(
tab=tab,
example_id=example_id,
@@ -276,7 +267,7 @@ def setup_sentry(self):
def sentry_event_set_request(self, event, hint):
request = event.setdefault("request", {})
request.setdefault("method", "POST")
- request["data"] = st.session_state
+ request["data"] = gui.session_state
if url := request.get("url"):
f = furl(url)
request["url"] = str(
@@ -284,7 +275,7 @@ def sentry_event_set_request(self, event, hint):
)
else:
request["url"] = self.app_url(
- tab=self.tab, query_params=st.get_query_params()
+ tab=self.tab, query_params=gui.get_query_params()
)
return event
@@ -313,36 +304,36 @@ def sentry_event_set_user(self, event, hint):
return event
def refresh_state(self):
- _, run_id, uid = extract_query_params(gooey_get_query_params())
+ _, run_id, uid = extract_query_params(gui.get_query_params())
channel = self.realtime_channel_name(run_id, uid)
- output = realtime_pull([channel])[0]
+ output = gui.realtime_pull([channel])[0]
if output:
- st.session_state.update(output)
+ gui.session_state.update(output)
def render(self):
self.setup_sentry()
- if self.get_run_state(st.session_state) == RecipeRunState.running:
+ if self.get_run_state(gui.session_state) == RecipeRunState.running:
self.refresh_state()
else:
- realtime_clear_subs()
+ gui.realtime_clear_subs()
self._user_disabled_check()
self._check_if_flagged()
- if st.session_state.get("show_report_workflow"):
+ if gui.session_state.get("show_report_workflow"):
self.render_report_form()
return
self._render_header()
- st.newline()
+ gui.newline()
- with st.nav_tabs():
+ with gui.nav_tabs():
for tab in self.get_tabs():
url = self.current_app_url(tab)
- with st.nav_item(url, active=tab == self.tab):
- st.html(tab.title)
- with st.nav_tab_content():
+ with gui.nav_item(url, active=tab == self.tab):
+ gui.html(tab.title)
+ with gui.nav_tab_content():
self.render_selected_tab()
def _render_header(self):
@@ -354,13 +345,13 @@ def _render_header(self):
self, current_run, published_run, tab=self.tab
)
- with st.div(className="d-flex justify-content-between mt-4"):
- with st.div(className="d-lg-flex d-block align-items-center"):
+ with gui.div(className="d-flex justify-content-between mt-4"):
+ with gui.div(className="d-lg-flex d-block align-items-center"):
if not tbreadcrumbs.has_breadcrumbs() and not self.run_user:
self._render_title(tbreadcrumbs.h1_title)
if tbreadcrumbs:
- with st.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"):
+ with gui.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"):
render_breadcrumbs(
tbreadcrumbs,
is_api_call=(
@@ -376,7 +367,7 @@ def _render_header(self):
if not is_root_example:
self.render_author(author)
- with st.div(className="d-flex align-items-center"):
+ with gui.div(className="d-flex align-items-center"):
can_user_edit_run = self.can_user_edit_run(current_run, published_run)
has_unpublished_changes = (
published_run
@@ -389,8 +380,8 @@ def _render_header(self):
if can_user_edit_run and has_unpublished_changes:
self._render_unpublished_changes_indicator()
- with st.div(className="d-flex align-items-start right-action-icons"):
- st.html(
+ with gui.div(className="d-flex align-items-start right-action-icons"):
+ gui.html(
"""
"""
)
- st.image(user.photo_url, className=class_name)
+ gui.image(user.photo_url, className=class_name)
if user.display_name:
name_style = {"fontSize": text_size} if text_size else {}
- with st.tag("span", style=name_style):
- st.html(html.escape(user.display_name))
+ with gui.tag("span", style=name_style):
+ gui.html(html.escape(user.display_name))
def get_credits_click_url(self):
if self.request.user and self.request.user.is_anonymous:
@@ -1315,9 +1313,9 @@ def get_submit_container_props(self):
)
def render_submit_button(self, key="--submit-1"):
- with st.div(**self.get_submit_container_props()):
- st.write("---")
- col1, col2 = st.columns([2, 1], responsive=False)
+ with gui.div(**self.get_submit_container_props()):
+ gui.write("---")
+ col1, col2 = gui.columns([2, 1], responsive=False)
col2.node.props[
"className"
] += " d-flex justify-content-end align-items-center"
@@ -1325,25 +1323,25 @@ def render_submit_button(self, key="--submit-1"):
with col1:
self.render_run_cost()
with col2:
- submitted = st.button(
+ submitted = gui.button(
"๐ Submit",
key=key,
type="primary",
- # disabled=bool(st.session_state.get(StateKeys.run_status)),
+ # disabled=bool(gui.session_state.get(StateKeys.run_status)),
)
if not submitted:
return False
try:
self.validate_form_v2()
except AssertionError as e:
- st.error(str(e))
+ gui.error(str(e))
return False
else:
return True
def render_run_cost(self):
url = self.get_credits_click_url()
- run_cost = self.get_price_roundoff(st.session_state)
+ run_cost = self.get_price_roundoff(gui.session_state)
ret = f'Run cost = {run_cost} credits'
cost_note = self.get_cost_note()
@@ -1354,18 +1352,18 @@ def render_run_cost(self):
if additional_notes:
ret += f" \n{additional_notes}"
- st.caption(ret, line_clamp=1, unsafe_allow_html=True)
+ gui.caption(ret, line_clamp=1, unsafe_allow_html=True)
def _render_step_row(self):
key = "details-expander"
- with st.expander("**โน๏ธ Details**", key=key):
- if not st.session_state.get(key):
+ with gui.expander("**โน๏ธ Details**", key=key):
+ if not gui.session_state.get(key):
return
- col1, col2 = st.columns([1, 2])
+ col1, col2 = gui.columns([1, 2])
with col1:
self.render_description()
with col2:
- placeholder = st.div()
+ placeholder = gui.div()
render_called_functions(
saved_run=self.get_current_sr(), trigger=FunctionTrigger.pre
)
@@ -1375,33 +1373,33 @@ def _render_step_row(self):
pass
else:
with placeholder:
- st.write("##### ๐ฃ Steps")
+ gui.write("##### ๐ฃ Steps")
render_called_functions(
saved_run=self.get_current_sr(), trigger=FunctionTrigger.post
)
def _render_help(self):
- placeholder = st.div()
+ placeholder = gui.div()
try:
self.render_usage_guide()
except NotImplementedError:
pass
else:
with placeholder:
- st.write(
+ gui.write(
"""
## How to Use This Recipe
"""
)
key = "discord-expander"
- with st.expander(
+ with gui.expander(
f"**๐๐ฝโโ๏ธ Need more help? [Join our Discord]({settings.DISCORD_INVITE_URL})**",
key=key,
):
- if not st.session_state.get(key):
+ if not gui.session_state.get(key):
return
- st.markdown(
+ gui.markdown(
"""
@@ -1458,34 +1456,34 @@ def run_v2(
raise NotImplementedError
def _render_report_button(self):
- example_id, run_id, uid = extract_query_params(gooey_get_query_params())
+ example_id, run_id, uid = extract_query_params(gui.get_query_params())
# only logged in users can report a run (but not examples/default runs)
if not (self.request.user and run_id and uid):
return
- reported = st.button(
+ reported = gui.button(
'
Report', type="tertiary"
)
if not reported:
return
- st.session_state["show_report_workflow"] = reported
- st.experimental_rerun()
+ gui.session_state["show_report_workflow"] = reported
+ gui.rerun()
def update_flag_for_run(self, run_id: str, uid: str, is_flagged: bool):
ref = self.run_doc_sr(uid=uid, run_id=run_id)
ref.is_flagged = is_flagged
ref.save(update_fields=["is_flagged"])
- st.session_state["is_flagged"] = is_flagged
+ gui.session_state["is_flagged"] = is_flagged
# Functions in every recipe feels like overkill for now, hide it in settings
functions_in_settings = True
def _render_input_col(self):
self.render_form_v2()
- placeholder = st.div()
+ placeholder = gui.div()
- with st.expander("โ๏ธ Settings"):
+ with gui.expander("โ๏ธ Settings"):
if self.functions_in_settings:
functions_input(self.request.user)
self.render_settings()
@@ -1494,8 +1492,8 @@ def _render_input_col(self):
self.render_variables()
submitted = self.render_submit_button()
- with st.div(style={"textAlign": "right"}):
- st.caption(
+ with gui.div(style={"textAlign": "right"}):
+ gui.caption(
"_By submitting, you agree to Gooey.AI's [terms](https://gooey.ai/terms) & "
"[privacy policy](https://gooey.ai/privacy)._"
)
@@ -1503,7 +1501,7 @@ def _render_input_col(self):
def render_variables(self):
if not self.functions_in_settings:
- st.write("---")
+ gui.write("---")
functions_input(self.request.user)
variables_input(
template_keys=self.template_keys, allow_add=is_functions_enabled()
@@ -1522,30 +1520,30 @@ def get_run_state(cls, state: dict[str, typing.Any]) -> RecipeRunState:
return RecipeRunState.starting
def render_deleted_output(self):
- col1, *_ = st.columns(2)
+ col1, *_ = gui.columns(2)
with col1:
- st.error(
+ gui.error(
"This data has been deleted as per the retention policy.",
icon="๐๏ธ",
color="rgba(255, 200, 100, 0.5)",
)
- st.newline()
+ gui.newline()
self._render_output_col(is_deleted=True)
- st.newline()
+ gui.newline()
self.render_run_cost()
def _render_output_col(self, *, submitted: bool = False, is_deleted: bool = False):
assert inspect.isgeneratorfunction(self.run)
- if st.session_state.get(StateKeys.pressed_randomize):
- st.session_state["seed"] = int(gooey_rng.randrange(MAX_SEED))
- st.session_state.pop(StateKeys.pressed_randomize, None)
+ if gui.session_state.get(StateKeys.pressed_randomize):
+ gui.session_state["seed"] = int(gooey_rng.randrange(MAX_SEED))
+ gui.session_state.pop(StateKeys.pressed_randomize, None)
submitted = True
if submitted or self.should_submit_after_login():
self.on_submit()
- run_state = self.get_run_state(st.session_state)
+ run_state = self.get_run_state(gui.session_state)
match run_state:
case RecipeRunState.completed:
self._render_completed_output()
@@ -1569,16 +1567,16 @@ def _render_completed_output(self):
pass
def _render_failed_output(self):
- err_msg = st.session_state.get(StateKeys.error_msg)
- st.error(err_msg, unsafe_allow_html=True)
+ err_msg = gui.session_state.get(StateKeys.error_msg)
+ gui.error(err_msg, unsafe_allow_html=True)
def _render_running_output(self):
- run_status = st.session_state.get(StateKeys.run_status)
+ run_status = gui.session_state.get(StateKeys.run_status)
html_spinner(run_status)
self.render_extra_waiting_output()
def render_extra_waiting_output(self):
- created_at = st.session_state.get("created_at")
+ created_at = gui.session_state.get("created_at")
if not created_at:
return
@@ -1589,15 +1587,15 @@ def render_extra_waiting_output(self):
estimated_run_time = self.estimate_run_duration()
if not estimated_run_time:
return
- with st.countdown_timer(
+ with gui.countdown_timer(
end_time=created_at + datetime.timedelta(seconds=estimated_run_time),
delay_text="Sorry for the wait. Your run is taking longer than we expected.",
):
if self.is_current_user_owner() and self.request.user.email:
- st.write(
+ gui.write(
f"""We'll email **{self.request.user.email}** when your workflow is done."""
)
- st.write(
+ gui.write(
f"""In the meantime, check out [๐ Examples]({self.current_app_url(RecipeTabs.examples)})
for inspiration."""
)
@@ -1609,21 +1607,21 @@ def on_submit(self):
try:
sr = self.create_new_run(enable_rate_limits=True)
except ValidationError as e:
- st.session_state[StateKeys.run_status] = None
- st.session_state[StateKeys.error_msg] = str(e)
+ gui.session_state[StateKeys.run_status] = None
+ gui.session_state[StateKeys.error_msg] = str(e)
return
except RateLimitExceeded as e:
- st.session_state[StateKeys.run_status] = None
- st.session_state[StateKeys.error_msg] = e.detail.get("error", "")
+ gui.session_state[StateKeys.run_status] = None
+ gui.session_state[StateKeys.error_msg] = e.detail.get("error", "")
return
self.call_runner_task(sr)
- raise RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid))
+ raise gui.RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid))
def should_submit_after_login(self) -> bool:
return (
- st.get_query_params().get(SUBMIT_AFTER_LOGIN_Q)
+ gui.get_query_params().get(SUBMIT_AFTER_LOGIN_Q)
and self.request
and self.request.user
and not self.request.user.is_anonymous
@@ -1632,9 +1630,9 @@ def should_submit_after_login(self) -> bool:
def create_new_run(
self, *, enable_rate_limits: bool = False, **defaults
) -> SavedRun:
- st.session_state[StateKeys.run_status] = "Starting..."
- st.session_state.pop(StateKeys.error_msg, None)
- st.session_state.pop(StateKeys.run_time, None)
+ gui.session_state[StateKeys.run_status] = "Starting..."
+ gui.session_state.pop(StateKeys.error_msg, None)
+ gui.session_state.pop(StateKeys.run_time, None)
self._setup_rng_seed()
self.clear_outputs()
@@ -1654,7 +1652,7 @@ def create_new_run(
run_id = get_random_doc_id()
parent_example_id, parent_run_id, parent_uid = extract_query_params(
- gooey_get_query_params()
+ gui.get_query_params()
)
parent = self.get_sr_from_query_params(
parent_example_id, parent_run_id, parent_uid
@@ -1673,8 +1671,8 @@ def create_new_run(
)
# ensure the request is validated
- state = st.session_state | json.loads(
- self.RequestModel.parse_obj(st.session_state).json(exclude_unset=True)
+ state = gui.session_state | json.loads(
+ self.RequestModel.parse_obj(gui.session_state).json(exclude_unset=True)
)
self.dump_state_to_sr(state, sr)
@@ -1699,6 +1697,7 @@ def call_runner_task(self, sr: SavedRun):
run_id=sr.run_id,
uid=sr.uid,
channel=self.realtime_channel_name(sr.run_id, sr.uid),
+ unsaved_state=self._unsaved_state(),
)
| post_runner_tasks.s()
)
@@ -1741,28 +1740,28 @@ def generate_credit_error_message(self, run_id, uid) -> str:
return error_msg
def _setup_rng_seed(self):
- seed = st.session_state.get("seed")
+ seed = gui.session_state.get("seed")
if not seed:
return
gooey_rng.seed(seed)
def clear_outputs(self):
# clear error msg
- st.session_state.pop(StateKeys.error_msg, None)
+ gui.session_state.pop(StateKeys.error_msg, None)
# clear outputs
for field_name in self.ResponseModel.__fields__:
- st.session_state.pop(field_name, None)
+ gui.session_state.pop(field_name, None)
def _render_after_output(self):
self._render_report_button()
if "seed" in self.RequestModel.schema_json():
- randomize = st.button(
+ randomize = gui.button(
'
Regenerate', type="tertiary"
)
if randomize:
- st.session_state[StateKeys.pressed_randomize] = True
- st.experimental_rerun()
+ gui.session_state[StateKeys.pressed_randomize] = True
+ gui.rerun()
@classmethod
def load_state_from_sr(cls, sr: SavedRun) -> dict:
@@ -1782,7 +1781,7 @@ def load_state_defaults(cls, state: dict):
state.setdefault(k, v)
return state
- def fields_to_save(self) -> [str]:
+ def fields_to_save(self) -> list[str]:
# only save the fields in request/response
return [
field_name
@@ -1794,6 +1793,18 @@ def fields_to_save(self) -> [str]:
StateKeys.run_time,
]
+ def _unsaved_state(self) -> dict[str, typing.Any]:
+ result = {}
+ for field in self.fields_not_to_save():
+ try:
+ result[field] = gui.session_state[field]
+ except KeyError:
+ pass
+ return result
+
+ def fields_not_to_save(self) -> list[str]:
+ return []
+
def _examples_tab(self):
allow_hide = self.is_current_user_admin()
@@ -1823,12 +1834,12 @@ def _saved_tab(self):
created_by=self.request.user,
)[:50]
if not published_runs:
- st.write("No published runs yet")
+ gui.write("No published runs yet")
return
def _render(pr: PublishedRun):
- with st.div(className="mb-2", style={"font-size": "0.9rem"}):
- pill(
+ with gui.div(className="mb-2", style={"font-size": "0.9rem"}):
+ gui.pill(
PublishedRunVisibility(pr.visibility).get_badge_html(),
unsafe_allow_html=True,
className="border border-dark",
@@ -1845,7 +1856,7 @@ def _history_tab(self):
if self.is_current_user_admin():
uid = self.request.query_params.get("uid", uid)
- before = gooey_get_query_params().get("updated_at__lt", None)
+ before = gui.get_query_params().get("updated_at__lt", None)
if before:
before = datetime.datetime.fromisoformat(before)
else:
@@ -1858,7 +1869,7 @@ def _history_tab(self):
)[:25]
)
if not run_history:
- st.write("No history yet")
+ gui.write("No history yet")
return
grid_layout(3, run_history, self._render_run_preview)
@@ -1867,15 +1878,15 @@ def _history_tab(self):
RecipeTabs.history,
query_params={"updated_at__lt": run_history[-1].to_dict()["updated_at"]},
)
- with st.link(to=str(next_url)):
- st.html(
+ with gui.link(to=str(next_url)):
+ gui.html(
# language=HTML
f"""
"""
)
def ensure_authentication(self, next_url: str | None = None, anon_ok: bool = False):
if not self.request.user or (self.request.user.is_anonymous and not anon_ok):
- raise RedirectException(self.get_auth_url(next_url))
+ raise gui.RedirectException(self.get_auth_url(next_url))
def get_auth_url(self, next_url: str | None = None) -> str:
from routers.root import login
@@ -1890,10 +1901,10 @@ def _render_run_preview(self, saved_run: SavedRun):
is_latest_version = published_run and published_run.saved_run == saved_run
tb = get_title_breadcrumbs(self, sr=saved_run, pr=published_run)
- with st.link(to=saved_run.get_app_url()):
- with st.div(className="mb-1", style={"fontSize": "0.9rem"}):
+ with gui.link(to=saved_run.get_app_url()):
+ with gui.div(className="mb-1", style={"fontSize": "0.9rem"}):
if is_latest_version:
- pill(
+ gui.pill(
PublishedRunVisibility(
published_run.visibility
).get_badge_html(),
@@ -1901,7 +1912,7 @@ def _render_run_preview(self, saved_run: SavedRun):
className="border border-dark",
)
- st.write(f"#### {tb.h1_title}")
+ gui.write(f"#### {tb.h1_title}")
updated_at = saved_run.updated_at
if (
@@ -1909,34 +1920,34 @@ def _render_run_preview(self, saved_run: SavedRun):
and isinstance(updated_at, datetime.datetime)
and not saved_run.run_status
):
- st.caption("Loading...", **render_local_dt_attrs(updated_at))
+ gui.caption("Loading...", **render_local_dt_attrs(updated_at))
if saved_run.run_status:
started_at_text(saved_run.created_at)
html_spinner(saved_run.run_status, scroll_into_view=False)
elif saved_run.error_msg:
- st.error(saved_run.error_msg, unsafe_allow_html=True)
+ gui.error(saved_run.error_msg, unsafe_allow_html=True)
return self.render_example(saved_run.to_dict())
def render_published_run_preview(self, published_run: PublishedRun):
tb = get_title_breadcrumbs(self, published_run.saved_run, published_run)
- with st.link(to=published_run.get_app_url()):
- st.write(f"#### {tb.h1_title}")
+ with gui.link(to=published_run.get_app_url()):
+ gui.write(f"#### {tb.h1_title}")
- with st.div(className="d-flex align-items-center justify-content-between"):
- with st.div():
+ with gui.div(className="d-flex align-items-center justify-content-between"):
+ with gui.div():
updated_at = published_run.saved_run.updated_at
if updated_at and isinstance(updated_at, datetime.datetime):
- st.caption("Loading...", **render_local_dt_attrs(updated_at))
+ gui.caption("Loading...", **render_local_dt_attrs(updated_at))
if published_run.visibility == PublishedRunVisibility.PUBLIC:
run_icon = '
'
run_count = format_number_with_suffix(published_run.get_run_count())
- st.caption(f"{run_icon} {run_count} runs", unsafe_allow_html=True)
+ gui.caption(f"{run_icon} {run_count} runs", unsafe_allow_html=True)
if published_run.notes:
- st.caption(published_run.notes, line_clamp=2)
+ gui.caption(published_run.notes, line_clamp=2)
doc = published_run.saved_run.to_dict()
self.render_example(doc)
@@ -1950,28 +1961,28 @@ def _render_example_preview(
tb = get_title_breadcrumbs(self, published_run.saved_run, published_run)
if published_run.created_by:
- with st.div(className="mb-1 text-truncate", style={"height": "1.5rem"}):
+ with gui.div(className="mb-1 text-truncate", style={"height": "1.5rem"}):
self.render_author(
published_run.created_by,
image_size="20px",
text_size="0.9rem",
)
- with st.link(to=published_run.get_app_url()):
- st.write(f"#### {tb.h1_title}")
+ with gui.link(to=published_run.get_app_url()):
+ gui.write(f"#### {tb.h1_title}")
- with st.div(className="d-flex align-items-center justify-content-between"):
- with st.div():
+ with gui.div(className="d-flex align-items-center justify-content-between"):
+ with gui.div():
updated_at = published_run.saved_run.updated_at
if updated_at and isinstance(updated_at, datetime.datetime):
- st.caption("Loading...", **render_local_dt_attrs(updated_at))
+ gui.caption("Loading...", **render_local_dt_attrs(updated_at))
run_icon = '
'
run_count = format_number_with_suffix(published_run.get_run_count())
- st.caption(f"{run_icon} {run_count} runs", unsafe_allow_html=True)
+ gui.caption(f"{run_icon} {run_count} runs", unsafe_allow_html=True)
if published_run.notes:
- st.caption(published_run.notes, line_clamp=2)
+ gui.caption(published_run.notes, line_clamp=2)
if allow_hide:
self._example_hide_button(published_run=published_run)
@@ -1980,7 +1991,7 @@ def _render_example_preview(
self.render_example(doc)
def _example_hide_button(self, published_run: PublishedRun):
- pressed_delete = st.button(
+ pressed_delete = gui.button(
"๐๏ธ Hide",
key=f"delete_example_{published_run.published_run_id}",
style={"color": "red"},
@@ -1990,11 +2001,11 @@ def _example_hide_button(self, published_run: PublishedRun):
self.set_hidden(published_run=published_run, hidden=True)
def set_hidden(self, *, published_run: PublishedRun, hidden: bool):
- with st.spinner("Hiding..."):
+ with gui.spinner("Hiding..."):
published_run.is_approved_example = not hidden
published_run.save()
- st.experimental_rerun()
+ gui.rerun()
def render_example(self, state: dict):
pass
@@ -2036,45 +2047,45 @@ def run_as_api_tab(self):
/ "docs"
)
- st.markdown(
+ gui.markdown(
f'๐ To learn more, take a look at our
complete API',
unsafe_allow_html=True,
)
- st.write("#### ๐ค Example Request")
+ gui.write("#### ๐ค Example Request")
- include_all = st.checkbox("##### Show all fields")
- as_async = st.checkbox("##### Run Async")
- as_form_data = st.checkbox("##### Upload Files via Form Data")
+ include_all = gui.checkbox("##### Show all fields")
+ as_async = gui.checkbox("##### Run Async")
+ as_form_data = gui.checkbox("##### Upload Files via Form Data")
pr = self.get_current_published_run()
api_url, request_body = self.get_example_request(
- st.session_state,
+ gui.session_state,
include_all=include_all,
pr=pr,
)
response_body = self.get_example_response_body(
- st.session_state, as_async=as_async, include_all=include_all
+ gui.session_state, as_async=as_async, include_all=include_all
)
api_example_generator(
api_url=api_url,
- request_body=request_body,
+ request_body=request_body | self._unsaved_state(),
as_form_data=as_form_data,
as_async=as_async,
)
- st.write("")
+ gui.write("")
- st.write("#### ๐ Example Response")
- st.json(response_body, expanded=True)
+ gui.write("#### ๐ Example Response")
+ gui.json(response_body, expanded=True)
if not self.request.user or self.request.user.is_anonymous:
- st.write("**Please Login to generate the `$GOOEY_API_KEY`**")
+ gui.write("**Please Login to generate the `$GOOEY_API_KEY`**")
return
- st.write("---")
- with st.tag("a", id="api-keys"):
- st.write("### ๐ API keys")
+ gui.write("---")
+ with gui.tag("a", id="api-keys"):
+ gui.write("### ๐ API keys")
manage_api_keys(self.request.user)
@@ -2168,7 +2179,7 @@ def get_example_response_body(
include_all: bool = False,
) -> dict:
run_id = get_random_doc_id()
- created_at = st.session_state.get(
+ created_at = gui.session_state.get(
StateKeys.created_at, datetime.datetime.utcnow().isoformat()
)
web_url = self.app_url(
@@ -2182,7 +2193,7 @@ def get_example_response_body(
run_id=run_id,
web_url=web_url,
created_at=created_at,
- run_time_sec=st.session_state.get(StateKeys.run_time, 0),
+ run_time_sec=gui.session_state.get(StateKeys.run_time, 0),
status="completed",
output=output,
)
@@ -2220,12 +2231,12 @@ def is_current_user_owner(self) -> bool:
def started_at_text(dt: datetime.datetime):
- with st.div(className="d-flex"):
+ with gui.div(className="d-flex"):
text = "Started"
- if seed := st.session_state.get("seed"):
+ if seed := gui.session_state.get("seed"):
text += f' with seed
{seed}'
- st.caption(text + " on ", unsafe_allow_html=True)
- st.caption(
+ gui.caption(text + " on ", unsafe_allow_html=True)
+ gui.caption(
"...",
className="text-black",
**render_local_dt_attrs(dt),
@@ -2235,23 +2246,23 @@ def started_at_text(dt: datetime.datetime):
def render_output_caption():
caption = ""
- run_time = st.session_state.get(StateKeys.run_time, 0)
+ run_time = gui.session_state.get(StateKeys.run_time, 0)
if run_time:
caption += f'Generated in
{run_time :.1f}s'
- if seed := st.session_state.get("seed"):
+ if seed := gui.session_state.get("seed"):
caption += f' with seed
{seed} '
- updated_at = st.session_state.get(StateKeys.updated_at, datetime.datetime.today())
+ updated_at = gui.session_state.get(StateKeys.updated_at, datetime.datetime.today())
if updated_at:
if isinstance(updated_at, str):
updated_at = datetime.datetime.fromisoformat(updated_at)
caption += " on "
- with st.div(className="d-flex"):
- st.caption(caption, unsafe_allow_html=True)
+ with gui.div(className="d-flex"):
+ gui.caption(caption, unsafe_allow_html=True)
if updated_at:
- st.caption(
+ gui.caption(
"...",
className="text-black",
**render_local_dt_attrs(
diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py
index 0e16d25db..09b48d375 100644
--- a/daras_ai_v2/billing.py
+++ b/daras_ai_v2/billing.py
@@ -1,18 +1,15 @@
from typing import Literal
+import gooey_gui as gui
import stripe
from django.core.exceptions import ValidationError
-import gooey_ui as st
from app_users.models import AppUser, PaymentProvider
from daras_ai_v2 import icons, settings, paypal
-from daras_ai_v2.base import RedirectException
from daras_ai_v2.fastapi_tricks import get_app_route_url
from daras_ai_v2.grid_layout_widget import grid_layout
from daras_ai_v2.settings import templates
from daras_ai_v2.user_date_widgets import render_local_date_attrs
-from gooey_ui.components.modal import Modal
-from gooey_ui.components.pills import pill
from payments.models import PaymentMethodSummary
from payments.plans import PricingPlan
from scripts.migrate_existing_subscriptions import available_subscriptions
@@ -29,30 +26,30 @@ def billing_page(user: AppUser):
if user.subscription:
render_current_plan(user)
- with st.div(className="my-5"):
+ with gui.div(className="my-5"):
render_credit_balance(user)
- with st.div(className="my-5"):
+ with gui.div(className="my-5"):
selected_payment_provider = render_all_plans(user)
- with st.div(className="my-5"):
+ with gui.div(className="my-5"):
render_addon_section(user, selected_payment_provider)
if user.subscription and user.subscription.payment_provider:
if user.subscription.payment_provider == PaymentProvider.STRIPE:
- with st.div(className="my-5"):
+ with gui.div(className="my-5"):
render_auto_recharge_section(user)
- with st.div(className="my-5"):
+ with gui.div(className="my-5"):
render_payment_information(user)
- with st.div(className="my-5"):
+ with gui.div(className="my-5"):
render_billing_history(user)
def render_payments_setup():
from routers.account import payment_processing_route
- st.html(
+ gui.html(
templates.get_template("payment_setup.html").render(
settings=settings,
payment_processing_url=get_app_route_url(payment_processing_route),
@@ -68,25 +65,25 @@ def render_current_plan(user: AppUser):
else None
)
- with st.div(className=f"{rounded_border} border-dark"):
+ with gui.div(className=f"{rounded_border} border-dark"):
# ROW 1: Plan title and next invoice date
left, right = left_and_right()
with left:
- st.write(f"#### Gooey.AI {plan.title}")
+ gui.write(f"#### Gooey.AI {plan.title}")
if provider:
- st.write(
+ gui.write(
f"[{icons.edit} Manage Subscription](#payment-information)",
unsafe_allow_html=True,
)
- with right, st.div(className="d-flex align-items-center gap-1"):
+ with right, gui.div(className="d-flex align-items-center gap-1"):
if provider and (
- next_invoice_ts := st.run_in_thread(
+ next_invoice_ts := gui.run_in_thread(
user.subscription.get_next_invoice_timestamp, cache=True
)
):
- st.html("Next invoice on ")
- pill(
+ gui.html("Next invoice on ")
+ gui.pill(
"...",
text_bg="dark",
**render_local_date_attrs(
@@ -102,22 +99,22 @@ def render_current_plan(user: AppUser):
# ROW 2: Plan pricing details
left, right = left_and_right(className="mt-5")
with left:
- st.write(f"# {plan.pricing_title()}", className="no-margin")
+ gui.write(f"# {plan.pricing_title()}", className="no-margin")
if plan.monthly_charge:
provider_text = f" **via {provider.label}**" if provider else ""
- st.caption("per month" + provider_text)
+ gui.caption("per month" + provider_text)
- with right, st.div(className="text-end"):
- st.write(f"# {plan.credits:,} credits", className="no-margin")
+ with right, gui.div(className="text-end"):
+ gui.write(f"# {plan.credits:,} credits", className="no-margin")
if plan.monthly_charge:
- st.write(
+ gui.write(
f"**${plan.monthly_charge:,}** monthly renewal for {plan.credits:,} credits"
)
def render_credit_balance(user: AppUser):
- st.write(f"## Credit Balance: {user.balance:,}")
- st.caption(
+ gui.write(f"## Credit Balance: {user.balance:,}")
+ gui.caption(
"Every time you submit a workflow or make an API call, we deduct credits from your account."
)
@@ -130,13 +127,13 @@ def render_all_plans(user: AppUser) -> PaymentProvider:
)
all_plans = [plan for plan in PricingPlan if not plan.deprecated]
- st.write("## All Plans")
- plans_div = st.div(className="mb-1")
+ gui.write("## All Plans")
+ plans_div = gui.div(className="mb-1")
if user.subscription and user.subscription.payment_provider:
selected_payment_provider = None
else:
- with st.div():
+ with gui.div():
selected_payment_provider = PaymentProvider[
payment_provider_radio() or PaymentProvider.STRIPE.name
]
@@ -146,8 +143,8 @@ def _render_plan(plan: PricingPlan):
extra_class = "border-dark"
else:
extra_class = "bg-light"
- with st.div(className="d-flex flex-column h-100"):
- with st.div(
+ with gui.div(className="d-flex flex-column h-100"):
+ with gui.div(
className=f"{rounded_border} flex-grow-1 d-flex flex-column p-3 mb-2 {extra_class}"
):
_render_plan_details(plan)
@@ -158,30 +155,32 @@ def _render_plan(plan: PricingPlan):
with plans_div:
grid_layout(4, all_plans, _render_plan, separator=False)
- with st.div(className="my-2 d-flex justify-content-center"):
- st.caption(f"**[See all features & benefits]({settings.PRICING_DETAILS_URL})**")
+ with gui.div(className="my-2 d-flex justify-content-center"):
+ gui.caption(
+ f"**[See all features & benefits]({settings.PRICING_DETAILS_URL})**"
+ )
return selected_payment_provider
def _render_plan_details(plan: PricingPlan):
- with st.div(className="flex-grow-1"):
- with st.div(className="mb-4"):
- with st.tag("h4", className="mb-0"):
- st.html(plan.title)
- st.caption(
+ with gui.div(className="flex-grow-1"):
+ with gui.div(className="mb-4"):
+ with gui.tag("h4", className="mb-0"):
+ gui.html(plan.title)
+ gui.caption(
plan.description,
style={
"minHeight": "calc(var(--bs-body-line-height) * 2em)",
"display": "block",
},
)
- with st.div(className="my-3 w-100"):
- with st.tag("h4", className="my-0 d-inline me-2"):
- st.html(plan.pricing_title())
- with st.tag("span", className="text-muted my-0"):
- st.html(plan.pricing_caption())
- st.write(plan.long_description, unsafe_allow_html=True)
+ with gui.div(className="my-3 w-100"):
+ with gui.tag("h4", className="my-0 d-inline me-2"):
+ gui.html(plan.pricing_title())
+ with gui.tag("span", className="text-muted my-0"):
+ gui.html(plan.pricing_caption())
+ gui.write(plan.long_description, unsafe_allow_html=True)
def _render_plan_action_button(
@@ -192,13 +191,13 @@ def _render_plan_action_button(
):
btn_classes = "w-100 mt-3"
if plan == current_plan:
- st.button("Your Plan", className=btn_classes, disabled=True, type="tertiary")
+ gui.button("Your Plan", className=btn_classes, disabled=True, type="tertiary")
elif plan.contact_us_link:
- with st.link(
+ with gui.link(
to=plan.contact_us_link,
className=btn_classes + " btn btn-theme btn-primary",
):
- st.html("Contact Us")
+ gui.html("Contact Us")
elif user.subscription and not user.subscription.payment_provider:
# don't show upgrade/downgrade buttons for enterprise customers
# assumption: anyone without a payment provider attached is admin/enterprise
@@ -262,11 +261,11 @@ def _render_update_subscription_button(
key = f"change-sub-{plan.key}"
match label:
case "Downgrade":
- downgrade_modal = Modal(
+ downgrade_modal = gui.Modal(
"Confirm downgrade",
key=f"downgrade-plan-modal-{plan.key}",
)
- if st.button(
+ if gui.button(
label,
className=className,
key=key,
@@ -275,28 +274,28 @@ def _render_update_subscription_button(
if downgrade_modal.is_open():
with downgrade_modal.container():
- st.write(
+ gui.write(
f"""
Are you sure you want to change from:
**{current_plan.title} ({fmt_price(current_plan)})** to **{plan.title} ({fmt_price(plan)})**?
""",
className="d-block py-4",
)
- with st.div(className="d-flex w-100"):
- if st.button(
+ with gui.div(className="d-flex w-100"):
+ if gui.button(
"Downgrade",
className="btn btn-theme bg-danger border-danger text-white",
key=f"{key}-confirm",
):
change_subscription(user, plan)
- if st.button(
+ if gui.button(
"Cancel",
className="border border-danger text-danger",
key=f"{key}-cancel",
):
downgrade_modal.close()
case _:
- if st.button(label, className=className, key=key):
+ if gui.button(label, className=className, key=key):
change_subscription(
user,
plan,
@@ -319,19 +318,19 @@ def change_subscription(user: AppUser, new_plan: PricingPlan, **kwargs):
current_plan = PricingPlan.from_sub(user.subscription)
if new_plan == current_plan:
- raise RedirectException(get_app_route_url(account_route), status_code=303)
+ raise gui.RedirectException(get_app_route_url(account_route), status_code=303)
if new_plan == PricingPlan.STARTER:
user.subscription.cancel()
user.subscription.delete()
- raise RedirectException(
+ raise gui.RedirectException(
get_app_route_url(payment_processing_route), status_code=303
)
match user.subscription.payment_provider:
case PaymentProvider.STRIPE:
if not new_plan.supports_stripe():
- st.error(f"Stripe subscription not available for {new_plan}")
+ gui.error(f"Stripe subscription not available for {new_plan}")
subscription = stripe.Subscription.retrieve(user.subscription.external_id)
stripe.Subscription.modify(
@@ -346,13 +345,13 @@ def change_subscription(user: AppUser, new_plan: PricingPlan, **kwargs):
**kwargs,
proration_behavior="none",
)
- raise RedirectException(
+ raise gui.RedirectException(
get_app_route_url(payment_processing_route), status_code=303
)
case PaymentProvider.PAYPAL:
if not new_plan.supports_paypal():
- st.error(f"Paypal subscription not available for {new_plan}")
+ gui.error(f"Paypal subscription not available for {new_plan}")
subscription = paypal.Subscription.retrieve(user.subscription.external_id)
paypal_plan_info = new_plan.get_paypal_plan()
@@ -360,16 +359,16 @@ def change_subscription(user: AppUser, new_plan: PricingPlan, **kwargs):
plan_id=paypal_plan_info["plan_id"],
plan=paypal_plan_info["plan"],
)
- raise RedirectException(approval_url, status_code=303)
+ raise gui.RedirectException(approval_url, status_code=303)
case _:
- st.error("Not implemented for this payment provider")
+ gui.error("Not implemented for this payment provider")
def payment_provider_radio(**props) -> str | None:
- with st.div(className="d-flex"):
- st.write("###### Pay Via", className="d-block me-3")
- return st.radio(
+ with gui.div(className="d-flex"):
+ gui.write("###### Pay Via", className="d-block me-3")
+ return gui.radio(
"",
options=PaymentProvider.names,
format_func=lambda name: f'
{PaymentProvider[name].label}',
@@ -379,10 +378,10 @@ def payment_provider_radio(**props) -> str | None:
def render_addon_section(user: AppUser, selected_payment_provider: PaymentProvider):
if user.subscription:
- st.write("# Purchase More Credits")
+ gui.write("# Purchase More Credits")
else:
- st.write("# Purchase Credits")
- st.caption(f"Buy more credits. $1 per {settings.ADDON_CREDITS_PER_DOLLAR} credits")
+ gui.write("# Purchase Credits")
+ gui.caption(f"Buy more credits. $1 per {settings.ADDON_CREDITS_PER_DOLLAR} credits")
if user.subscription:
provider = PaymentProvider(user.subscription.payment_provider)
@@ -396,22 +395,22 @@ def render_addon_section(user: AppUser, selected_payment_provider: PaymentProvid
def render_paypal_addon_buttons():
- selected_amt = st.horizontal_radio(
+ selected_amt = gui.horizontal_radio(
"",
settings.ADDON_AMOUNT_CHOICES,
format_func=lambda amt: f"${amt:,}",
checked_by_default=False,
)
if selected_amt:
- st.js(
+ gui.js(
f"setPaypalAddonQuantity({int(selected_amt) * settings.ADDON_CREDITS_PER_DOLLAR})"
)
- st.div(
+ gui.div(
id="paypal-addon-buttons",
className="mt-2",
style={"width": "fit-content"},
)
- st.div(id="paypal-result-message")
+ gui.div(id="paypal-result-message")
def render_stripe_addon_buttons(user: AppUser):
@@ -420,10 +419,10 @@ def render_stripe_addon_buttons(user: AppUser):
def render_stripe_addon_button(dollat_amt: int, user: AppUser):
- confirm_purchase_modal = Modal(
+ confirm_purchase_modal = gui.Modal(
"Confirm Purchase", key=f"confirm-purchase-{dollat_amt}"
)
- if st.button(f"${dollat_amt:,}", type="primary"):
+ if gui.button(f"${dollat_amt:,}", type="primary"):
if user.subscription:
confirm_purchase_modal.open()
else:
@@ -432,32 +431,32 @@ def render_stripe_addon_button(dollat_amt: int, user: AppUser):
if not confirm_purchase_modal.is_open():
return
with confirm_purchase_modal.container():
- st.write(
+ gui.write(
f"""
Please confirm your purchase:
**{dollat_amt * settings.ADDON_CREDITS_PER_DOLLAR:,} credits for ${dollat_amt}**.
""",
className="py-4 d-block text-center",
)
- with st.div(className="d-flex w-100 justify-content-end"):
- if st.session_state.get("--confirm-purchase"):
- success = st.run_in_thread(
+ with gui.div(className="d-flex w-100 justify-content-end"):
+ if gui.session_state.get("--confirm-purchase"):
+ success = gui.run_in_thread(
user.subscription.stripe_attempt_addon_purchase,
args=[dollat_amt],
placeholder="Processing payment...",
)
if success is None:
return
- st.session_state.pop("--confirm-purchase")
+ gui.session_state.pop("--confirm-purchase")
if success:
confirm_purchase_modal.close()
else:
- st.error("Payment failed... Please try again.")
+ gui.error("Payment failed... Please try again.")
return
- if st.button("Cancel", className="border border-danger text-danger me-2"):
+ if gui.button("Cancel", className="border border-danger text-danger me-2"):
confirm_purchase_modal.close()
- st.button("Buy", type="primary", key="--confirm-purchase")
+ gui.button("Buy", type="primary", key="--confirm-purchase")
def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int):
@@ -478,7 +477,7 @@ def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int):
"payment_method_save": "enabled",
},
)
- raise RedirectException(checkout_session.url, status_code=303)
+ raise gui.RedirectException(checkout_session.url, status_code=303)
def render_stripe_subscription_button(
@@ -490,13 +489,13 @@ def render_stripe_subscription_button(
key: str,
):
if not plan.supports_stripe():
- st.write("Stripe subscription not available")
+ gui.write("Stripe subscription not available")
return
# IMPORTANT: key=... is needed here to maintain uniqueness
# of buttons with the same label. otherwise, all buttons
# will be the same to the server
- if st.button(label, key=key, type=btn_type):
+ if gui.button(label, key=key, type=btn_type):
stripe_subscription_checkout_redirect(user=user, plan=plan)
@@ -522,7 +521,7 @@ def stripe_subscription_checkout_redirect(user: AppUser, plan: PricingPlan):
"payment_method_save": "enabled",
},
)
- raise RedirectException(checkout_session.url, status_code=303)
+ raise gui.RedirectException(checkout_session.url, status_code=303)
def render_paypal_subscription_button(
@@ -530,11 +529,11 @@ def render_paypal_subscription_button(
plan: PricingPlan,
):
if not plan.supports_paypal():
- st.write("Paypal subscription not available")
+ gui.write("Paypal subscription not available")
return
lookup_key = plan.key
- st.html(
+ gui.html(
f"""
str:
@@ -617,8 +616,8 @@ def render_billing_history(user: AppUser, limit: int = 50):
if not txns:
return
- st.write("## Billing History", className="d-block")
- st.table(
+ gui.write("## Billing History", className="d-block")
+ gui.table(
pd.DataFrame.from_records(
[
{
@@ -633,7 +632,7 @@ def render_billing_history(user: AppUser, limit: int = 50):
),
)
if txns.count() > limit:
- st.caption(f"Showing only the most recent {limit} transactions.")
+ gui.caption(f"Showing only the most recent {limit} transactions.")
def render_auto_recharge_section(user: AppUser):
@@ -643,9 +642,9 @@ def render_auto_recharge_section(user: AppUser):
)
subscription = user.subscription
- st.write("## Auto Recharge & Limits")
- with st.div(className="h4"):
- auto_recharge_enabled = st.checkbox(
+ gui.write("## Auto Recharge & Limits")
+ with gui.div(className="h4"):
+ auto_recharge_enabled = gui.checkbox(
"Enable auto recharge",
value=subscription.auto_recharge_enabled,
)
@@ -656,20 +655,20 @@ def render_auto_recharge_section(user: AppUser):
subscription.save(update_fields=["auto_recharge_enabled"])
if not auto_recharge_enabled:
- st.caption(
+ gui.caption(
"Enable auto recharge to automatically keep your credit balance topped up."
)
return
- col1, col2 = st.columns(2)
- with col1, st.div(className="mb-2"):
- subscription.auto_recharge_topup_amount = st.selectbox(
+ col1, col2 = gui.columns(2)
+ with col1, gui.div(className="mb-2"):
+ subscription.auto_recharge_topup_amount = gui.selectbox(
"###### Automatically purchase",
options=settings.ADDON_AMOUNT_CHOICES,
format_func=lambda amt: f"{settings.ADDON_CREDITS_PER_DOLLAR * int(amt):,} credits for ${amt}",
value=subscription.auto_recharge_topup_amount,
)
- subscription.auto_recharge_balance_threshold = st.selectbox(
+ subscription.auto_recharge_balance_threshold = gui.selectbox(
"###### when balance falls below",
options=settings.AUTO_RECHARGE_BALANCE_THRESHOLD_CHOICES,
format_func=lambda c: f"{c:,} credits",
@@ -677,49 +676,51 @@ def render_auto_recharge_section(user: AppUser):
)
with col2:
- st.write("###### Monthly Recharge Budget")
- st.caption(
+ gui.write("###### Monthly Recharge Budget")
+ gui.caption(
"""
If your account exceeds this budget in a given calendar month,
subsequent runs & API requests will be rejected.
""",
)
- with st.div(className="d-flex align-items-center"):
- user.subscription.monthly_spending_budget = st.number_input(
+ with gui.div(className="d-flex align-items-center"):
+ user.subscription.monthly_spending_budget = gui.number_input(
"",
min_value=10,
value=user.subscription.monthly_spending_budget,
key="monthly-spending-budget",
)
- st.write("USD", className="d-block ms-2")
+ gui.write("USD", className="d-block ms-2")
- st.write("###### Email Notification Threshold")
- st.caption(
+ gui.write("###### Email Notification Threshold")
+ gui.caption(
"""
If your account purchases exceed this threshold in a given
calendar month, you will receive an email notification.
"""
)
- with st.div(className="d-flex align-items-center"):
- user.subscription.monthly_spending_notification_threshold = st.number_input(
- "",
- min_value=10,
- value=user.subscription.monthly_spending_notification_threshold,
- key="monthly-spending-notification-threshold",
+ with gui.div(className="d-flex align-items-center"):
+ user.subscription.monthly_spending_notification_threshold = (
+ gui.number_input(
+ "",
+ min_value=10,
+ value=user.subscription.monthly_spending_notification_threshold,
+ key="monthly-spending-notification-threshold",
+ )
)
- st.write("USD", className="d-block ms-2")
+ gui.write("USD", className="d-block ms-2")
- if st.button("Save", type="primary", key="save-auto-recharge-and-limits"):
+ if gui.button("Save", type="primary", key="save-auto-recharge-and-limits"):
try:
subscription.full_clean()
except ValidationError as e:
- st.error(str(e))
+ gui.error(str(e))
else:
subscription.save()
- st.success("Settings saved!")
+ gui.success("Settings saved!")
def left_and_right(*, className: str = "", **props):
className += " d-flex flex-row justify-content-between align-items-center"
- with st.div(className=className, **props):
- return st.div(), st.div()
+ with gui.div(className=className, **props):
+ return gui.div(), gui.div()
diff --git a/daras_ai_v2/bot_integration_widgets.py b/daras_ai_v2/bot_integration_widgets.py
index b6324ce5f..4958f389d 100644
--- a/daras_ai_v2/bot_integration_widgets.py
+++ b/daras_ai_v2/bot_integration_widgets.py
@@ -6,7 +6,7 @@
from django.utils.text import slugify
from furl import furl
-import gooey_ui as st
+import gooey_gui as gui
from app_users.models import AppUser
from bots.models import BotIntegration, BotIntegrationAnalysisRun, Platform
from daras_ai_v2 import settings, icons
@@ -19,52 +19,49 @@
def general_integration_settings(bi: BotIntegration, current_user: AppUser):
- if st.session_state.get(f"_bi_reset_{bi.id}"):
- st.session_state[f"_bi_streaming_enabled_{bi.id}"] = (
+ if gui.session_state.get(f"_bi_reset_{bi.id}"):
+ gui.session_state[f"_bi_streaming_enabled_{bi.id}"] = (
BotIntegration._meta.get_field("streaming_enabled").default
)
- st.session_state[f"_bi_show_feedback_buttons_{bi.id}"] = (
+ gui.session_state[f"_bi_show_feedback_buttons_{bi.id}"] = (
BotIntegration._meta.get_field("show_feedback_buttons").default
)
- st.session_state["analysis_urls"] = []
- st.session_state.pop("--list-view:analysis_urls", None)
-
- bi.streaming_enabled = st.checkbox(
- "**๐ก Streaming Enabled**",
- value=bi.streaming_enabled,
- key=f"_bi_streaming_enabled_{bi.id}",
- )
- st.caption("Responses will be streamed to the user in real-time if enabled.")
- bi.show_feedback_buttons = st.checkbox(
- "**๐๐พ ๐๐ฝ Show Feedback Buttons**",
- value=bi.show_feedback_buttons,
- key=f"_bi_show_feedback_buttons_{bi.id}",
- )
- st.caption(
- "Users can rate and provide feedback on every copilot response if enabled."
- )
-
- st.caption(
- "Please note that this language is distinct from the one provided in the workflow settings. Hence, this allows you to integrate the same bot in many languages."
- )
+ gui.session_state["analysis_urls"] = []
+ gui.session_state.pop("--list-view:analysis_urls", None)
+
+ if bi.platform != Platform.TWILIO:
+ bi.streaming_enabled = gui.checkbox(
+ "**๐ก Streaming Enabled**",
+ value=bi.streaming_enabled,
+ key=f"_bi_streaming_enabled_{bi.id}",
+ )
+ gui.caption("Responses will be streamed to the user in real-time if enabled.")
+ bi.show_feedback_buttons = gui.checkbox(
+ "**๐๐พ ๐๐ฝ Show Feedback Buttons**",
+ value=bi.show_feedback_buttons,
+ key=f"_bi_show_feedback_buttons_{bi.id}",
+ )
+ gui.caption(
+ "Users can rate and provide feedback on every copilot response if enabled."
+ )
- st.write(
+ gui.write(
"""
##### ๐ง Analysis Scripts
Analyze each incoming message and the copilot's response using a Gooey.AI /LLM workflow. Must return a JSON object.
[Learn more](https://gooey.ai/docs/guides/build-your-ai-copilot/conversation-analysis).
"""
)
- if "analysis_urls" not in st.session_state:
- st.session_state["analysis_urls"] = [
+ if "analysis_urls" not in gui.session_state:
+ gui.session_state["analysis_urls"] = [
(anal.published_run or anal.saved_run).get_app_url()
for anal in bi.analysis_runs.all()
]
- if st.session_state.get("analysis_urls"):
+ if gui.session_state.get("analysis_urls"):
from recipes.VideoBots import VideoBotsPage
- st.anchor(
+ gui.anchor(
"๐ View Results",
str(
furl(
@@ -80,7 +77,7 @@ def general_integration_settings(bi: BotIntegration, current_user: AppUser):
input_analysis_runs = []
def render_workflow_url_input(key: str, del_key: str | None, d: dict):
- with st.columns([3, 2])[0]:
+ with gui.columns([3, 2])[0]:
ret = workflow_url_input(
page_cls=CompareLLMPage,
key=key,
@@ -103,10 +100,10 @@ def render_workflow_url_input(key: str, del_key: str | None, d: dict):
flatten_dict_key="url",
)
- with st.center():
- with st.div():
- pressed_update = st.button("โ
Save")
- pressed_reset = st.button(
+ with gui.center():
+ with gui.div():
+ pressed_update = gui.button("โ
Save")
+ pressed_reset = gui.button(
"Reset", key=f"_bi_reset_{bi.id}", type="tertiary"
)
if pressed_update or pressed_reset:
@@ -124,18 +121,61 @@ def render_workflow_url_input(key: str, del_key: str | None, d: dict):
# delete any analysis runs that were removed
bi.analysis_runs.all().exclude(id__in=input_analysis_runs).delete()
except ValidationError as e:
- st.error(str(e))
- st.write("---")
+ gui.error(str(e))
+ gui.write("---")
+
+
+def twilio_specific_settings(bi: BotIntegration):
+ SETTINGS_FIELDS = ["twilio_use_missed_call", "twilio_initial_text", "twilio_initial_audio_url", "twilio_waiting_text", "twilio_waiting_audio_url"] # fmt:skip
+ if gui.session_state.get(f"_bi_reset_{bi.id}"):
+ for field in SETTINGS_FIELDS:
+ gui.session_state[f"_bi_{field}_{bi.id}"] = BotIntegration._meta.get_field(
+ field
+ ).default
+
+ bi.twilio_initial_text = gui.text_area(
+ "###### ๐ Initial Text (said at the beginning of each call)",
+ value=bi.twilio_initial_text,
+ key=f"_bi_twilio_initial_text_{bi.id}",
+ )
+ bi.twilio_initial_audio_url = (
+ gui.file_uploader(
+ "###### ๐ Initial Audio (played at the beginning of each call)",
+ accept=["audio/*"],
+ key=f"_bi_twilio_initial_audio_url_{bi.id}",
+ )
+ or ""
+ )
+ bi.twilio_waiting_audio_url = (
+ gui.file_uploader(
+ "###### ๐ต Waiting Audio (played while waiting for a response -- Voice)",
+ accept=["audio/*"],
+ key=f"_bi_twilio_waiting_audio_url_{bi.id}",
+ )
+ or ""
+ )
+ bi.twilio_waiting_text = gui.text_area(
+ "###### ๐ Waiting Text (texted while waiting for a response -- SMS)",
+ key=f"_bi_twilio_waiting_text_{bi.id}",
+ )
+ bi.twilio_use_missed_call = gui.checkbox(
+ "๐ Use Missed Call",
+ value=bi.twilio_use_missed_call,
+ key=f"_bi_twilio_use_missed_call_{bi.id}",
+ )
+ gui.caption(
+ "When enabled, immediately hangs up incoming calls and calls back the user so they don't incur charges (depending on their carrier/plan)."
+ )
def slack_specific_settings(bi: BotIntegration, default_name: str):
- if st.session_state.get(f"_bi_reset_{bi.id}"):
- st.session_state[f"_bi_name_{bi.id}"] = default_name
- st.session_state[f"_bi_slack_read_receipt_msg_{bi.id}"] = (
+ if gui.session_state.get(f"_bi_reset_{bi.id}"):
+ gui.session_state[f"_bi_name_{bi.id}"] = default_name
+ gui.session_state[f"_bi_slack_read_receipt_msg_{bi.id}"] = (
BotIntegration._meta.get_field("slack_read_receipt_msg").default
)
- bi.slack_read_receipt_msg = st.text_input(
+ bi.slack_read_receipt_msg = gui.text_input(
"""
##### โ
Read Receipt
This message is sent immediately after recieving a user message and replaced with the copilot's response once it's ready.
@@ -145,7 +185,7 @@ def slack_specific_settings(bi: BotIntegration, default_name: str):
value=bi.slack_read_receipt_msg,
key=f"_bi_slack_read_receipt_msg_{bi.id}",
)
- bi.name = st.text_input(
+ bi.name = gui.text_input(
"""
##### ๐ชช Channel Specific Bot Name
This is the name the bot will post as in this specific channel (to be displayed in Slack)
@@ -154,7 +194,7 @@ def slack_specific_settings(bi: BotIntegration, default_name: str):
value=bi.name,
key=f"_bi_name_{bi.id}",
)
- st.caption("Enable streaming messages to Slack in real-time.")
+ gui.caption("Enable streaming messages to Slack in real-time.")
def broadcast_input(bi: BotIntegration):
@@ -169,7 +209,7 @@ def broadcast_input(bi: BotIntegration):
)
/ "docs"
)
- text = st.text_area(
+ text = gui.text_area(
f"""
###### Broadcast Message ๐ข
Broadcast a message to all users of this integration using this bot account. \\
@@ -178,40 +218,50 @@ def broadcast_input(bi: BotIntegration):
key=key + ":text",
placeholder="Type your message here...",
)
- audio = st.file_uploader(
+ audio = gui.file_uploader(
"**๐ค Audio**",
key=key + ":audio",
help="Attach a video to this message.",
optional=True,
accept=["audio/*"],
)
- video = st.file_uploader(
- "**๐ฅ Video**",
- key=key + ":video",
- help="Attach a video to this message.",
- optional=True,
- accept=["video/*"],
- )
- documents = st.file_uploader(
- "**๐ Documents**",
- key=key + ":documents",
- help="Attach documents to this message.",
- accept_multiple_files=True,
- optional=True,
- )
+ video = None
+ documents = None
+ medium = "Voice Call"
+ if bi.platform == Platform.TWILIO:
+ medium = gui.selectbox(
+ "###### ๐ฑ Medium",
+ ["Voice Call", "SMS/MMS"],
+ key=key + ":medium",
+ )
+ else:
+ video = gui.file_uploader(
+ "**๐ฅ Video**",
+ key=key + ":video",
+ help="Attach a video to this message.",
+ optional=True,
+ accept=["video/*"],
+ )
+ documents = gui.file_uploader(
+ "**๐ Documents**",
+ key=key + ":documents",
+ help="Attach documents to this message.",
+ accept_multiple_files=True,
+ optional=True,
+ )
should_confirm_key = key + ":should_confirm"
confirmed_send_btn = key + ":confirmed_send"
- if st.button("๐ค Send Broadcast", style=dict(height="3.2rem"), key=key + ":send"):
- st.session_state[should_confirm_key] = True
- if not st.session_state.get(should_confirm_key):
+ if gui.button("๐ค Send Broadcast", style=dict(height="3.2rem"), key=key + ":send"):
+ gui.session_state[should_confirm_key] = True
+ if not gui.session_state.get(should_confirm_key):
return
convos = bi.conversations.all()
- if st.session_state.get(confirmed_send_btn):
- st.success("Started sending broadcast!")
- st.session_state.pop(confirmed_send_btn)
- st.session_state.pop(should_confirm_key)
+ if gui.session_state.get(confirmed_send_btn):
+ gui.success("Started sending broadcast!")
+ gui.session_state.pop(confirmed_send_btn)
+ gui.session_state.pop(should_confirm_key)
send_broadcast_msgs_chunked(
text=text,
audio=audio,
@@ -219,15 +269,16 @@ def broadcast_input(bi: BotIntegration):
documents=documents,
bi=bi,
convo_qs=convos,
+ medium=medium,
)
else:
if not convos.exists():
- st.error("No users have interacted with this bot yet.", icon="โ ๏ธ")
+ gui.error("No users have interacted with this bot yet.", icon="โ ๏ธ")
return
- st.write(
+ gui.write(
f"Are you sure? This will send a message to all {convos.count()} users that have ever interacted with this bot.\n"
)
- st.button("โ
Yes, Send", key=confirmed_send_btn)
+ gui.button("โ
Yes, Send", key=confirmed_send_btn)
def get_bot_test_link(bi: BotIntegration) -> str | None:
@@ -254,6 +305,8 @@ def get_bot_test_link(bi: BotIntegration) -> str | None:
integration_name=slugify(bi.name) or "untitled",
),
)
+ elif bi.twilio_phone_number:
+ return str(furl("tel:") / bi.twilio_phone_number.as_e164)
else:
return None
@@ -265,7 +318,7 @@ def get_web_widget_embed_code(bi: BotIntegration) -> str:
integration_id=bi.api_integration_id(),
integration_name=slugify(bi.name) or "untitled",
),
- )
+ ).rstrip("/")
return dedent(
f"""
@@ -275,43 +328,43 @@ def get_web_widget_embed_code(bi: BotIntegration) -> str:
def web_widget_config(bi: BotIntegration, user: AppUser | None):
- with st.div(style={"width": "100%", "textAlign": "left"}):
- col1, col2 = st.columns(2)
+ with gui.div(style={"width": "100%", "textAlign": "left"}):
+ col1, col2 = gui.columns(2)
with col1:
- if st.session_state.get("--update-display-picture"):
- display_pic = st.file_uploader(
+ if gui.session_state.get("--update-display-picture"):
+ display_pic = gui.file_uploader(
label="###### Display Picture",
accept=["image/*"],
)
if display_pic:
bi.photo_url = display_pic
else:
- if st.button(f"{icons.camera} Change Photo"):
- st.session_state["--update-display-picture"] = True
- st.experimental_rerun()
- bi.name = st.text_input("###### Name", value=bi.name)
- bi.descripton = st.text_area(
+ if gui.button(f"{icons.camera} Change Photo"):
+ gui.session_state["--update-display-picture"] = True
+ gui.rerun()
+ bi.name = gui.text_input("###### Name", value=bi.name)
+ bi.descripton = gui.text_area(
"###### Description",
value=bi.descripton,
)
- scol1, scol2 = st.columns(2)
+ scol1, scol2 = gui.columns(2)
with scol1:
- bi.by_line = st.text_input(
+ bi.by_line = gui.text_input(
"###### By Line",
value=bi.by_line or (user and f"By {user.display_name}"),
)
with scol2:
- bi.website_url = st.text_input(
+ bi.website_url = gui.text_input(
"###### Website Link",
value=bi.website_url or (user and user.website_url),
)
- st.write("###### Conversation Starters")
+ gui.write("###### Conversation Starters")
bi.conversation_starters = list(
filter(
None,
[
- st.text_input("", key=f"--question-{i}", value=value)
+ gui.text_input("", key=f"--question-{i}", value=value)
for i, value in zip_longest(range(4), bi.conversation_starters)
],
)
@@ -332,39 +385,39 @@ def web_widget_config(bi: BotIntegration, user: AppUser | None):
| bi.web_config_extras
)
- scol1, scol2 = st.columns(2)
+ scol1, scol2 = gui.columns(2)
with scol1:
- config["showSources"] = st.checkbox(
+ config["showSources"] = gui.checkbox(
"Show Sources", value=config["showSources"]
)
- config["enablePhotoUpload"] = st.checkbox(
+ config["enablePhotoUpload"] = gui.checkbox(
"Allow Photo Upload", value=config["enablePhotoUpload"]
)
with scol2:
- config["enableAudioMessage"] = st.checkbox(
+ config["enableAudioMessage"] = gui.checkbox(
"Enable Audio Message", value=config["enableAudioMessage"]
)
- config["enableLipsyncVideo"] = st.checkbox(
+ config["enableLipsyncVideo"] = gui.checkbox(
"Enable Lipsync Video", value=config["enableLipsyncVideo"]
)
- # config["branding"]["showPoweredByGooey"] = st.checkbox(
+ # config["branding"]["showPoweredByGooey"] = gui.checkbox(
# "Show Powered By Gooey", value=config["branding"]["showPoweredByGooey"]
# )
- with st.expander("Embed Settings"):
- st.caption(
+ with gui.expander("Embed Settings"):
+ gui.caption(
"These settings will take effect when you embed the widget on your website."
)
- scol1, scol2 = st.columns(2)
+ scol1, scol2 = gui.columns(2)
with scol1:
- config["mode"] = st.selectbox(
+ config["mode"] = gui.selectbox(
"###### Mode",
["popup", "inline", "fullscreen"],
value=config["mode"],
format_func=lambda x: x.capitalize(),
)
if config["mode"] == "popup":
- config["branding"]["fabLabel"] = st.text_input(
+ config["branding"]["fabLabel"] = gui.text_input(
"###### Label",
value=config["branding"].get("fabLabel", "Help"),
)
@@ -374,28 +427,28 @@ def web_widget_config(bi: BotIntegration, user: AppUser | None):
# remove defaults
bi.web_config_extras = config
- with st.div(className="d-flex justify-content-end"):
- if st.button(
+ with gui.div(className="d-flex justify-content-end"):
+ if gui.button(
f"{icons.save} Update Web Preview",
type="primary",
className="align-right",
):
bi.save()
- st.experimental_rerun()
+ gui.rerun()
with col2:
- with st.center(), st.div():
+ with gui.center(), gui.div():
web_preview_tab = f"{icons.chat} Web Preview"
api_tab = f"{icons.api} API"
- selected = st.horizontal_radio("", [web_preview_tab, api_tab])
+ selected = gui.horizontal_radio("", [web_preview_tab, api_tab])
if selected == web_preview_tab:
- st.html(
+ gui.html(
# language=html
f"""
"""
)
- st.js(
+ gui.js(
# language=javascript
"""
async function loadGooeyEmbed() {
diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py
index e00805385..58a154590 100644
--- a/daras_ai_v2/bots.py
+++ b/daras_ai_v2/bots.py
@@ -2,6 +2,7 @@
import typing
from datetime import datetime
+import gooey_gui as gui
from django.db import transaction
from django.utils import timezone
from fastapi import HTTPException
@@ -24,7 +25,6 @@
from daras_ai_v2.base import BasePage, RecipeRunState, StateKeys
from daras_ai_v2.language_model import CHATML_ROLE_USER, CHATML_ROLE_ASSISTANT
from daras_ai_v2.vector_search import doc_url_to_file_metadata
-from gooey_ui.pubsub import realtime_subscribe
from gooeysite.bg_db_conn import db_middleware, get_celery_result_db_safe
from recipes.VideoBots import VideoBotsPage, ReplyButton
from routers.api import submit_api_call
@@ -402,7 +402,7 @@ def _process_and_send_msg(
if bot.streaming_enabled:
# subscribe to the realtime channel for updates
channel = page.realtime_channel_name(run_id, uid)
- with realtime_subscribe(channel) as realtime_gen:
+ with gui.realtime_subscribe(channel) as realtime_gen:
for state in realtime_gen:
bot.recipe_run_state = page.get_run_state(state)
bot.run_status = state.get(StateKeys.run_status) or ""
@@ -497,7 +497,7 @@ def build_run_vars(convo: Conversation, user_msg_id: str):
bi = convo.bot_integration
if bi.platform == Platform.WEB:
- user_msg_id = user_msg_id.lstrip(MSG_ID_PREFIX)
+ user_msg_id = user_msg_id.removeprefix(MSG_ID_PREFIX)
variables = dict(
platform=Platform(bi.platform).name,
integration_id=bi.api_integration_id(),
diff --git a/daras_ai_v2/breadcrumbs.py b/daras_ai_v2/breadcrumbs.py
index c6d389b24..a0d06101e 100644
--- a/daras_ai_v2/breadcrumbs.py
+++ b/daras_ai_v2/breadcrumbs.py
@@ -1,6 +1,6 @@
import typing
-import gooey_ui as st
+import gooey_gui as gui
from bots.models import (
SavedRun,
PublishedRun,
@@ -32,7 +32,7 @@ def has_breadcrumbs(self):
def render_breadcrumbs(breadcrumbs: TitleBreadCrumbs, *, is_api_call: bool = False):
- st.html(
+ gui.html(
"""