Skip to content

Commit

Permalink
Fix: handle uncaught exception only for Serverless workers (#388)
Browse files Browse the repository at this point in the history
* refactor: moved handle_uncaught_exception to rp_scale
* refactor: bind handle_uncaught_exception on JobScaler init
* fix: python <3.11 compatibility
  • Loading branch information
deanq authored Jan 2, 2025
1 parent d7a2131 commit 3f78233
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 7 deletions.
7 changes: 0 additions & 7 deletions runpod/serverless/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import signal
import sys
import time
import typing
from typing import Any, Dict

from runpod.serverless import core
Expand All @@ -24,12 +23,6 @@
log = RunPodLogger()


def handle_uncaught_exception(exc_type, exc_value, exc_traceback):
log.error(f"Uncaught exception | {exc_type}; {exc_value}; {exc_traceback};")

sys.excepthook = handle_uncaught_exception


# ---------------------------------------------------------------------------- #
# Run Time Arguments #
# ---------------------------------------------------------------------------- #
Expand Down
9 changes: 9 additions & 0 deletions runpod/serverless/modules/rp_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import asyncio
import signal
import sys
import traceback
from typing import Any, Dict

from ...http_client import AsyncClientSession, ClientSession, TooManyRequests
Expand All @@ -16,6 +18,11 @@
job_progress = JobsProgress()


def _handle_uncaught_exception(exc_type, exc_value, exc_traceback):
exc = traceback.format_exception(exc_type, exc_value, exc_traceback)
log.error(f"Uncaught exception | {exc}")


def _default_concurrency_modifier(current_concurrency: int) -> int:
"""
Default concurrency modifier.
Expand Down Expand Up @@ -87,6 +94,8 @@ def start(self):
when the user sends a SIGTERM or SIGINT signal. This is typically
the case when the worker is running in a container.
"""
sys.excepthook = _handle_uncaught_exception

try:
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, self.handle_shutdown)
Expand Down
54 changes: 54 additions & 0 deletions tests/test_serverless/test_modules/test_scale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import sys
import traceback
from unittest import TestCase
from unittest.mock import patch

from runpod.serverless.modules.rp_scale import _handle_uncaught_exception


class TestHandleUncaughtException(TestCase):
def setUp(self):
sys.excepthook = sys.__excepthook__

@patch("runpod.serverless.modules.rp_scale.log")
def test_handle_uncaught_exception(self, mock_logger):
exc_type = ValueError
exc_value = ValueError("This is a test error")
exc_traceback = None # No traceback for simplicity

_handle_uncaught_exception(exc_type, exc_value, exc_traceback)

formatted_exception = traceback.format_exception(exc_type, exc_value, exc_traceback)

mock_logger.error.assert_called_once()
log_message = mock_logger.error.call_args[0][0]
assert "Uncaught exception | " in log_message
assert str(formatted_exception) in log_message

@patch("runpod.serverless.modules.rp_scale.log")
def test_handle_uncaught_exception_with_traceback(self, mock_logger):
try:
raise RuntimeError("This is a runtime error")
except RuntimeError:
exc_type, exc_value, exc_traceback = sys.exc_info()

_handle_uncaught_exception(exc_type, exc_value, exc_traceback)

formatted_exception = traceback.format_exception(exc_type, exc_value, exc_traceback)

mock_logger.error.assert_called_once()
log_message = mock_logger.error.call_args[0][0]
assert "Uncaught exception | " in log_message
assert str(formatted_exception) in log_message

@patch("runpod.serverless.modules.rp_scale.log")
def test_handle_uncaught_exception_with_no_exception(self, mock_logger):
_handle_uncaught_exception(None, None, None)

mock_logger.error.assert_called_once()
log_message = mock_logger.error.call_args[0][0]
assert "Uncaught exception | " in log_message

def test_excepthook_not_set_when_start_not_invoked(self):
assert sys.excepthook == sys.__excepthook__
assert sys.excepthook != _handle_uncaught_exception
26 changes: 26 additions & 0 deletions tests/test_serverless/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import argparse
import os
import sys
from unittest import mock
from unittest.mock import patch, mock_open, Mock, MagicMock

Expand All @@ -12,6 +13,7 @@

import runpod
from runpod.serverless.modules.rp_logger import RunPodLogger
from runpod.serverless.modules.rp_scale import _handle_uncaught_exception
from runpod.serverless import _signal_handler

nest_asyncio.apply()
Expand Down Expand Up @@ -187,6 +189,9 @@ async def asyncSetUp(self):
"rp_args": {"rp_debugger": True, "rp_log_level": "DEBUG"},
}

async def asyncTearDown(self):
sys.excepthook = sys.__excepthook__

@patch("runpod.serverless.modules.rp_scale.AsyncClientSession")
@patch("runpod.serverless.modules.rp_scale.get_job")
@patch("runpod.serverless.modules.rp_job.run_job")
Expand Down Expand Up @@ -543,3 +548,24 @@ async def test_run_worker_with_sls_core(self):
os.environ.pop("RUNPOD_USE_CORE")

assert mock_main.called

@patch("runpod.serverless.signal.signal")
@patch("runpod.serverless.worker.rp_scale.JobScaler.run")
def test_start_sets_excepthook(self, _, __):
runpod.serverless.start({})
assert sys.excepthook == _handle_uncaught_exception

@patch("runpod.serverless.signal.signal")
@patch("runpod.serverless.rp_fastapi.WorkerAPI.start_uvicorn")
@patch("runpod.serverless._set_config_args")
def test_start_does_not_set_excepthook(self, mock_set_config_args, _, __):
mock_set_config_args.return_value = self.config
self.config.update({"rp_args": {
"rp_serve_api": True,
"rp_api_host": "localhost",
"rp_api_port": 8000,
"rp_api_concurrency": 1,
}})

runpod.serverless.start(self.config)
assert sys.excepthook != _handle_uncaught_exception

0 comments on commit 3f78233

Please sign in to comment.