From eadda21b55862616e53eace2c597834aab73f33e Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 28 Feb 2024 22:03:48 +0530 Subject: [PATCH] catch gpu task errors as user errors --- celeryapp/tasks.py | 1 + daras_ai_v2/base.py | 4 ++-- daras_ai_v2/exceptions.py | 4 ++++ daras_ai_v2/gpu_server.py | 8 ++++++-- gooeysite/bg_db_conn.py | 6 ++++-- 5 files changed, 17 insertions(+), 6 deletions(-) diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 2cf56fcdc..ca71f2f4c 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -40,6 +40,7 @@ def gui_runner( is_api_call: bool = False, ): page = page_cls(request=SimpleNamespace(user=AppUser.objects.get(id=user_id))) + page.setup_sentry() sr = page.run_doc_sr(run_id, uid) sr.is_api_call = is_api_call diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 26825a087..ecc415d6b 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -179,7 +179,7 @@ def get_tab_url(self, tab: str) -> str: tab_name=MenuTabs.paths[tab], ) - def setup_render(self): + def setup_sentry(self): with sentry_sdk.configure_scope() as scope: scope.set_extra("base_url", self.app_url()) scope.set_transaction_name( @@ -194,7 +194,7 @@ def refresh_state(self): st.session_state.update(output) def render(self): - self.setup_render() + self.setup_sentry() if self.get_run_state(st.session_state) == RecipeRunState.running: self.refresh_state() diff --git a/daras_ai_v2/exceptions.py b/daras_ai_v2/exceptions.py index 329e80ba4..e97ac7190 100644 --- a/daras_ai_v2/exceptions.py +++ b/daras_ai_v2/exceptions.py @@ -45,6 +45,10 @@ def __init__(self, message: str, sentry_level: str = "info"): super().__init__(message) +class GPUError(UserError): + pass + + FFMPEG_ERR_MSG = ( "Unsupported File Format\n\n" "We encountered an issue processing your file as it appears to be in a format not supported by our system or may be corrupted. " diff --git a/daras_ai_v2/gpu_server.py b/daras_ai_v2/gpu_server.py index 329a33fbf..33b6ad935 100644 --- a/daras_ai_v2/gpu_server.py +++ b/daras_ai_v2/gpu_server.py @@ -9,7 +9,7 @@ from daras_ai.image_input import storage_blob_for from daras_ai_v2 import settings -from daras_ai_v2.exceptions import raise_for_status +from daras_ai_v2.exceptions import raise_for_status, GPUError from gooeysite.bg_db_conn import get_celery_result_db_safe @@ -160,7 +160,11 @@ def call_celery_task( task_name, kwargs=dict(pipeline=pipeline, inputs=inputs), queue=queue ) s = time() - ret = get_celery_result_db_safe(result) + ret = get_celery_result_db_safe(result, propagate=False) + try: + result.maybe_throw() + except Exception as e: + raise GPUError(f"Error in GPU Task {queue}:{task_name} - {e}") from e record_cost_auto( model=queue, sku=ModelSku.gpu_ms, quantity=int((time() - s) * 1000) ) diff --git a/gooeysite/bg_db_conn.py b/gooeysite/bg_db_conn.py index 9c7680df9..0c36daca8 100644 --- a/gooeysite/bg_db_conn.py +++ b/gooeysite/bg_db_conn.py @@ -31,5 +31,7 @@ def wrapper(*args, **kwargs): @db_middleware -def get_celery_result_db_safe(result: "celery.result.AsyncResult") -> typing.Any: - return result.get(disable_sync_subtasks=False) +def get_celery_result_db_safe( + result: "celery.result.AsyncResult", **kwargs +) -> typing.Any: + return result.get(disable_sync_subtasks=False, **kwargs)