Skip to content

Commit

Permalink
Add UserError exception type to separate server vs user errors
Browse files Browse the repository at this point in the history
  • Loading branch information
nikochiko committed Dec 30, 2023
1 parent 93715d3 commit d816193
Show file tree
Hide file tree
Showing 12 changed files with 179 additions and 45 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Generated by Django 4.2.7 on 2023-12-30 00:52

from django.db import migrations, models
from bots.models import FinishReason


def forwards_func(apps, schema_editor):
saved_run = apps.get_model("bots", "SavedRun")
saved_run.objects.filter(run_status="", error_msg="").update(
finish_reason=FinishReason.DONE,
)
saved_run.objects.exclude(error_msg="").update(
finish_reason=FinishReason.SERVER_ERROR,
)


def backwards_func(apps, schema_editor):
pass


class Migration(migrations.Migration):
dependencies = [
("bots", "0053_alter_publishedrun_workflow_alter_savedrun_workflow_and_more"),
]

operations = [
migrations.AddField(
model_name="savedrun",
name="finish_reason",
field=models.IntegerField(
choices=[(1, "User Error"), (2, "Server Error"), (3, "Done")],
default=None,
null=True,
),
),
migrations.RunPython(forwards_func, backwards_func),
]
14 changes: 14 additions & 0 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@
EPOCH = datetime.datetime.utcfromtimestamp(0)


class FinishReason(models.IntegerChoices):
USER_ERROR = 1
SERVER_ERROR = 2
DONE = 3


class PublishedRunVisibility(models.IntegerChoices):
UNLISTED = 1
PUBLIC = 2
Expand Down Expand Up @@ -157,6 +163,11 @@ class SavedRun(models.Model):

state = models.JSONField(default=dict, blank=True, encoder=PostgresJSONEncoder)

finish_reason = models.IntegerField(
choices=FinishReason.choices,
null=True,
default=None,
)
error_msg = models.TextField(default="", blank=True)
run_time = models.DurationField(default=datetime.timedelta, blank=True)
run_status = models.TextField(default="", blank=True)
Expand Down Expand Up @@ -225,6 +236,8 @@ def to_dict(self) -> dict:
ret[StateKeys.created_at] = self.created_at
if self.error_msg:
ret[StateKeys.error_msg] = self.error_msg
if self.finish_reason:
ret[StateKeys.finish_reason] = self.finish_reason
if self.run_time:
ret[StateKeys.run_time] = self.run_time.total_seconds()
if self.run_status:
Expand Down Expand Up @@ -255,6 +268,7 @@ def copy_from_firebase_state(self, state: dict) -> "SavedRun":
if created_at:
self.created_at = created_at
self.error_msg = state.pop(StateKeys.error_msg, None) or ""
self.finish_reason = state.pop(StateKeys.finish_reason, None)
self.run_time = datetime.timedelta(
seconds=state.pop(StateKeys.run_time, None) or 0
)
Expand Down
62 changes: 36 additions & 26 deletions celeryapp/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
from types import SimpleNamespace

import sentry_sdk
import pydantic

import gooey_ui as st
from app_users.models import AppUser
from bots.models import SavedRun
from bots.models import SavedRun, FinishReason
from celeryapp.celeryconfig import app
from daras_ai.image_input import truncate_text_words
from daras_ai_v2 import settings
from daras_ai_v2.base import StateKeys, err_msg_for_exc, BasePage
from daras_ai_v2.base import BasePage, StateKeys, err_msg_for_exc
from daras_ai_v2.exceptions import UserError, raise_as_user_error
from daras_ai_v2.send_email import send_email_via_postmark
from daras_ai_v2.settings import templates
from gooey_ui.pubsub import realtime_push
Expand All @@ -27,15 +29,16 @@ def gui_runner(
uid: str,
state: dict,
channel: str,
query_params: dict = None,
query_params: dict | None = None,
is_api_call: bool = False,
):
page = page_cls(request=SimpleNamespace(user=AppUser.objects.get(id=user_id)))
sr = page.run_doc_sr(run_id, uid)

st.set_session_state(state)
run_time = 0
run_time = 0.0
yield_val = None
finish_reason = None
error_msg = None
set_query_params(query_params or {})

