Skip to content

Commit

Permalink
Added retries to inference requests using tenacity
Browse files Browse the repository at this point in the history
  • Loading branch information
Edward Keeble committed Nov 23, 2023
1 parent 647e4fb commit bc1c419
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
25 changes: 25 additions & 0 deletions cerulean_cloud/cloud_run_orchestrator/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from rasterio.io import MemoryFile
from rasterio.plot import reshape_as_raster
from rio_tiler.io import COGReader
from tenacity import RetryCallState, retry, stop_after_attempt, wait_random_exponential

from cerulean_cloud.cloud_run_offset_tiles.schema import (
InferenceInput,
Expand All @@ -42,6 +43,13 @@ def img_array_to_b64_image(img_array: np.ndarray) -> str:
return b64encode(img_bytes).decode("ascii")


def report_inference_retry(retry_state: RetryCallState):
"""report retry"""
print(
f"Retrying {retry_state.fn.__name__} for Scene ID {retry_state.args[0].sceneid} due to {retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number}. Retrying in {retry_state.next_action.sleep} seconds."
)


class CloudRunInferenceClient:
"""Client for inference cloud run"""

Expand Down Expand Up @@ -71,6 +79,13 @@ def __init__(
self.inference_parms = inference_parms
self.jitter = jitter

@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(multiplier=1, max=10),
retry_if_exception_type=httpx.TransportError,
reraise=True,
before_sleep=report_inference_retry,
)
async def get_base_tile_inference(
self,
tile: morecantile.Tile,
Expand Down Expand Up @@ -105,9 +120,11 @@ async def get_base_tile_inference(
jit = random.uniform(0, self.jitter)
print(f"Jittering by {jit} seconds")
await asyncio.sleep(jit)

res = await http_client.post(
self.url + "/predict", json=payload.dict(), timeout=None
)

if res.status_code == 200:
return InferenceResultStack(**res.json())
else:
Expand All @@ -116,6 +133,13 @@ async def get_base_tile_inference(
f"XXX Received unexpected status code: {res.status_code} {res.content}"
)

@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(multiplier=1, max=10),
retry_if_exception_type=httpx.TransportError,
reraise=True,
before_sleep=report_inference_retry,
)
async def get_offset_tile_inference(
self, bounds: List[float], http_client: httpx.AsyncClient, rescale=(0, 255)
) -> InferenceResultStack:
Expand Down Expand Up @@ -148,6 +172,7 @@ async def get_offset_tile_inference(
jit = random.uniform(0, self.jitter)
print(f"Jittering by {jit} seconds")
await asyncio.sleep(jit)

res = await http_client.post(
self.url + "/predict", json=payload.dict(), timeout=None
)
Expand Down
3 changes: 2 additions & 1 deletion cerulean_cloud/cloud_run_orchestrator/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ geopandas
pygeos
networkx
google-cloud-tasks
protobuf
protobuf
tenacity==8.2.3

0 comments on commit bc1c419

Please sign in to comment.