diff --git a/cerulean_cloud/cloud_run_orchestrator/clients.py b/cerulean_cloud/cloud_run_orchestrator/clients.py index 21087357..e245b95a 100644 --- a/cerulean_cloud/cloud_run_orchestrator/clients.py +++ b/cerulean_cloud/cloud_run_orchestrator/clients.py @@ -16,6 +16,7 @@ from rasterio.io import MemoryFile from rasterio.plot import reshape_as_raster from rio_tiler.io import COGReader +from tenacity import RetryCallState, retry, stop_after_attempt, wait_random_exponential from cerulean_cloud.cloud_run_offset_tiles.schema import ( InferenceInput, @@ -42,6 +43,13 @@ def img_array_to_b64_image(img_array: np.ndarray) -> str: return b64encode(img_bytes).decode("ascii") +def report_inference_retry(retry_state: RetryCallState): + """report retry""" + print( + f"Retrying {retry_state.fn.__name__} for Scene ID {retry_state.args[0].sceneid} due to {retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number}. Retrying in {retry_state.next_action.sleep} seconds." + ) + + class CloudRunInferenceClient: """Client for inference cloud run""" @@ -71,6 +79,13 @@ def __init__( self.inference_parms = inference_parms self.jitter = jitter + @retry( + stop=stop_after_attempt(3), + wait=wait_random_exponential(multiplier=1, max=10), + retry_if_exception_type=httpx.TransportError, + reraise=True, + before_sleep=report_inference_retry, + ) async def get_base_tile_inference( self, tile: morecantile.Tile, @@ -105,9 +120,11 @@ async def get_base_tile_inference( jit = random.uniform(0, self.jitter) print(f"Jittering by {jit} seconds") await asyncio.sleep(jit) + res = await http_client.post( self.url + "/predict", json=payload.dict(), timeout=None ) + if res.status_code == 200: return InferenceResultStack(**res.json()) else: @@ -116,6 +133,13 @@ async def get_base_tile_inference( f"XXX Received unexpected status code: {res.status_code} {res.content}" ) + @retry( + stop=stop_after_attempt(3), + wait=wait_random_exponential(multiplier=1, max=10), + retry_if_exception_type=httpx.TransportError, + reraise=True, + before_sleep=report_inference_retry, + ) async def get_offset_tile_inference( self, bounds: List[float], http_client: httpx.AsyncClient, rescale=(0, 255) ) -> InferenceResultStack: @@ -148,6 +172,7 @@ async def get_offset_tile_inference( jit = random.uniform(0, self.jitter) print(f"Jittering by {jit} seconds") await asyncio.sleep(jit) + res = await http_client.post( self.url + "/predict", json=payload.dict(), timeout=None ) diff --git a/cerulean_cloud/cloud_run_orchestrator/requirements.txt b/cerulean_cloud/cloud_run_orchestrator/requirements.txt index 3423ef02..2b0e9dba 100644 --- a/cerulean_cloud/cloud_run_orchestrator/requirements.txt +++ b/cerulean_cloud/cloud_run_orchestrator/requirements.txt @@ -24,4 +24,5 @@ geopandas pygeos networkx google-cloud-tasks -protobuf \ No newline at end of file +protobuf +tenacity==8.2.3