From 695117f4b91c2728728187812fa1a672f050ed7f Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Tue, 29 Oct 2024 15:07:25 +0000 Subject: [PATCH 1/2] Added retry logic to deployment requests. --- aana/deployments/aana_deployment_handle.py | 93 +++++++++++++++++++--- aana/tests/units/test_deployment_retry.py | 68 ++++++++++++++++ aana/utils/core.py | 7 +- 3 files changed, 155 insertions(+), 13 deletions(-) create mode 100644 aana/tests/units/test_deployment_retry.py diff --git a/aana/deployments/aana_deployment_handle.py b/aana/deployments/aana_deployment_handle.py index 0551ca84..3affb0c9 100644 --- a/aana/deployments/aana_deployment_handle.py +++ b/aana/deployments/aana_deployment_handle.py @@ -1,5 +1,6 @@ from ray import serve +from aana.utils.core import sleep_exponential_backoff from aana.utils.typing import is_async_generator @@ -17,17 +18,32 @@ class AanaDeploymentHandle: deployment_name (str): The name of the deployment. """ - def __init__(self, deployment_name: str): + def __init__( + self, + deployment_name: str, + num_retries: int = 3, + retry_exceptions: bool | list[Exception] = False, + retry_delay: float = 0.2, + retry_max_delay: float = 2.0, + ): """A handle to interact with a deployed Aana deployment. Args: deployment_name (str): The name of the deployment. + num_retries (int): The maximum number of retries for the method. + retry_exceptions (bool | list[Exception]): Whether to retry on application-level errors or a list of exceptions to retry on. + retry_delay (float): The initial delay between retries. + retry_max_delay (float): The maximum delay between retries. """ self.handle = serve.get_app_handle(deployment_name) self.deployment_name = deployment_name self.__methods = None + self.num_retries = num_retries + self.retry_exceptions = retry_exceptions + self.retry_delay = retry_delay + self.retry_max_delay = retry_max_delay - def __create_async_method(self, name: str): + def __create_async_method(self, name: str): # noqa: C901 """Create an method to interact with the deployment. Args: @@ -40,16 +56,54 @@ def __create_async_method(self, name: str): if is_async_generator(return_type): async def method(*args, **kwargs): - async for item in self.handle.options( - method_name=name, stream=True - ).remote(*args, **kwargs): - yield item + retries = 0 + while retries <= self.num_retries: + try: + async for item in self.handle.options( + method_name=name, stream=True + ).remote(*args, **kwargs): + yield item + break + except Exception as e: + is_retryable = self.retry_exceptions is True or ( + isinstance(self.retry_exceptions, list) + and isinstance( + e.cause.__class__, tuple(self.retry_exceptions) + ) + ) + if not is_retryable or retries >= self.num_retries: + raise + await sleep_exponential_backoff( + initial_delay=self.retry_delay, + max_delay=self.retry_max_delay, + attempts=retries, + ) + retries += 1 + else: async def method(*args, **kwargs): - return await self.handle.options(method_name=name).remote( - *args, **kwargs - ) + retries = 0 + while retries <= self.num_retries: + try: + return await self.handle.options(method_name=name).remote( + *args, **kwargs + ) + except Exception as e: # noqa: PERF203 + is_retryable = self.retry_exceptions is True or ( + isinstance(self.retry_exceptions, list) + and isinstance( + e.cause.__class__, tuple(self.retry_exceptions) + ) + ) + if not is_retryable or retries >= self.num_retries: + raise + await sleep_exponential_backoff( + initial_delay=self.retry_delay, + max_delay=self.retry_max_delay, + attempts=retries, + ) + retries += 1 if "annotations" in self.__methods[name]: method.__annotations__ = self.__methods[name]["annotations"] @@ -64,12 +118,29 @@ async def __load_methods(self): setattr(self, name, self.__create_async_method(name)) @classmethod - async def create(cls, deployment_name: str): + async def create( + cls, + deployment_name: str, + num_retries: int = 3, + retry_exceptions: bool | list[Exception] = False, + retry_delay: float = 0.2, + retry_max_delay: float = 2.0, + ): """Create a deployment handle. Args: deployment_name (str): The name of the deployment to interact with. + num_retries (int): The maximum number of retries for the method. + retry_exceptions (bool | list[Exception]): Whether to retry on application-level errors or a list of exceptions to retry on. + retry_delay (float): The initial delay between retries. + retry_max_delay (float): The maximum delay between retries. """ - handle = cls(deployment_name) + handle = cls( + deployment_name=deployment_name, + num_retries=num_retries, + retry_exceptions=retry_exceptions, + retry_delay=retry_delay, + retry_max_delay=retry_max_delay, + ) await handle.__load_methods() return handle diff --git a/aana/tests/units/test_deployment_retry.py b/aana/tests/units/test_deployment_retry.py new file mode 100644 index 00000000..37b9f2c9 --- /dev/null +++ b/aana/tests/units/test_deployment_retry.py @@ -0,0 +1,68 @@ +# ruff: noqa: S101, S113 + +import pytest +from ray import serve + +from aana.deployments.aana_deployment_handle import AanaDeploymentHandle +from aana.deployments.base_deployment import BaseDeployment, exception_handler + + +@serve.deployment(health_check_period_s=1, health_check_timeout_s=30) +class Lowercase(BaseDeployment): + """Ray deployment that returns the lowercase version of a text.""" + + def __init__(self): + """Initialize the deployment.""" + super().__init__() + self.num_requests = 0 + + @exception_handler + async def lower(self, text: str) -> dict: + """Lowercase the text. + + Args: + text (str): The text to lowercase + + Returns: + dict: The lowercase text + """ + # Only every 3rd request should be successful + self.num_requests += 1 + if self.num_requests % 3 != 0: + raise Exception("Random exception") # noqa: TRY002, TRY003 + + return {"text": text.lower()} + + +deployments = [ + { + "name": "lowercase_deployment", + "instance": Lowercase, + } +] + + +@pytest.mark.asyncio +async def test_deployment_retry(create_app): + """Test the Ray Serve app.""" + create_app(deployments, []) + + text = "Hello, World!" + + # Get deployment handle without retries + handle = await AanaDeploymentHandle.create( + "lowercase_deployment", retry_exceptions=False + ) + + # test the lowercase deployment fails + with pytest.raises(Exception): # noqa: B017 + await handle.lower(text=text) + + # Get deployment handle with retries + handle = await AanaDeploymentHandle.create( + "lowercase_deployment", retry_exceptions=True + ) + + # test the lowercase deployment works + response = await handle.lower(text=text) + assert response == {"text": text.lower()} diff --git a/aana/utils/core.py b/aana/utils/core.py index 87a12e38..baab95f3 100644 --- a/aana/utils/core.py +++ b/aana/utils/core.py @@ -1,6 +1,7 @@ import asyncio import hashlib import importlib +import random from pathlib import Path from typing import Any @@ -84,7 +85,7 @@ def get_module_dir(module_name: str) -> Path: async def sleep_exponential_backoff( initial_delay: float, max_delay: float, attempts: int ): - """Sleep for an exponentially increasing amount of time. + """Sleep for an exponentially increasing amount of time with jitter. Args: initial_delay (float): The initial delay in seconds. @@ -92,4 +93,6 @@ async def sleep_exponential_backoff( attempts (int): The number of attempts so far. """ delay = min(initial_delay * (2**attempts), max_delay) - await asyncio.sleep(delay) + # Full jitter + delay_with_jitter = random.uniform(0, delay) # noqa: S311 + await asyncio.sleep(delay_with_jitter) From c41c7c98060628fee37b1a74bd2cf069c282afb1 Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Wed, 30 Oct 2024 11:43:24 +0000 Subject: [PATCH 2/2] Refactor sleep_exponential_backoff to add optional jitter parameter --- aana/utils/core.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/aana/utils/core.py b/aana/utils/core.py index baab95f3..31d89236 100644 --- a/aana/utils/core.py +++ b/aana/utils/core.py @@ -83,7 +83,7 @@ def get_module_dir(module_name: str) -> Path: async def sleep_exponential_backoff( - initial_delay: float, max_delay: float, attempts: int + initial_delay: float, max_delay: float, attempts: int, jitter: bool = True ): """Sleep for an exponentially increasing amount of time with jitter. @@ -91,8 +91,9 @@ async def sleep_exponential_backoff( initial_delay (float): The initial delay in seconds. max_delay (float): The maximum delay in seconds. attempts (int): The number of attempts so far. + jitter (bool): Whether to add jitter to the delay. Default is True. """ delay = min(initial_delay * (2**attempts), max_delay) # Full jitter - delay_with_jitter = random.uniform(0, delay) # noqa: S311 + delay_with_jitter = random.uniform(0, delay) if jitter else delay # noqa: S311 await asyncio.sleep(delay_with_jitter)