Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retry Mechanism for Deployment Requests #193

Merged
merged 2 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 82 additions & 11 deletions aana/deployments/aana_deployment_handle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ray import serve

from aana.utils.core import sleep_exponential_backoff
from aana.utils.typing import is_async_generator


Expand All @@ -17,17 +18,32 @@ class AanaDeploymentHandle:
deployment_name (str): The name of the deployment.
"""

def __init__(self, deployment_name: str):
def __init__(
self,
deployment_name: str,
num_retries: int = 3,
retry_exceptions: bool | list[Exception] = False,
retry_delay: float = 0.2,
retry_max_delay: float = 2.0,
):
"""A handle to interact with a deployed Aana deployment.

Args:
deployment_name (str): The name of the deployment.
num_retries (int): The maximum number of retries for the method.
retry_exceptions (bool | list[Exception]): Whether to retry on application-level errors or a list of exceptions to retry on.
retry_delay (float): The initial delay between retries.
retry_max_delay (float): The maximum delay between retries.
"""
self.handle = serve.get_app_handle(deployment_name)
self.deployment_name = deployment_name
self.__methods = None
self.num_retries = num_retries
self.retry_exceptions = retry_exceptions
self.retry_delay = retry_delay
self.retry_max_delay = retry_max_delay

def __create_async_method(self, name: str):
def __create_async_method(self, name: str): # noqa: C901
"""Create an method to interact with the deployment.

Args:
Expand All @@ -40,16 +56,54 @@ def __create_async_method(self, name: str):
if is_async_generator(return_type):

async def method(*args, **kwargs):
async for item in self.handle.options(
method_name=name, stream=True
).remote(*args, **kwargs):
yield item
retries = 0
while retries <= self.num_retries:
try:
async for item in self.handle.options(
method_name=name, stream=True
).remote(*args, **kwargs):
yield item
break
except Exception as e:
is_retryable = self.retry_exceptions is True or (
isinstance(self.retry_exceptions, list)
and isinstance(
e.cause.__class__, tuple(self.retry_exceptions)
)
)
if not is_retryable or retries >= self.num_retries:
raise
await sleep_exponential_backoff(
initial_delay=self.retry_delay,
max_delay=self.retry_max_delay,
attempts=retries,
)
retries += 1

else:

async def method(*args, **kwargs):
return await self.handle.options(method_name=name).remote(
*args, **kwargs
)
retries = 0
while retries <= self.num_retries:
try:
return await self.handle.options(method_name=name).remote(
*args, **kwargs
)
except Exception as e: # noqa: PERF203
is_retryable = self.retry_exceptions is True or (
isinstance(self.retry_exceptions, list)
and isinstance(
e.cause.__class__, tuple(self.retry_exceptions)
)
)
if not is_retryable or retries >= self.num_retries:
raise
await sleep_exponential_backoff(
initial_delay=self.retry_delay,
max_delay=self.retry_max_delay,
attempts=retries,
)
retries += 1

if "annotations" in self.__methods[name]:
method.__annotations__ = self.__methods[name]["annotations"]
Expand All @@ -64,12 +118,29 @@ async def __load_methods(self):
setattr(self, name, self.__create_async_method(name))

@classmethod
async def create(cls, deployment_name: str):
async def create(
cls,
deployment_name: str,
num_retries: int = 3,
retry_exceptions: bool | list[Exception] = False,
retry_delay: float = 0.2,
retry_max_delay: float = 2.0,
):
"""Create a deployment handle.

Args:
deployment_name (str): The name of the deployment to interact with.
num_retries (int): The maximum number of retries for the method.
retry_exceptions (bool | list[Exception]): Whether to retry on application-level errors or a list of exceptions to retry on.
retry_delay (float): The initial delay between retries.
retry_max_delay (float): The maximum delay between retries.
"""
handle = cls(deployment_name)
handle = cls(
deployment_name=deployment_name,
num_retries=num_retries,
retry_exceptions=retry_exceptions,
retry_delay=retry_delay,
retry_max_delay=retry_max_delay,
)
await handle.__load_methods()
return handle
68 changes: 68 additions & 0 deletions aana/tests/units/test_deployment_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# ruff: noqa: S101, S113

import pytest
from ray import serve

from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
from aana.deployments.base_deployment import BaseDeployment, exception_handler


@serve.deployment(health_check_period_s=1, health_check_timeout_s=30)
class Lowercase(BaseDeployment):
"""Ray deployment that returns the lowercase version of a text."""

def __init__(self):
"""Initialize the deployment."""
super().__init__()
self.num_requests = 0

@exception_handler
async def lower(self, text: str) -> dict:
"""Lowercase the text.

Args:
text (str): The text to lowercase

Returns:
dict: The lowercase text
"""
# Only every 3rd request should be successful
self.num_requests += 1
if self.num_requests % 3 != 0:
raise Exception("Random exception") # noqa: TRY002, TRY003

return {"text": text.lower()}


deployments = [
{
"name": "lowercase_deployment",
"instance": Lowercase,
}
]


@pytest.mark.asyncio
async def test_deployment_retry(create_app):
"""Test the Ray Serve app."""
create_app(deployments, [])

text = "Hello, World!"

# Get deployment handle without retries
handle = await AanaDeploymentHandle.create(
"lowercase_deployment", retry_exceptions=False
)

# test the lowercase deployment fails
with pytest.raises(Exception): # noqa: B017
await handle.lower(text=text)

# Get deployment handle with retries
handle = await AanaDeploymentHandle.create(
"lowercase_deployment", retry_exceptions=True
)

# test the lowercase deployment works
response = await handle.lower(text=text)
assert response == {"text": text.lower()}
7 changes: 5 additions & 2 deletions aana/utils/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import hashlib
import importlib
import random
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -84,12 +85,14 @@ def get_module_dir(module_name: str) -> Path:
async def sleep_exponential_backoff(
initial_delay: float, max_delay: float, attempts: int
):
"""Sleep for an exponentially increasing amount of time.
"""Sleep for an exponentially increasing amount of time with jitter.

Args:
initial_delay (float): The initial delay in seconds.
max_delay (float): The maximum delay in seconds.
attempts (int): The number of attempts so far.
"""
delay = min(initial_delay * (2**attempts), max_delay)
await asyncio.sleep(delay)
# Full jitter
delay_with_jitter = random.uniform(0, delay) # noqa: S311
movchan74 marked this conversation as resolved.
Show resolved Hide resolved
await asyncio.sleep(delay_with_jitter)
Loading