Skip to content

Commit

Permalink
Change flow parameters: max instance count, max concurrency, max disp…
Browse files Browse the repository at this point in the history
…atch, max connections

Load model in ONCE per container, make accessible to multiple concurrent requests
Attempt some cleanup to avoid memory leak in orchestrator
  • Loading branch information
jonaraphael committed Nov 4, 2023
1 parent c6227d7 commit 18d5795
Show file tree
Hide file tree
Showing 11 changed files with 345 additions and 91 deletions.
26 changes: 12 additions & 14 deletions cerulean_cloud/cloud_run_offset_tiles/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Ref: https://github.com/python-engineer/ml-deployment/tree/main/google-cloud-run
"""
from base64 import b64decode, b64encode
from functools import lru_cache
from typing import Dict, List, Tuple, Union

import geojson
Expand Down Expand Up @@ -33,17 +32,17 @@
app.add_middleware(CORSMiddleware, allow_origins=["*"])
add_timing_middleware(app, prefix="app")

MODEL = None

def load_tracing_model(savepath):
"""load tracing model. a tracing model must be applied to the same batch dimensions the model was trained on."""
tracing_model = torch.jit.load(savepath, map_location="cpu")
return tracing_model


@lru_cache()
def get_model():
"""load model"""
return load_tracing_model("cerulean_cloud/cloud_run_offset_tiles/model/model.pt")
def load_model():
"""Load the model into the global variable."""
global MODEL
if MODEL is None:
# You should specify the correct path to your model file
model_path = "cerulean_cloud/cloud_run_offset_tiles/model/model.pt"
MODEL = torch.jit.load(model_path, map_location="cpu")
return MODEL


def logits_to_classes(out_batch_logits):
Expand Down Expand Up @@ -164,11 +163,10 @@ def _predict(
tags=["Run inference"],
response_model=InferenceResultStack,
)
def predict(
request: Request, payload: PredictPayload, model=Depends(get_model)
) -> Dict:
"""predict"""
def predict(request: Request, payload: PredictPayload) -> Dict:
"""Run prediction using the loaded model."""
record_timing(request, note="Started")
model = load_model()
results = _predict(payload.inf_stack, model, payload.inf_parms)
record_timing(request, note="Finished inference")

Expand Down
87 changes: 40 additions & 47 deletions cerulean_cloud/cloud_run_orchestrator/clients.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Clients for other cloud run functions"""
import asyncio
import json
import os
import zipfile
Expand Down Expand Up @@ -70,76 +69,70 @@ def __init__(
self.inference_parms = inference_parms

async def get_base_tile_inference(
self, tile: morecantile.Tile, semaphore: asyncio.Semaphore, rescale=(0, 255)
self, tile: morecantile.Tile, rescale=(0, 255)
) -> InferenceResultStack:
"""fetch inference for base tiles"""
async with semaphore:
img_array = await self.titiler_client.get_base_tile(
sceneid=self.sceneid, tile=tile, scale=self.scale, rescale=rescale
)
img_array = await self.titiler_client.get_base_tile(
sceneid=self.sceneid, tile=tile, scale=self.scale, rescale=rescale
)

img_array = reshape_as_raster(img_array)
img_array = reshape_as_raster(img_array)

bounds = list(TMS.bounds(tile))
bounds = list(TMS.bounds(tile))

with self.aux_datasets.open() as src:
window = rasterio.windows.from_bounds(*bounds, transform=src.transform)
height, width = img_array.shape[1:]
aux_ds = src.read(window=window, out_shape=(height, width))
with self.aux_datasets.open() as src:
window = rasterio.windows.from_bounds(*bounds, transform=src.transform)
height, width = img_array.shape[1:]
aux_ds = src.read(window=window, out_shape=(height, width))

img_array = np.concatenate([img_array[0:1, :, :], aux_ds], axis=0)
img_array = np.concatenate([img_array[0:1, :, :], aux_ds], axis=0)

encoded = img_array_to_b64_image(img_array)
encoded = img_array_to_b64_image(img_array)

inf_stack = [InferenceInput(image=encoded, bounds=TMS.bounds(tile))]
payload = PredictPayload(
inf_stack=inf_stack, inf_parms=self.inference_parms
)
res = await self.client.post(
self.url + "/predict", json=payload.dict(), timeout=None
)
inf_stack = [InferenceInput(image=encoded, bounds=TMS.bounds(tile))]
payload = PredictPayload(inf_stack=inf_stack, inf_parms=self.inference_parms)
res = await self.client.post(
self.url + "/predict", json=payload.dict(), timeout=None
)
if res.status_code == 200:
return InferenceResultStack(**res.json())
else:
print(f"XXX Received unexpected status code: {res.status_code}")
print(f"XXX HTTP content: {res.content}")
print(f"XXX Issue was found in: {self.sceneid}")
return InferenceResultStack(stack=[])
raise Exception(
f"XXX Received unexpected status code: {res.status_code} {res.content}"
)

async def get_offset_tile_inference(
self, bounds: List[float], semaphore: asyncio.Semaphore, rescale=(0, 255)
self, bounds: List[float], rescale=(0, 255)
) -> InferenceResultStack:
"""fetch inference for offset tiles"""
async with semaphore:
hw = self.scale * 256
img_array = await self.titiler_client.get_offset_tile(
self.sceneid, *bounds, width=hw, height=hw, rescale=rescale
)
img_array = reshape_as_raster(img_array)
with self.aux_datasets.open() as src:
window = rasterio.windows.from_bounds(*bounds, transform=src.transform)
height, width = img_array.shape[1:]
aux_ds = src.read(window=window, out_shape=(height, width))
hw = self.scale * 256
img_array = await self.titiler_client.get_offset_tile(
self.sceneid, *bounds, width=hw, height=hw, rescale=rescale
)
img_array = reshape_as_raster(img_array)
with self.aux_datasets.open() as src:
window = rasterio.windows.from_bounds(*bounds, transform=src.transform)
height, width = img_array.shape[1:]
aux_ds = src.read(window=window, out_shape=(height, width))

img_array = np.concatenate([img_array[0:1, :, :], aux_ds], axis=0)
img_array = np.concatenate([img_array[0:1, :, :], aux_ds], axis=0)

encoded = img_array_to_b64_image(img_array)
encoded = img_array_to_b64_image(img_array)

inf_stack = [InferenceInput(image=encoded, bounds=bounds)]
inf_stack = [InferenceInput(image=encoded, bounds=bounds)]

payload = PredictPayload(
inf_stack=inf_stack, inf_parms=self.inference_parms
)
res = await self.client.post(
self.url + "/predict", json=payload.dict(), timeout=None
)
payload = PredictPayload(inf_stack=inf_stack, inf_parms=self.inference_parms)
res = await self.client.post(
self.url + "/predict", json=payload.dict(), timeout=None
)
if res.status_code == 200:
return InferenceResultStack(**res.json())
else:
print(f"XXX Received unexpected status code: {res.status_code}")
print(f"XXX HTTP content: {res.content}")
print(f"XXX Issue was found in: {self.sceneid}")
return InferenceResultStack(stack=[])
raise Exception(
f"XXX Received unexpected status code: {res.status_code} {res.content}"
)


def get_scene_date_month(scene_id: str) -> str:
Expand Down
43 changes: 20 additions & 23 deletions cerulean_cloud/cloud_run_orchestrator/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"""
import asyncio
import os
import traceback
import urllib.parse as urlparse
from base64 import b64decode # , b64encode
from datetime import datetime, timedelta
Expand Down Expand Up @@ -245,14 +244,13 @@ def flatten_feature_list(
return flat_list


async def perform_inference(tiles, inference_func, semaphore_value, description):
async def perform_inference(tiles, inference_func, description):
"""
Perform inference on a set of tiles asynchronously.
Parameters:
- tiles or bounds (list): List of tiles to perform inference on. (depends on inference_func)
- inference_func (function): Asynchronous function to call for inference.
- semaphore_value (int): Maximum number of concurrent tasks.
- description (str): Description of the inference task for logging.
Returns:
Expand All @@ -263,22 +261,11 @@ async def perform_inference(tiles, inference_func, semaphore_value, description)
- Prints traceback of exceptions to the console.
"""
print(f"Inference on {description}!")
semaphore = asyncio.Semaphore(value=semaphore_value)
inferences = await asyncio.gather(
*[
inference_func(tile, rescale=(0, 255), semaphore=semaphore)
for tile in tiles
],
return_exceptions=True,
*[inference_func(tile, rescale=(0, 255)) for tile in tiles],
return_exceptions=False, # This raises exceptions
)
clean_inferences = []
for res in inferences:
if isinstance(res, Exception):
print(f"WARNING: Exception occurred during {description} inference: {res}")
traceback.print_tb(res.__traceback__)
else:
clean_inferences.append(res)
return clean_inferences
return inferences


async def _orchestrate(
Expand Down Expand Up @@ -396,24 +383,32 @@ async def _orchestrate(
base_tiles_inference = await perform_inference(
base_tiles,
cloud_run_inference.get_base_tile_inference,
20,
100,
"base tiles",
)

offset_tiles_inference = await perform_inference(
offset_tiles_bounds,
cloud_run_inference.get_offset_tile_inference,
20,
100,
"offset tiles",
)

# Clean up potentially memory heavy assets
del base_tiles
del offset_tiles_bounds

if model.type == "MASKRCNN":
out_fc = geojson.FeatureCollection(
features=flatten_feature_list(base_tiles_inference)
)
out_fc_offset = geojson.FeatureCollection(
features=flatten_feature_list(offset_tiles_inference)
)

# Clean up potentially memory heavy assets
del base_tiles_inference
del offset_tiles_inference
elif model.type == "UNET":
# print("Loading all tiles into memory for merge!")
# ds_base_tiles = []
Expand Down Expand Up @@ -495,11 +490,9 @@ async def _orchestrate(
buffered_gdf["geometry"] = buffered_gdf.to_crs(
"EPSG:3857"
).buffer(LAND_MASK_BUFFER_M)
landmask = get_landmask_gdf()
intersecting_land = gpd.sjoin(
get_landmask_gdf(),
buffered_gdf,
how="inner",
predicate="intersects",
landmask, buffered_gdf, how="inner", predicate="intersects"
)
if not intersecting_land.empty:
feat["properties"]["inf_idx"] = 0
Expand Down Expand Up @@ -531,6 +524,10 @@ async def _orchestrate(
ntiles=ntiles,
noffsettiles=noffsettiles,
)

# Clean up potentially memory heavy assets
del out_fc
del out_fc_offset
else:
print("WARNING: Operating as a DRY RUN!!")
orchestrator_result = OrchestratorResult(
Expand Down
8 changes: 8 additions & 0 deletions cerulean_cloud/cloud_run_orchestrator/merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ def merge_inferences(
# Reproject the GeoDataFrame back to WGS 84 CRS
result = dissolved_gdf.to_crs(crs=4326)

# Clean up potentially memory heavy assets
del dissolved_gdf
del concat_gdf
del final_gdf
del joined
del base_gdf
del offset_gdf

return result.__geo_interface__
else:
# If one of the FeatureCollections is empty, return an empty FeatureCollection
Expand Down
Loading

0 comments on commit 18d5795

Please sign in to comment.