Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] Add per-request replica failure cache in LB to reduce redundant retries #3916

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion sky/serve/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# time the load balancer syncs with controller, it will update all available
# replica ips for each service, also send the number of requests in last query
# interval.
LB_CONTROLLER_SYNC_INTERVAL_SECONDS = 20
LB_CONTROLLER_SYNC_INTERVAL_SECONDS = 10

# The maximum retry times for load balancer for each request. After changing to
# proxy implementation, we do retry for failed requests.
Expand Down
7 changes: 5 additions & 2 deletions sky/serve/load_balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import asyncio
import logging
import threading
from typing import Dict, Union
from typing import Dict, List, Union

import aiohttp
import fastapi
Expand Down Expand Up @@ -160,11 +160,12 @@ async def _proxy_with_retries(
# SkyServe supports serving on Spot Instances. To avoid preemptions
# during request handling, we add a retry here.
retry_cnt = 0
failed_replica_urls: List[str] = []
andylizf marked this conversation as resolved.
Show resolved Hide resolved
while True:
retry_cnt += 1
with self._client_pool_lock:
ready_replica_url = self._load_balancing_policy.select_replica(
request)
request, failed_replica_urls)
andylizf marked this conversation as resolved.
Show resolved Hide resolved
if ready_replica_url is None:
response_or_exception = fastapi.HTTPException(
# 503 means that the server is currently
Expand All @@ -184,6 +185,8 @@ async def _proxy_with_retries(
# 499 means a client terminates the connection
# before the server is able to respond.
return fastapi.responses.Response(status_code=499)
assert ready_replica_url is not None
failed_replica_urls.append(ready_replica_url)
# TODO(tian): Fail fast for errors like 404 not found.
if retry_cnt == constants.LB_MAX_RETRY:
if isinstance(response_or_exception, fastapi.HTTPException):
Expand Down
22 changes: 15 additions & 7 deletions sky/serve/load_balancing_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ def __init__(self) -> None:
def set_ready_replicas(self, ready_replicas: List[str]) -> None:
raise NotImplementedError

def select_replica(self, request: 'fastapi.Request') -> Optional[str]:
replica = self._select_replica(request)
def select_replica(self, request: 'fastapi.Request',
disabled_urls: List[str]) -> Optional[str]:
replica = self._select_replica(request, disabled_urls)
if replica is not None:
logger.info(f'Selected replica {replica} '
f'for request {_request_repr(request)}')
Expand All @@ -40,7 +41,8 @@ def select_replica(self, request: 'fastapi.Request') -> Optional[str]:

# TODO(tian): We should have an abstract class for Request to
# compatible with all frameworks.
def _select_replica(self, request: 'fastapi.Request') -> Optional[str]:
def _select_replica(self, request: 'fastapi.Request',
disabled_urls: List[str]) -> Optional[str]:
raise NotImplementedError


Expand All @@ -61,10 +63,16 @@ def set_ready_replicas(self, ready_replicas: List[str]) -> None:
self.ready_replicas = ready_replicas
self.index = 0

def _select_replica(self, request: 'fastapi.Request') -> Optional[str]:
def _select_replica(self, request: 'fastapi.Request',
disabled_urls: List[str]) -> Optional[str]:
del request # Unused.
if not self.ready_replicas:
return None
ready_replica_url = self.ready_replicas[self.index]
self.index = (self.index + 1) % len(self.ready_replicas)
return ready_replica_url
check_disable = True
if all(url in disabled_urls for url in self.ready_replicas):
check_disable = False
while True:
ready_replica_url = self.ready_replicas[self.index]
self.index = (self.index + 1) % len(self.ready_replicas)
if not check_disable or ready_replica_url not in disabled_urls:
return ready_replica_url
Loading