Expand All @@ -53,6 +56,7 @@ def save(done=False):
# set run status and run time
status = {
StateKeys.run_time: run_time,
StateKeys.finish_reason: finish_reason,
StateKeys.error_msg: error_msg,
StateKeys.run_status: run_status,
}
Expand All @@ -75,30 +79,36 @@ def save(done=False):
try:
gen = page.run(st.session_state)
save()
while True:
# record time
start_time = time()
try:
# advance the generator (to further progress of run())
yield_val = next(gen)
# increment total time taken after every iteration
run_time += time() - start_time
continue
# run completed
except StopIteration:
run_time += time() - start_time
sr.transaction, sr.price = page.deduct_credits(st.session_state)
break
start_time = time()
try:
with raise_as_user_error([pydantic.ValidationError]):
for yield_val in gen:
# increment total time taken after every iteration
run_time += time() - start_time
save()
except UserError as e:
# handled errors caused due to user input
run_time += time() - start_time
finish_reason = FinishReason.USER_ERROR
traceback.print_exc()
sentry_sdk.capture_exception(e)
error_msg = err_msg_for_exc(e)
except Exception as e:
# render errors nicely
except Exception as e:
run_time += time() - start_time
traceback.print_exc()
sentry_sdk.capture_exception(e)
error_msg = err_msg_for_exc(e)
break
finally:
save()
run_time += time() - start_time
finish_reason = FinishReason.SERVER_ERROR
traceback.print_exc()
sentry_sdk.capture_exception(e)
error_msg = err_msg_for_exc(e)
else:
# run completed
run_time += time() - start_time
finish_reason = FinishReason.DONE
sr.transaction, sr.price = page.deduct_credits(st.session_state)
finally:
if not finish_reason:
finish_reason = FinishReason.SERVER_ERROR
error_msg = "Something went wrong. Please try again later."
save(done=True)
if not is_api_call:
send_email_on_completion(page, sr)
Expand Down
3 changes: 2 additions & 1 deletion daras_ai_v2/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import gooey_ui as st
from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri
from daras_ai_v2 import settings
from daras_ai_v2.exceptions import UserError
from daras_ai_v2.functional import map_parallel
from daras_ai_v2.gdrive_downloader import (
is_gdrive_url,
Expand Down Expand Up @@ -570,7 +571,7 @@ def run_asr(
assert data.get("chunks"), f"{selected_model.value} can't generate VTT"
return generate_vtt(data["chunks"])
case _:
raise ValueError(f"Invalid output format: {output_format}")
raise UserError(f"Invalid output format: {output_format}")


def _get_or_create_recognizer(
Expand Down
25 changes: 23 additions & 2 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import gooey_ui as st
from app_users.models import AppUser, AppUserTransaction
from bots.models import (
FinishReason,
SavedRun,
PublishedRun,
PublishedRunVersion,
Expand All @@ -46,6 +47,7 @@
from daras_ai_v2.db import (
ANONYMOUS_USER_COOKIE,
)
from daras_ai_v2.exceptions import UserError
from daras_ai_v2.grid_layout_widget import grid_layout
from daras_ai_v2.html_spinner_widget import html_spinner
from daras_ai_v2.manage_api_keys_widget import manage_api_keys
Expand Down Expand Up @@ -95,6 +97,7 @@ class StateKeys:
created_at = "created_at"
updated_at = "updated_at"

finish_reason = "__finish_reason"
error_msg = "__error_msg"
run_time = "__run_time"
run_status = "__run_status"
Expand Down Expand Up @@ -1381,7 +1384,17 @@ def _render_completed_output(self):

def _render_failed_output(self):
err_msg = st.session_state.get(StateKeys.error_msg)
st.error(err_msg, unsafe_allow_html=True)
finish_reason = st.session_state.get(StateKeys.finish_reason)
self._render_error(err_msg, finish_reason)

def _render_error(self, error_msg: str, finish_reason: FinishReason):
match finish_reason:
case FinishReason.USER_ERROR:
st.warning(error_msg, unsafe_allow_html=True)
case FinishReason.SERVER_ERROR:
st.error(error_msg, unsafe_allow_html=True)
case _:
raise ValueError(f"invalid finish reason for error: {finish_reason}")

def _render_running_output(self):
run_status = st.session_state.get(StateKeys.run_status)
Expand Down Expand Up @@ -1421,6 +1434,7 @@ def on_submit(self):
st.session_state[StateKeys.error_msg] = self.generate_credit_error_message(
example_id, run_id, uid
)
st.session_state[StateKeys.finish_reason] = FinishReason.USER_ERROR
self.run_doc_sr(run_id, uid).set(self.state_to_doc(st.session_state))
else:
self.call_runner_task(example_id, run_id, uid)
Expand All @@ -1439,6 +1453,7 @@ def should_submit_after_login(self) -> bool:
def create_new_run(self):
st.session_state[StateKeys.run_status] = "Starting..."
st.session_state.pop(StateKeys.error_msg, None)
st.session_state.pop(StateKeys.finish_reason, None)
st.session_state.pop(StateKeys.run_time, None)
self._setup_rng_seed()
self.clear_outputs()
Expand Down Expand Up @@ -1529,6 +1544,7 @@ def _setup_rng_seed(self):
def clear_outputs(self):
# clear error msg
st.session_state.pop(StateKeys.error_msg, None)
st.session_state.pop(StateKeys.finish_reason, None)
# clear outputs
for field_name in self.ResponseModel.__fields__:
st.session_state.pop(field_name, None)
Expand Down Expand Up @@ -1608,6 +1624,7 @@ def fields_to_save(self) -> [str]:
for field_name in model.__fields__
] + [
StateKeys.error_msg,
StateKeys.finish_reason,
StateKeys.run_status,
StateKeys.run_time,
]
Expand Down Expand Up @@ -1711,7 +1728,9 @@ def _render_run_preview(self, saved_run: SavedRun):
if saved_run.run_status:
html_spinner(saved_run.run_status)
elif saved_run.error_msg:
st.error(saved_run.error_msg, unsafe_allow_html=True)
self._render_error(
error_msg=saved_run.error_msg, finish_reason=saved_run.finish_reason
)

return self.render_example(saved_run.to_dict())

Expand Down Expand Up @@ -1963,6 +1982,8 @@ def err_msg_for_exc(e):
return f"(GPU) {err_type}: {err_str}"
err_str = str(err_body)
return f"(HTTP {response.status_code}) {html.escape(err_str[:1000])}"
elif isinstance(e, UserError):
return str(e)
else:
return f"{type(e).__name__}: {e}"

Expand Down
14 changes: 14 additions & 0 deletions daras_ai_v2/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from contextlib import contextmanager
from typing import Type


class UserError(Exception):
pass


@contextmanager
def raise_as_user_error(excs: list[Type[Exception]]):
try:
yield
except tuple(excs) as e:
raise UserError(str(e)) from e
5 changes: 3 additions & 2 deletions daras_ai_v2/safety_checker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from app_users.models import AppUser
from daras_ai_v2.azure_image_moderation import is_image_nsfw
from daras_ai_v2.exceptions import UserError
from daras_ai_v2.functional import flatten
from daras_ai_v2 import settings
from recipes.CompareLLM import CompareLLMPage
Expand Down Expand Up @@ -43,14 +44,14 @@ def safety_checker_text(text_input: str):
if not lines:
continue
if lines[-1].upper().endswith("FLAGGED"):
raise ValueError(
raise UserError(
"Your request was rejected as a result of our safety system. Your prompt may contain text that is not allowed by our safety system."
)


def safety_checker_image(image_url: str, cache: bool = False) -> None:
if is_image_nsfw(image_url=image_url, cache=cache):
raise ValueError(
raise UserError(
"Your request was rejected as a result of our safety system. "
"Your input image may contain contents that are not allowed "
"by our safety system."
Expand Down
19 changes: 19 additions & 0 deletions gooey_ui/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,25 @@ def success(body: str, icon: str = "✅", *, unsafe_allow_html=False):
markdown(dedent(body), unsafe_allow_html=unsafe_allow_html)


def warning(body: str, icon: str = "⚠️", *, unsafe_allow_html=False):
if not isinstance(body, str):
body = repr(body)
with div(
style=dict(
backgroundColor="rgba(238, 210, 2, 0.2)",
padding="1rem",
paddingBottom="0",
marginBottom="0.5rem",
borderRadius="0.25rem",
display="flex",
gap="0.5rem",
)
):
markdown(icon)
with div():
markdown(dedent(body), unsafe_allow_html=unsafe_allow_html)


def caption(body: str, **props):
style = props.setdefault("style", {"fontSize": "0.9rem"})
markdown(body, className="text-muted", **props)
Expand Down
3 changes: 2 additions & 1 deletion recipes/EmailFaceInpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from bots.models import Workflow
from daras_ai.image_input import upload_file_from_bytes
from daras_ai_v2 import db, settings
from daras_ai_v2.exceptions import UserError
from daras_ai_v2.loom_video_widget import youtube_video
from daras_ai_v2.send_email import send_email_via_postmark
from daras_ai_v2.stable_diffusion import InpaintingModels
Expand Down Expand Up @@ -170,7 +171,7 @@ def validate_form_v2(self):
email_regex, email_address
), "Please provide a valid Email Address"
else:
raise AssertionError("Please provide an Email Address or Twitter Handle")
raise UserError("Please provide an Email Address or Twitter Handle")

from_email = st.session_state.get("email_from")
email_subject = st.session_state.get("email_subject")
Expand Down
Loading

0 comments on commit d816193

Please sign in to comment.