Skip to content

Commit

Permalink
Try to improve error handling during inference
Browse files Browse the repository at this point in the history
  • Loading branch information
jonaraphael committed Oct 30, 2023
1 parent 7ed66cd commit 9833ea8
Showing 1 changed file with 57 additions and 42 deletions.
99 changes: 57 additions & 42 deletions cerulean_cloud/cloud_run_orchestrator/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""
import asyncio
import os
import traceback
import urllib.parse as urlparse
from base64 import b64decode # , b64encode
from datetime import datetime, timedelta
Expand Down Expand Up @@ -245,6 +246,42 @@ def flatten_feature_list(
return flat_list


async def perform_inference(tiles, inference_func, semaphore_value, description):
"""
Perform inference on a set of tiles asynchronously.
Parameters:
- tiles (list): List of tiles to perform inference on.
- inference_func (function): Asynchronous function to call for inference.
- semaphore_value (int): Maximum number of concurrent tasks.
- description (str): Description of the inference task for logging.
Returns:
- list: List of inference results. Exceptions, if any, are filtered out.
Side Effects:
- Prints log messages and warnings to the console.
- Prints traceback of exceptions to the console.
"""
print(f"Inference on {description}!")
semaphore = asyncio.Semaphore(value=semaphore_value)
inferences = await asyncio.gather(
*[
inference_func(tile=tile, rescale=(0, 255), semaphore=semaphore)
for tile in tiles
],
return_exceptions=True,
)
clean_inferences = []
for res in inferences:
if isinstance(res, Exception):
print(f"WARNING: Exception occurred during {description} inference: {res}")
traceback.print_tb(res.__traceback__)
else:
clean_inferences.append(res)
return clean_inferences


async def _orchestrate(
payload, tiler, titiler_client, roda_sentinelhub_client, db_engine
):
Expand Down Expand Up @@ -357,35 +394,21 @@ async def _orchestrate(
inference_parms=inference_parms,
)

print("Inference on base tiles!")
base_tile_semaphore = asyncio.Semaphore(value=20)
base_tiles_inference = await asyncio.gather(
*[
cloud_run_inference.get_base_tile_inference(
tile=base_tile,
rescale=(0, 255),
semaphore=base_tile_semaphore,
)
for base_tile in base_tiles
],
return_exceptions=True,
base_tiles_inference = await perform_inference(
base_tiles,
cloud_run_inference.get_base_tile_inference,
20,
"base tiles",
)

print("Inference on offset tiles!")
offset_tile_semaphore = asyncio.Semaphore(value=20)
offset_tiles_inference = await asyncio.gather(
*[
cloud_run_inference.get_offset_tile_inference(
bounds=offset_tile_bounds,
rescale=(0, 255),
semaphore=offset_tile_semaphore,
)
for offset_tile_bounds in offset_tiles_bounds
],
return_exceptions=True,
offset_tiles_inference = await perform_inference(
offset_tiles_bounds,
cloud_run_inference.get_offset_tile_inference,
20,
"offset tiles",
)

if base_tiles_inference[0].stack[0].dict().get("classes"):
if model.type == "MASKRCNN":
print("Loading all tiles into memory for merge!")
ds_base_tiles = []
for base_tile_inference in base_tiles_inference:
Expand Down Expand Up @@ -436,24 +459,16 @@ async def _orchestrate(
dst.write(ar)

out_fc_offset = get_fc_from_raster(offset_tile_inference_file)

elif model.type == "UNET":
# XXX UNTESTED PATHWAY
out_fc = geojson.FeatureCollection(
features=flatten_feature_list(base_tiles_inference)
)
out_fc_offset = geojson.FeatureCollection(
features=flatten_feature_list(offset_tiles_inference)
)
else:
try:
out_fc = geojson.FeatureCollection(
features=flatten_feature_list(base_tiles_inference)
)
out_fc_offset = geojson.FeatureCollection(
features=flatten_feature_list(offset_tiles_inference)
)
except AttributeError as e:
print(f"YYY error details: {e}")
print(f"YYY base_tiles_inference: {base_tiles_inference}")
print(f"YYY offset_tiles_inference: {offset_tiles_inference}")
for r in base_tiles_inference:
print(f"YYY [r for r in base_tiles_inference]: {r}")
for r in offset_tiles_inference:
print(f"YYY [r for r in offset_tiles_inference]: {r}")
raise e
raise Exception("Unrecognized model type")

# XXXBUG ValueError: Cannot determine common CRS for concatenation inputs, got ['WGS 84 / UTM zone 28N', 'WGS 84 / UTM zone 29N']. Use `to_crs()` to transform geometries to the same CRS before merging."
# Example: S1A_IW_GRDH_1SDV_20230727T185101_20230727T185126_049613_05F744_1E56
Expand Down

0 comments on commit 9833ea8

Please sign in to comment.