Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add client retry support to .map #2571

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions modal/_utils/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,3 +727,72 @@ async def async_chain(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[T,
logger.exception(f"Error closing async generator: {e}")
if first_exception is not None:
raise first_exception


class TimedPriorityQueue(asyncio.PriorityQueue[tuple[float, Union[T, None]]]):
"""
A priority queue that schedules items to be processed at specific timestamps.
"""

def __init__(self, maxsize: int = 0):
super().__init__(maxsize=maxsize)
self.condition = asyncio.Condition()

async def put_with_timestamp(self, timestamp: float, item: Union[T, None]):
"""
Add an item to the queue to be processed at a specific timestamp.
"""
async with self.condition:
await super().put((timestamp, item))
self.condition.notify_all() # notify any waiting coroutines

async def get_next(self) -> Union[T, None]:
"""
Get the next item from the queue that is ready to be processed.
"""
while True:
async with self.condition:
while self.empty():
await self.condition.wait()

# peek at the next item
timestamp, item = await super().get()
now = time.time()

if timestamp > now:
# not ready yet, calculate sleep time
sleep_time = timestamp - now
self.put_nowait((timestamp, item)) # put it back

# wait until either the timeout or a new item is added
try:
await asyncio.wait_for(self.condition.wait(), timeout=sleep_time)
except asyncio.TimeoutError:
continue
else:
return item

async def batch(self, max_batch_size=100, debounce_time=0.015) -> AsyncGenerator[list[T], None]:
"""
Read from the queue but return lists of items when queue is large.

Treats a None value as the end of queue items.
"""
batch: list[T] = []
while True:
try:
item: Union[T, None] = await asyncio.wait_for(self.get_next(), timeout=debounce_time)

if item is None:
if batch:
yield batch
return
batch.append(item)

if len(batch) >= max_batch_size:
yield batch
batch = []
except asyncio.TimeoutError:
if batch:
yield batch
batch = []
1 change: 1 addition & 0 deletions modal/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ class _Setting(typing.NamedTuple):
"image_builder_version": _Setting(),
"strict_parameters": _Setting(False, transform=_to_boolean), # For internal/experimental use
"snapshot_debug": _Setting(False, transform=_to_boolean),
"client_retries": _Setting(False, transform=_to_boolean), # For internal testing.
}


Expand Down
93 changes: 87 additions & 6 deletions modal/functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright Modal Labs 2023
import asyncio
import dataclasses
import inspect
import textwrap
import time
Expand Down Expand Up @@ -26,6 +28,7 @@
from google.protobuf.message import Message
from grpclib import GRPCError, Status
from synchronicity.combined_types import MethodWithAio
from synchronicity.exceptions import UserCodeException

from modal._utils.async_utils import aclosing
from modal_proto import api_pb2
Expand Down Expand Up @@ -64,6 +67,7 @@
from .config import config
from .exception import (
ExecutionError,
FunctionTimeoutError,
InvalidError,
NotFoundError,
OutputExpiredError,
Expand All @@ -86,7 +90,7 @@
_SynchronizedQueue,
)
from .proxy import _Proxy
from .retries import Retries
from .retries import Retries, RetryManager
from .schedule import Schedule
from .scheduler_placement import SchedulerPlacement
from .secret import _Secret
Expand All @@ -98,15 +102,32 @@
import modal.partial_function


@dataclasses.dataclass
class _RetryContext:
function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType"
retry_policy: api_pb2.FunctionRetryPolicy
function_call_jwt: str
input_jwt: str
input_id: str
item: api_pb2.FunctionPutInputsItem


class _Invocation:
"""Internal client representation of a single-input call to a Modal Function or Generator"""

stub: ModalClientModal

def __init__(self, stub: ModalClientModal, function_call_id: str, client: _Client):
def __init__(
self,
stub: ModalClientModal,
function_call_id: str,
client: _Client,
retry_context: Optional[_RetryContext] = None,
):
self.stub = stub
self.client = client # Used by the deserializer.
self.function_call_id = function_call_id # TODO: remove and use only input_id
self._retry_context = retry_context

@staticmethod
async def create(
Expand All @@ -132,7 +153,17 @@ async def create(
function_call_id = response.function_call_id

if response.pipelined_inputs:
return _Invocation(client.stub, function_call_id, client)
assert len(response.pipelined_inputs) == 1
input = response.pipelined_inputs[0]
retry_context = _RetryContext(
function_call_invocation_type=function_call_invocation_type,
retry_policy=response.retry_policy,
function_call_jwt=response.function_call_jwt,
input_jwt=input.input_jwt,
input_id=input.input_id,
item=item,
)
return _Invocation(client.stub, function_call_id, client, retry_context)

request_put = api_pb2.FunctionPutInputsRequest(
function_id=function_id, inputs=[item], function_call_id=function_call_id
Expand All @@ -144,7 +175,16 @@ async def create(
processed_inputs = inputs_response.inputs
if not processed_inputs:
raise Exception("Could not create function call - the input queue seems to be full")
return _Invocation(client.stub, function_call_id, client)
input = inputs_response.inputs[0]
retry_context = _RetryContext(
function_call_invocation_type=function_call_invocation_type,
retry_policy=response.retry_policy,
function_call_jwt=response.function_call_jwt,
input_jwt=input.input_jwt,
input_id=input.input_id,
item=item,
)
return _Invocation(client.stub, function_call_id, client, retry_context)

async def pop_function_call_outputs(
self, timeout: Optional[float], clear_on_success: bool
Expand Down Expand Up @@ -180,13 +220,49 @@ async def pop_function_call_outputs(
# return the last response to check for state of num_unfinished_inputs
return response

async def run_function(self) -> Any:
async def _retry_input(self) -> None:
ctx = self._retry_context
if not ctx:
raise ValueError("Cannot retry input when _retry_context is empty.")

item = api_pb2.FunctionRetryInputsItem(input_jwt=ctx.input_jwt, input=ctx.item.input)
request = api_pb2.FunctionRetryInputsRequest(function_call_jwt=ctx.function_call_jwt, inputs=[item])
await retry_transient_errors(
self.client.stub.FunctionRetryInputs,
request,
)

async def _get_single_output(self) -> Any:
# waits indefinitely for a single result for the function, and clear the outputs buffer after
item: api_pb2.FunctionGetOutputsItem = (
await self.pop_function_call_outputs(timeout=None, clear_on_success=True)
).outputs[0]
return await _process_result(item.result, item.data_format, self.stub, self.client)

async def run_function(self) -> Any:
# Use retry logic only if retry policy is specified and
ctx = self._retry_context
if (
not ctx
or not ctx.retry_policy
or ctx.retry_policy.retries == 0
or ctx.function_call_invocation_type != api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
):
return await self._get_single_output()

# User errors including timeouts are managed by the user specified retry policy.
user_retry_manager = RetryManager(ctx.retry_policy)

while True:
try:
return await self._get_single_output()
except (UserCodeException, FunctionTimeoutError) as exc:
delay_ms = user_retry_manager.get_delay_ms()
if delay_ms is None:
raise exc
await asyncio.sleep(delay_ms / 1000)
await self._retry_input()

async def poll_function(self, timeout: Optional[float] = None):
"""Waits up to timeout for a result from a function.

Expand Down Expand Up @@ -1323,13 +1399,18 @@ async def _map(
yield item

async def _call_function(self, args, kwargs) -> ReturnType:
if config.get("client_retries"):
function_call_invocation_type = api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
else:
function_call_invocation_type = api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC_LEGACY
invocation = await _Invocation.create(
self,
args,
kwargs,
client=self._client,
function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC_LEGACY,
function_call_invocation_type=function_call_invocation_type,
)

return await invocation.run_function()

async def _call_function_nowait(
Expand Down
Loading
Loading