diff --git a/cerulean_cloud/cloud_run_orchestrator/handler.py b/cerulean_cloud/cloud_run_orchestrator/handler.py index 4bc899bd..22d22ef8 100644 --- a/cerulean_cloud/cloud_run_orchestrator/handler.py +++ b/cerulean_cloud/cloud_run_orchestrator/handler.py @@ -26,7 +26,8 @@ from fastapi.middleware.cors import CORSMiddleware from global_land_mask import globe from rasterio.io import MemoryFile -from rasterio.merge import merge + +# from rasterio.merge import merge from shapely.geometry import shape from cerulean_cloud.auth import api_key_auth @@ -420,64 +421,64 @@ async def _orchestrate( base_tiles_inference[0].stack[0].dict().get("classes"), ) - if model.type == "UNET": - print("Loading all tiles into memory for merge!") - ds_base_tiles = [] - for base_tile_inference in base_tiles_inference: - ds_base_tiles.append( - *[ - create_dataset_from_inference_result(b) - for b in base_tile_inference.stack - ] - ) - - ds_offset_tiles = [] - for offset_tile_inference in offset_tiles_inference: - ds_offset_tiles.append( - *[ - create_dataset_from_inference_result(b) - for b in offset_tile_inference.stack - ] - ) - - print("Merging base tiles!") - base_tile_inference_file = MemoryFile() - ar, transform = merge(ds_base_tiles) - with base_tile_inference_file.open( - driver="GTiff", - height=ar.shape[1], - width=ar.shape[2], - count=ar.shape[0], - dtype=ar.dtype, - transform=transform, - crs="EPSG:4326", - ) as dst: - dst.write(ar) - - out_fc = get_fc_from_raster(base_tile_inference_file) - - print("Merging offset tiles!") - offset_tile_inference_file = MemoryFile() - ar, transform = merge(ds_offset_tiles) - with offset_tile_inference_file.open( - driver="GTiff", - height=ar.shape[1], - width=ar.shape[2], - count=ar.shape[0], - dtype=ar.dtype, - transform=transform, - crs="EPSG:4326", - ) as dst: - dst.write(ar) - - out_fc_offset = get_fc_from_raster(offset_tile_inference_file) - elif model.type == "MASKRCNN": - # out_fc = geojson.FeatureCollection( - # features=flatten_feature_list(base_tiles_inference) - # ) - # out_fc_offset = geojson.FeatureCollection( - # features=flatten_feature_list(offset_tiles_inference) - # ) + if model.type == "MASKRCNN": + out_fc = geojson.FeatureCollection( + features=flatten_feature_list(base_tiles_inference) + ) + out_fc_offset = geojson.FeatureCollection( + features=flatten_feature_list(offset_tiles_inference) + ) + elif model.type == "UNET": + # print("Loading all tiles into memory for merge!") + # ds_base_tiles = [] + # for base_tile_inference in base_tiles_inference: + # ds_base_tiles.append( + # *[ + # create_dataset_from_inference_result(b) + # for b in base_tile_inference.stack + # ] + # ) + + # ds_offset_tiles = [] + # for offset_tile_inference in offset_tiles_inference: + # ds_offset_tiles.append( + # *[ + # create_dataset_from_inference_result(b) + # for b in offset_tile_inference.stack + # ] + # ) + + # print("Merging base tiles!") + # base_tile_inference_file = MemoryFile() + # ar, transform = merge(ds_base_tiles) + # with base_tile_inference_file.open( + # driver="GTiff", + # height=ar.shape[1], + # width=ar.shape[2], + # count=ar.shape[0], + # dtype=ar.dtype, + # transform=transform, + # crs="EPSG:4326", + # ) as dst: + # dst.write(ar) + + # out_fc = get_fc_from_raster(base_tile_inference_file) + + # print("Merging offset tiles!") + # offset_tile_inference_file = MemoryFile() + # ar, transform = merge(ds_offset_tiles) + # with offset_tile_inference_file.open( + # driver="GTiff", + # height=ar.shape[1], + # width=ar.shape[2], + # count=ar.shape[0], + # dtype=ar.dtype, + # transform=transform, + # crs="EPSG:4326", + # ) as dst: + # dst.write(ar) + + # out_fc_offset = get_fc_from_raster(offset_tile_inference_file) raise NotImplementedError("UNET pathway isn't well defined") else: raise NotImplementedError(