From d81619368a9ee23f083d3a28d6080b506a1ea16b Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Sat, 30 Dec 2023 18:43:41 +0530 Subject: [PATCH] Add UserError exception type to separate server vs user errors --- ...ason_alter_savedrun_example_id_and_more.py | 37 +++++++++++ bots/models.py | 14 +++++ celeryapp/tasks.py | 62 +++++++++++-------- daras_ai_v2/asr.py | 3 +- daras_ai_v2/base.py | 25 +++++++- daras_ai_v2/exceptions.py | 14 +++++ daras_ai_v2/safety_checker.py | 5 +- gooey_ui/components/__init__.py | 19 ++++++ recipes/EmailFaceInpainting.py | 3 +- recipes/LetterWriter.py | 11 ++-- recipes/QRCodeGenerator.py | 28 +++++++-- recipes/VideoBots.py | 3 +- 12 files changed, 179 insertions(+), 45 deletions(-) create mode 100644 bots/migrations/0054_savedrun_finish_reason_alter_savedrun_example_id_and_more.py create mode 100644 daras_ai_v2/exceptions.py diff --git a/bots/migrations/0054_savedrun_finish_reason_alter_savedrun_example_id_and_more.py b/bots/migrations/0054_savedrun_finish_reason_alter_savedrun_example_id_and_more.py new file mode 100644 index 000000000..bc096462e --- /dev/null +++ b/bots/migrations/0054_savedrun_finish_reason_alter_savedrun_example_id_and_more.py @@ -0,0 +1,37 @@ +# Generated by Django 4.2.7 on 2023-12-30 00:52 + +from django.db import migrations, models +from bots.models import FinishReason + + +def forwards_func(apps, schema_editor): + saved_run = apps.get_model("bots", "SavedRun") + saved_run.objects.filter(run_status="", error_msg="").update( + finish_reason=FinishReason.DONE, + ) + saved_run.objects.exclude(error_msg="").update( + finish_reason=FinishReason.SERVER_ERROR, + ) + + +def backwards_func(apps, schema_editor): + pass + + +class Migration(migrations.Migration): + dependencies = [ + ("bots", "0053_alter_publishedrun_workflow_alter_savedrun_workflow_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="savedrun", + name="finish_reason", + field=models.IntegerField( + choices=[(1, "User Error"), (2, "Server Error"), (3, "Done")], + default=None, + null=True, + ), + ), + migrations.RunPython(forwards_func, backwards_func), + ] diff --git a/bots/models.py b/bots/models.py index f44b64bb9..642046940 100644 --- a/bots/models.py +++ b/bots/models.py @@ -31,6 +31,12 @@ EPOCH = datetime.datetime.utcfromtimestamp(0) +class FinishReason(models.IntegerChoices): + USER_ERROR = 1 + SERVER_ERROR = 2 + DONE = 3 + + class PublishedRunVisibility(models.IntegerChoices): UNLISTED = 1 PUBLIC = 2 @@ -157,6 +163,11 @@ class SavedRun(models.Model): state = models.JSONField(default=dict, blank=True, encoder=PostgresJSONEncoder) + finish_reason = models.IntegerField( + choices=FinishReason.choices, + null=True, + default=None, + ) error_msg = models.TextField(default="", blank=True) run_time = models.DurationField(default=datetime.timedelta, blank=True) run_status = models.TextField(default="", blank=True) @@ -225,6 +236,8 @@ def to_dict(self) -> dict: ret[StateKeys.created_at] = self.created_at if self.error_msg: ret[StateKeys.error_msg] = self.error_msg + if self.finish_reason: + ret[StateKeys.finish_reason] = self.finish_reason if self.run_time: ret[StateKeys.run_time] = self.run_time.total_seconds() if self.run_status: @@ -255,6 +268,7 @@ def copy_from_firebase_state(self, state: dict) -> "SavedRun": if created_at: self.created_at = created_at self.error_msg = state.pop(StateKeys.error_msg, None) or "" + self.finish_reason = state.pop(StateKeys.finish_reason, None) self.run_time = datetime.timedelta( seconds=state.pop(StateKeys.run_time, None) or 0 ) diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 1868c7a18..76b7b652c 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -4,14 +4,16 @@ from types import SimpleNamespace import sentry_sdk +import pydantic import gooey_ui as st from app_users.models import AppUser -from bots.models import SavedRun +from bots.models import SavedRun, FinishReason from celeryapp.celeryconfig import app from daras_ai.image_input import truncate_text_words from daras_ai_v2 import settings -from daras_ai_v2.base import StateKeys, err_msg_for_exc, BasePage +from daras_ai_v2.base import BasePage, StateKeys, err_msg_for_exc +from daras_ai_v2.exceptions import UserError, raise_as_user_error from daras_ai_v2.send_email import send_email_via_postmark from daras_ai_v2.settings import templates from gooey_ui.pubsub import realtime_push @@ -27,15 +29,16 @@ def gui_runner( uid: str, state: dict, channel: str, - query_params: dict = None, + query_params: dict | None = None, is_api_call: bool = False, ): page = page_cls(request=SimpleNamespace(user=AppUser.objects.get(id=user_id))) sr = page.run_doc_sr(run_id, uid) st.set_session_state(state) - run_time = 0 + run_time = 0.0 yield_val = None + finish_reason = None error_msg = None set_query_params(query_params or {}) @@ -53,6 +56,7 @@ def save(done=False): # set run status and run time status = { StateKeys.run_time: run_time, + StateKeys.finish_reason: finish_reason, StateKeys.error_msg: error_msg, StateKeys.run_status: run_status, } @@ -75,30 +79,36 @@ def save(done=False): try: gen = page.run(st.session_state) save() - while True: - # record time - start_time = time() - try: - # advance the generator (to further progress of run()) - yield_val = next(gen) - # increment total time taken after every iteration - run_time += time() - start_time - continue - # run completed - except StopIteration: - run_time += time() - start_time - sr.transaction, sr.price = page.deduct_credits(st.session_state) - break + start_time = time() + try: + with raise_as_user_error([pydantic.ValidationError]): + for yield_val in gen: + # increment total time taken after every iteration + run_time += time() - start_time + save() + except UserError as e: + # handled errors caused due to user input + run_time += time() - start_time + finish_reason = FinishReason.USER_ERROR + traceback.print_exc() + sentry_sdk.capture_exception(e) + error_msg = err_msg_for_exc(e) + except Exception as e: # render errors nicely - except Exception as e: - run_time += time() - start_time - traceback.print_exc() - sentry_sdk.capture_exception(e) - error_msg = err_msg_for_exc(e) - break - finally: - save() + run_time += time() - start_time + finish_reason = FinishReason.SERVER_ERROR + traceback.print_exc() + sentry_sdk.capture_exception(e) + error_msg = err_msg_for_exc(e) + else: + # run completed + run_time += time() - start_time + finish_reason = FinishReason.DONE + sr.transaction, sr.price = page.deduct_credits(st.session_state) finally: + if not finish_reason: + finish_reason = FinishReason.SERVER_ERROR + error_msg = "Something went wrong. Please try again later." save(done=True) if not is_api_call: send_email_on_completion(page, sr) diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index 898ef61be..214483fa8 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -14,6 +14,7 @@ import gooey_ui as st from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri from daras_ai_v2 import settings +from daras_ai_v2.exceptions import UserError from daras_ai_v2.functional import map_parallel from daras_ai_v2.gdrive_downloader import ( is_gdrive_url, @@ -570,7 +571,7 @@ def run_asr( assert data.get("chunks"), f"{selected_model.value} can't generate VTT" return generate_vtt(data["chunks"]) case _: - raise ValueError(f"Invalid output format: {output_format}") + raise UserError(f"Invalid output format: {output_format}") def _get_or_create_recognizer( diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 3368a36dd..46d9a8ca7 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -28,6 +28,7 @@ import gooey_ui as st from app_users.models import AppUser, AppUserTransaction from bots.models import ( + FinishReason, SavedRun, PublishedRun, PublishedRunVersion, @@ -46,6 +47,7 @@ from daras_ai_v2.db import ( ANONYMOUS_USER_COOKIE, ) +from daras_ai_v2.exceptions import UserError from daras_ai_v2.grid_layout_widget import grid_layout from daras_ai_v2.html_spinner_widget import html_spinner from daras_ai_v2.manage_api_keys_widget import manage_api_keys @@ -95,6 +97,7 @@ class StateKeys: created_at = "created_at" updated_at = "updated_at" + finish_reason = "__finish_reason" error_msg = "__error_msg" run_time = "__run_time" run_status = "__run_status" @@ -1381,7 +1384,17 @@ def _render_completed_output(self): def _render_failed_output(self): err_msg = st.session_state.get(StateKeys.error_msg) - st.error(err_msg, unsafe_allow_html=True) + finish_reason = st.session_state.get(StateKeys.finish_reason) + self._render_error(err_msg, finish_reason) + + def _render_error(self, error_msg: str, finish_reason: FinishReason): + match finish_reason: + case FinishReason.USER_ERROR: + st.warning(error_msg, unsafe_allow_html=True) + case FinishReason.SERVER_ERROR: + st.error(error_msg, unsafe_allow_html=True) + case _: + raise ValueError(f"invalid finish reason for error: {finish_reason}") def _render_running_output(self): run_status = st.session_state.get(StateKeys.run_status) @@ -1421,6 +1434,7 @@ def on_submit(self): st.session_state[StateKeys.error_msg] = self.generate_credit_error_message( example_id, run_id, uid ) + st.session_state[StateKeys.finish_reason] = FinishReason.USER_ERROR self.run_doc_sr(run_id, uid).set(self.state_to_doc(st.session_state)) else: self.call_runner_task(example_id, run_id, uid) @@ -1439,6 +1453,7 @@ def should_submit_after_login(self) -> bool: def create_new_run(self): st.session_state[StateKeys.run_status] = "Starting..." st.session_state.pop(StateKeys.error_msg, None) + st.session_state.pop(StateKeys.finish_reason, None) st.session_state.pop(StateKeys.run_time, None) self._setup_rng_seed() self.clear_outputs() @@ -1529,6 +1544,7 @@ def _setup_rng_seed(self): def clear_outputs(self): # clear error msg st.session_state.pop(StateKeys.error_msg, None) + st.session_state.pop(StateKeys.finish_reason, None) # clear outputs for field_name in self.ResponseModel.__fields__: st.session_state.pop(field_name, None) @@ -1608,6 +1624,7 @@ def fields_to_save(self) -> [str]: for field_name in model.__fields__ ] + [ StateKeys.error_msg, + StateKeys.finish_reason, StateKeys.run_status, StateKeys.run_time, ] @@ -1711,7 +1728,9 @@ def _render_run_preview(self, saved_run: SavedRun): if saved_run.run_status: html_spinner(saved_run.run_status) elif saved_run.error_msg: - st.error(saved_run.error_msg, unsafe_allow_html=True) + self._render_error( + error_msg=saved_run.error_msg, finish_reason=saved_run.finish_reason + ) return self.render_example(saved_run.to_dict()) @@ -1963,6 +1982,8 @@ def err_msg_for_exc(e): return f"(GPU) {err_type}: {err_str}" err_str = str(err_body) return f"(HTTP {response.status_code}) {html.escape(err_str[:1000])}" + elif isinstance(e, UserError): + return str(e) else: return f"{type(e).__name__}: {e}" diff --git a/daras_ai_v2/exceptions.py b/daras_ai_v2/exceptions.py new file mode 100644 index 000000000..8691dd72e --- /dev/null +++ b/daras_ai_v2/exceptions.py @@ -0,0 +1,14 @@ +from contextlib import contextmanager +from typing import Type + + +class UserError(Exception): + pass + + +@contextmanager +def raise_as_user_error(excs: list[Type[Exception]]): + try: + yield + except tuple(excs) as e: + raise UserError(str(e)) from e diff --git a/daras_ai_v2/safety_checker.py b/daras_ai_v2/safety_checker.py index 541e41b97..99aae693b 100644 --- a/daras_ai_v2/safety_checker.py +++ b/daras_ai_v2/safety_checker.py @@ -1,5 +1,6 @@ from app_users.models import AppUser from daras_ai_v2.azure_image_moderation import is_image_nsfw +from daras_ai_v2.exceptions import UserError from daras_ai_v2.functional import flatten from daras_ai_v2 import settings from recipes.CompareLLM import CompareLLMPage @@ -43,14 +44,14 @@ def safety_checker_text(text_input: str): if not lines: continue if lines[-1].upper().endswith("FLAGGED"): - raise ValueError( + raise UserError( "Your request was rejected as a result of our safety system. Your prompt may contain text that is not allowed by our safety system." ) def safety_checker_image(image_url: str, cache: bool = False) -> None: if is_image_nsfw(image_url=image_url, cache=cache): - raise ValueError( + raise UserError( "Your request was rejected as a result of our safety system. " "Your input image may contain contents that are not allowed " "by our safety system." diff --git a/gooey_ui/components/__init__.py b/gooey_ui/components/__init__.py index 207620c1f..f26217bf9 100644 --- a/gooey_ui/components/__init__.py +++ b/gooey_ui/components/__init__.py @@ -144,6 +144,25 @@ def success(body: str, icon: str = "✅", *, unsafe_allow_html=False): markdown(dedent(body), unsafe_allow_html=unsafe_allow_html) +def warning(body: str, icon: str = "⚠️", *, unsafe_allow_html=False): + if not isinstance(body, str): + body = repr(body) + with div( + style=dict( + backgroundColor="rgba(238, 210, 2, 0.2)", + padding="1rem", + paddingBottom="0", + marginBottom="0.5rem", + borderRadius="0.25rem", + display="flex", + gap="0.5rem", + ) + ): + markdown(icon) + with div(): + markdown(dedent(body), unsafe_allow_html=unsafe_allow_html) + + def caption(body: str, **props): style = props.setdefault("style", {"fontSize": "0.9rem"}) markdown(body, className="text-muted", **props) diff --git a/recipes/EmailFaceInpainting.py b/recipes/EmailFaceInpainting.py index 86428c4ae..d68b8dba9 100644 --- a/recipes/EmailFaceInpainting.py +++ b/recipes/EmailFaceInpainting.py @@ -9,6 +9,7 @@ from bots.models import Workflow from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2 import db, settings +from daras_ai_v2.exceptions import UserError from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.send_email import send_email_via_postmark from daras_ai_v2.stable_diffusion import InpaintingModels @@ -170,7 +171,7 @@ def validate_form_v2(self): email_regex, email_address ), "Please provide a valid Email Address" else: - raise AssertionError("Please provide an Email Address or Twitter Handle") + raise UserError("Please provide an Email Address or Twitter Handle") from_email = st.session_state.get("email_from") email_subject = st.session_state.get("email_subject") diff --git a/recipes/LetterWriter.py b/recipes/LetterWriter.py index 87e0aa550..84cf4486b 100644 --- a/recipes/LetterWriter.py +++ b/recipes/LetterWriter.py @@ -8,6 +8,7 @@ from bots.models import Workflow from daras_ai.text_format import daras_ai_format_str from daras_ai_v2.base import BasePage +from daras_ai_v2.exceptions import UserError from daras_ai_v2.language_model import run_language_model from daras_ai_v2.text_training_data_widget import text_training_data, TrainingDataModel @@ -232,7 +233,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: json_body = request.api_json_body.replace("{{ action_id }}", request.action_id) if not (url and method): - raise ValueError("HTTP method / URL is empty. Please check your settings.") + raise UserError("HTTP method / URL is empty. Please check your settings.") if headers: headers = json.loads(headers) @@ -252,7 +253,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: yield "Generating Prompt..." if not request.input_prompt: - raise ValueError("Input prompt is Empty. Please check your settings.") + raise UserError("Input prompt is Empty. Please check your settings.") input_prompt = daras_ai_format_str( format_str=request.input_prompt, @@ -263,13 +264,11 @@ def run(self, state: dict) -> typing.Iterator[str | None]: yield "Generating Prompt..." if not request.prompt_header: - raise ValueError( + raise UserError( "Task description not provided. Please check your settings." ) if not request.example_letters: - raise ValueError( - "Example letters not provided. Please check your settings." - ) + raise UserError("Example letters not provided. Please check your settings.") prompt_prefix = "TalkingPoints:" completion_prefix = "Letter:" diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 30e109994..2382caa94 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -20,6 +20,7 @@ ) from daras_ai_v2.base import BasePage from daras_ai_v2.descriptions import prompting101 +from daras_ai_v2.exceptions import UserError from daras_ai_v2.img_model_settings_widgets import ( output_resolution_setting, img_model_settings, @@ -556,10 +557,13 @@ def is_url(url: str) -> bool: return True -def generate_and_upload_qr_code( +def get_qr_code_data( request: QRCodeGeneratorPage.RequestModel, user: AppUser, -) -> tuple[str, str, bool]: +) -> tuple[str, bool]: + """ + Returns tuple for the QR code data and whether the input URL was shortened + """ if request.qr_code_vcard: vcf_str = request.qr_code_vcard.to_vcf_str() qr_code_data = ShortenedURL.objects.get_or_create_for_workflow( @@ -568,7 +572,7 @@ def generate_and_upload_qr_code( user=user, workflow=Workflow.QR_CODE, )[0].shortened_url() - using_shortened_url = True + return qr_code_data, True else: if request.qr_code_file: qr_code_data = request.qr_code_file @@ -579,17 +583,29 @@ def generate_and_upload_qr_code( if isinstance(qr_code_data, str): qr_code_data = qr_code_data.strip() if not qr_code_data: - raise ValueError("Please provide QR Code URL, text content, or an image") - using_shortened_url = request.use_url_shortener and is_url(qr_code_data) + raise UserError("Please provide QR Code URL, text content, or an image") + using_shortened_url = ( + request.use_url_shortener and is_url(qr_code_data) or False + ) if using_shortened_url: qr_code_data = ShortenedURL.objects.get_or_create_for_workflow( url=qr_code_data, user=user, workflow=Workflow.QR_CODE, )[0].shortened_url() + return qr_code_data, using_shortened_url - img_cv2 = generate_qr_code(qr_code_data) +def generate_and_upload_qr_code( + request: QRCodeGeneratorPage.RequestModel, + user: AppUser, +) -> tuple[str, str, bool]: + try: + qr_code_data, using_shortened_url = get_qr_code_data(request, user) + except InvalidQRCode as e: + raise UserError from e + + img_cv2 = generate_qr_code(qr_code_data) img_cv2, _ = reposition_object( orig_img=img_cv2, orig_mask=img_cv2, diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 952df781c..f5e9abf0a 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -26,6 +26,7 @@ document_uploader, ) from daras_ai_v2.enum_selector_widget import enum_multiselect +from daras_ai_v2.exceptions import UserError from daras_ai_v2.field_render import field_title_desc from daras_ai_v2.functions import LLMTools from daras_ai_v2.glossary import glossary_input @@ -739,7 +740,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: ) max_allowed_tokens = min(max_allowed_tokens, request.max_tokens) if max_allowed_tokens < 0: - raise ValueError("Input Script is too long! Please reduce the script size.") + raise UserError("Input Script is too long! Please reduce the script size.") yield f"Running {model.value}..." if is_chat_model: