diff --git a/cerulean_cloud/cloud_run_orchestrator/handler.py b/cerulean_cloud/cloud_run_orchestrator/handler.py index 8cc2adc8..54b88826 100644 --- a/cerulean_cloud/cloud_run_orchestrator/handler.py +++ b/cerulean_cloud/cloud_run_orchestrator/handler.py @@ -10,6 +10,7 @@ """ import asyncio import os +import traceback import urllib.parse as urlparse from base64 import b64decode # , b64encode from datetime import datetime, timedelta @@ -245,6 +246,42 @@ def flatten_feature_list( return flat_list +async def perform_inference(tiles, inference_func, semaphore_value, description): + """ + Perform inference on a set of tiles asynchronously. + + Parameters: + - tiles (list): List of tiles to perform inference on. + - inference_func (function): Asynchronous function to call for inference. + - semaphore_value (int): Maximum number of concurrent tasks. + - description (str): Description of the inference task for logging. + + Returns: + - list: List of inference results. Exceptions, if any, are filtered out. + + Side Effects: + - Prints log messages and warnings to the console. + - Prints traceback of exceptions to the console. + """ + print(f"Inference on {description}!") + semaphore = asyncio.Semaphore(value=semaphore_value) + inferences = await asyncio.gather( + *[ + inference_func(tile=tile, rescale=(0, 255), semaphore=semaphore) + for tile in tiles + ], + return_exceptions=True, + ) + clean_inferences = [] + for res in inferences: + if isinstance(res, Exception): + print(f"WARNING: Exception occurred during {description} inference: {res}") + traceback.print_tb(res.__traceback__) + else: + clean_inferences.append(res) + return clean_inferences + + async def _orchestrate( payload, tiler, titiler_client, roda_sentinelhub_client, db_engine ): @@ -357,35 +394,21 @@ async def _orchestrate( inference_parms=inference_parms, ) - print("Inference on base tiles!") - base_tile_semaphore = asyncio.Semaphore(value=20) - base_tiles_inference = await asyncio.gather( - *[ - cloud_run_inference.get_base_tile_inference( - tile=base_tile, - rescale=(0, 255), - semaphore=base_tile_semaphore, - ) - for base_tile in base_tiles - ], - return_exceptions=True, + base_tiles_inference = await perform_inference( + base_tiles, + cloud_run_inference.get_base_tile_inference, + 20, + "base tiles", ) - print("Inference on offset tiles!") - offset_tile_semaphore = asyncio.Semaphore(value=20) - offset_tiles_inference = await asyncio.gather( - *[ - cloud_run_inference.get_offset_tile_inference( - bounds=offset_tile_bounds, - rescale=(0, 255), - semaphore=offset_tile_semaphore, - ) - for offset_tile_bounds in offset_tiles_bounds - ], - return_exceptions=True, + offset_tiles_inference = await perform_inference( + offset_tiles_bounds, + cloud_run_inference.get_offset_tile_inference, + 20, + "offset tiles", ) - if base_tiles_inference[0].stack[0].dict().get("classes"): + if model.type == "MASKRCNN": print("Loading all tiles into memory for merge!") ds_base_tiles = [] for base_tile_inference in base_tiles_inference: @@ -436,24 +459,16 @@ async def _orchestrate( dst.write(ar) out_fc_offset = get_fc_from_raster(offset_tile_inference_file) - + elif model.type == "UNET": + # XXX UNTESTED PATHWAY + out_fc = geojson.FeatureCollection( + features=flatten_feature_list(base_tiles_inference) + ) + out_fc_offset = geojson.FeatureCollection( + features=flatten_feature_list(offset_tiles_inference) + ) else: - try: - out_fc = geojson.FeatureCollection( - features=flatten_feature_list(base_tiles_inference) - ) - out_fc_offset = geojson.FeatureCollection( - features=flatten_feature_list(offset_tiles_inference) - ) - except AttributeError as e: - print(f"YYY error details: {e}") - print(f"YYY base_tiles_inference: {base_tiles_inference}") - print(f"YYY offset_tiles_inference: {offset_tiles_inference}") - for r in base_tiles_inference: - print(f"YYY [r for r in base_tiles_inference]: {r}") - for r in offset_tiles_inference: - print(f"YYY [r for r in offset_tiles_inference]: {r}") - raise e + raise Exception("Unrecognized model type") # XXXBUG ValueError: Cannot determine common CRS for concatenation inputs, got ['WGS 84 / UTM zone 28N', 'WGS 84 / UTM zone 29N']. Use `to_crs()` to transform geometries to the same CRS before merging." # Example: S1A_IW_GRDH_1SDV_20230727T185101_20230727T185126_049613_05F744_1E56