Skip to content

Commit

Permalink
Merge pull request #103 from SkyTruth/feature-concurrency-planning
Browse files Browse the repository at this point in the history
Feature concurrency planning
  • Loading branch information
jonaraphael authored Nov 7, 2023
2 parents bd2c3ad + 97f4f51 commit 9490f6c
Show file tree
Hide file tree
Showing 15 changed files with 435 additions and 97 deletions.
2 changes: 1 addition & 1 deletion cerulean_cloud/cloud_function_scene_relevancy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def handle_notification(request_json, ocean_poly):
"""handle notification"""
filtered_scenes = []
for r in request_json.get("Records"):
sns = request_json["Records"][0]["Sns"]
sns = r["Sns"]
msg = json.loads(sns["Message"])
scene_poly = sh.polygon.Polygon(msg["footprint"]["coordinates"][0][0])

Expand Down
26 changes: 12 additions & 14 deletions cerulean_cloud/cloud_run_offset_tiles/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Ref: https://github.com/python-engineer/ml-deployment/tree/main/google-cloud-run
"""
from base64 import b64decode, b64encode
from functools import lru_cache
from typing import Dict, List, Tuple, Union

import geojson
Expand Down Expand Up @@ -33,17 +32,17 @@
app.add_middleware(CORSMiddleware, allow_origins=["*"])
add_timing_middleware(app, prefix="app")

MODEL = None

def load_tracing_model(savepath):
"""load tracing model. a tracing model must be applied to the same batch dimensions the model was trained on."""
tracing_model = torch.jit.load(savepath, map_location="cpu")
return tracing_model


@lru_cache()
def get_model():
"""load model"""
return load_tracing_model("cerulean_cloud/cloud_run_offset_tiles/model/model.pt")
def load_model():
"""Load the model into the global variable."""
global MODEL
if MODEL is None:
# You should specify the correct path to your model file
model_path = "cerulean_cloud/cloud_run_offset_tiles/model/model.pt"
MODEL = torch.jit.load(model_path, map_location="cpu")
return MODEL


def logits_to_classes(out_batch_logits):
Expand Down Expand Up @@ -164,11 +163,10 @@ def _predict(
tags=["Run inference"],
response_model=InferenceResultStack,
)
def predict(
request: Request, payload: PredictPayload, model=Depends(get_model)
) -> Dict:
"""predict"""
def predict(request: Request, payload: PredictPayload) -> Dict:
"""Run prediction using the loaded model."""
record_timing(request, note="Started")
model = load_model()
results = _predict(payload.inf_stack, model, payload.inf_parms)
record_timing(request, note="Finished inference")

Expand Down
95 changes: 48 additions & 47 deletions cerulean_cloud/cloud_run_orchestrator/clients.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Clients for other cloud run functions"""
import asyncio
import json
import os
import zipfile
Expand Down Expand Up @@ -70,76 +69,78 @@ def __init__(
self.inference_parms = inference_parms

async def get_base_tile_inference(
self, tile: morecantile.Tile, semaphore: asyncio.Semaphore, rescale=(0, 255)
self, tile: morecantile.Tile, rescale=(0, 255)
) -> InferenceResultStack:
"""fetch inference for base tiles"""
async with semaphore:
img_array = await self.titiler_client.get_base_tile(
sceneid=self.sceneid, tile=tile, scale=self.scale, rescale=rescale
)
img_array = await self.titiler_client.get_base_tile(
sceneid=self.sceneid,
tile=tile,
scale=self.scale,
rescale=rescale,
)

img_array = reshape_as_raster(img_array)
img_array = reshape_as_raster(img_array)

bounds = list(TMS.bounds(tile))
bounds = list(TMS.bounds(tile))

with self.aux_datasets.open() as src:
window = rasterio.windows.from_bounds(*bounds, transform=src.transform)
height, width = img_array.shape[1:]
aux_ds = src.read(window=window, out_shape=(height, width))
with self.aux_datasets.open() as src:
window = rasterio.windows.from_bounds(*bounds, transform=src.transform)
height, width = img_array.shape[1:]
aux_ds = src.read(window=window, out_shape=(height, width))

img_array = np.concatenate([img_array[0:1, :, :], aux_ds], axis=0)
img_array = np.concatenate([img_array[0:1, :, :], aux_ds], axis=0)

encoded = img_array_to_b64_image(img_array)
encoded = img_array_to_b64_image(img_array)

inf_stack = [InferenceInput(image=encoded, bounds=TMS.bounds(tile))]
payload = PredictPayload(
inf_stack=inf_stack, inf_parms=self.inference_parms
)
res = await self.client.post(
self.url + "/predict", json=payload.dict(), timeout=None
)
inf_stack = [InferenceInput(image=encoded, bounds=TMS.bounds(tile))]
payload = PredictPayload(inf_stack=inf_stack, inf_parms=self.inference_parms)
res = await self.client.post(
self.url + "/predict", json=payload.dict(), timeout=None
)
if res.status_code == 200:
return InferenceResultStack(**res.json())
else:
print(f"XXX Received unexpected status code: {res.status_code}")
print(f"XXX HTTP content: {res.content}")
print(f"XXX Issue was found in: {self.sceneid}")
return InferenceResultStack(stack=[])
raise Exception(
f"XXX Received unexpected status code: {res.status_code} {res.content}"
)

async def get_offset_tile_inference(
self, bounds: List[float], semaphore: asyncio.Semaphore, rescale=(0, 255)
self, bounds: List[float], rescale=(0, 255)
) -> InferenceResultStack:
"""fetch inference for offset tiles"""
async with semaphore:
hw = self.scale * 256
img_array = await self.titiler_client.get_offset_tile(
self.sceneid, *bounds, width=hw, height=hw, rescale=rescale
)
img_array = reshape_as_raster(img_array)
with self.aux_datasets.open() as src:
window = rasterio.windows.from_bounds(*bounds, transform=src.transform)
height, width = img_array.shape[1:]
aux_ds = src.read(window=window, out_shape=(height, width))
hw = self.scale * 256
img_array = await self.titiler_client.get_offset_tile(
self.sceneid,
*bounds,
width=hw,
height=hw,
scale=self.scale,
rescale=rescale,
)
img_array = reshape_as_raster(img_array)
with self.aux_datasets.open() as src:
window = rasterio.windows.from_bounds(*bounds, transform=src.transform)
height, width = img_array.shape[1:]
aux_ds = src.read(window=window, out_shape=(height, width))

img_array = np.concatenate([img_array[0:1, :, :], aux_ds], axis=0)
img_array = np.concatenate([img_array[0:1, :, :], aux_ds], axis=0)

encoded = img_array_to_b64_image(img_array)
encoded = img_array_to_b64_image(img_array)

inf_stack = [InferenceInput(image=encoded, bounds=bounds)]
inf_stack = [InferenceInput(image=encoded, bounds=bounds)]

payload = PredictPayload(
inf_stack=inf_stack, inf_parms=self.inference_parms
)
res = await self.client.post(
self.url + "/predict", json=payload.dict(), timeout=None
)
payload = PredictPayload(inf_stack=inf_stack, inf_parms=self.inference_parms)
res = await self.client.post(
self.url + "/predict", json=payload.dict(), timeout=None
)
if res.status_code == 200:
return InferenceResultStack(**res.json())
else:
print(f"XXX Received unexpected status code: {res.status_code}")
print(f"XXX HTTP content: {res.content}")
print(f"XXX Issue was found in: {self.sceneid}")
return InferenceResultStack(stack=[])
raise Exception(
f"XXX Received unexpected status code: {res.status_code} {res.content}"
)


def get_scene_date_month(scene_id: str) -> str:
Expand Down
40 changes: 14 additions & 26 deletions cerulean_cloud/cloud_run_orchestrator/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"""
import asyncio
import os
import traceback
import urllib.parse as urlparse
from base64 import b64decode # , b64encode
from datetime import datetime, timedelta
Expand Down Expand Up @@ -245,14 +244,13 @@ def flatten_feature_list(
return flat_list


async def perform_inference(tiles, inference_func, semaphore_value, description):
async def perform_inference(tiles, inference_func, description):
"""
Perform inference on a set of tiles asynchronously.
Parameters:
- tiles or bounds (list): List of tiles to perform inference on. (depends on inference_func)
- inference_func (function): Asynchronous function to call for inference.
- semaphore_value (int): Maximum number of concurrent tasks.
- description (str): Description of the inference task for logging.
Returns:
Expand All @@ -263,22 +261,11 @@ async def perform_inference(tiles, inference_func, semaphore_value, description)
- Prints traceback of exceptions to the console.
"""
print(f"Inference on {description}!")
semaphore = asyncio.Semaphore(value=semaphore_value)
inferences = await asyncio.gather(
*[
inference_func(tile, rescale=(0, 255), semaphore=semaphore)
for tile in tiles
],
return_exceptions=True,
*[inference_func(tile, rescale=(0, 255)) for tile in tiles],
return_exceptions=False, # This raises exceptions
)
clean_inferences = []
for res in inferences:
if isinstance(res, Exception):
print(f"WARNING: Exception occurred during {description} inference: {res}")
traceback.print_tb(res.__traceback__)
else:
clean_inferences.append(res)
return clean_inferences
return inferences


async def _orchestrate(
Expand Down Expand Up @@ -412,36 +399,35 @@ async def _orchestrate(
base_tiles_inference = await perform_inference(
base_tiles,
cloud_run_inference.get_base_tile_inference,
20,
"base tiles",
)

offset_tiles_inference = await perform_inference(
offset_tiles_bounds,
cloud_run_inference.get_offset_tile_inference,
20,
"offset tiles",
)

offset_2_tiles_inference = await perform_inference(
offset_2_tiles_bounds,
cloud_run_inference.get_offset_tile_inference,
20,
"offset2 tiles",
)
del base_tiles
del offset_tiles_bounds

if model.type == "MASKRCNN":
out_fc = geojson.FeatureCollection(
features=flatten_feature_list(base_tiles_inference)
)

out_fc_offset = geojson.FeatureCollection(
features=flatten_feature_list(offset_tiles_inference)
)

out_fc_offset_2 = geojson.FeatureCollection(
features=flatten_feature_list(offset_2_tiles_inference)
)
del base_tiles_inference
del offset_tiles_inference
elif model.type == "UNET":
# print("Loading all tiles into memory for merge!")
# ds_base_tiles = []
Expand Down Expand Up @@ -523,11 +509,9 @@ async def _orchestrate(
buffered_gdf["geometry"] = buffered_gdf.to_crs(
"EPSG:3857"
).buffer(LAND_MASK_BUFFER_M)
landmask = get_landmask_gdf()
intersecting_land = gpd.sjoin(
get_landmask_gdf(),
buffered_gdf,
how="inner",
predicate="intersects",
landmask, buffered_gdf, how="inner", predicate="intersects"
)
if not intersecting_land.empty:
feat["properties"]["inf_idx"] = 0
Expand Down Expand Up @@ -559,6 +543,10 @@ async def _orchestrate(
ntiles=ntiles,
noffsettiles=noffsettiles,
)

# Clean up potentially memory heavy assets
del out_fc
del out_fc_offset
else:
print("WARNING: Operating as a DRY RUN!!")
orchestrator_result = OrchestratorResult(
Expand Down
8 changes: 8 additions & 0 deletions cerulean_cloud/cloud_run_orchestrator/merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def merge_inferences(
return geojson.FeatureCollection(features=[])

# Concat the GeoDataFrames
target_crs = gdfs_for_processing[0].crs
gdfs_for_processing = [gdf.to_crs(target_crs) for gdf in gdfs_for_processing]
concat_gdf = pd.concat(gdfs_for_processing, ignore_index=True)
final_gdf = concat_gdf.copy()

Expand Down Expand Up @@ -107,4 +109,10 @@ def merge_inferences(
# Reproject the GeoDataFrame back to WGS 84 CRS
result = dissolved_gdf.to_crs(crs=4326)

# Clean up potentially memory heavy assets
del dissolved_gdf
del concat_gdf
del final_gdf
del joined

return result.__geo_interface__
2 changes: 2 additions & 0 deletions cerulean_cloud/titiler_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ async def get_offset_tile(
height: int = 256,
band: str = "vv",
img_format: str = "png",
scale: int = 1,
rescale: Tuple[int, int] = (0, 255),
) -> np.ndarray:
"""get offset tile as numpy array (with bounds)
Expand All @@ -169,6 +170,7 @@ async def get_offset_tile(
)
url += f"?sceneid={sceneid}"
url += f"&bands={band}"
url += f"&scale={scale}"
url += f"&rescale={','.join([str(r) for r in rescale])}"
resp = await self.client.get(url, timeout=self.timeout)

Expand Down
Loading

0 comments on commit 9490f6c

Please sign in to comment.