Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug/memory loss #126

Closed
wants to merge 46 commits into from
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
0eff4aa
singleton engine, and garbage collection
jonaraphael Sep 18, 2024
b2924db
Clean up some engine use in orchestrator
jonaraphael Sep 18, 2024
b112974
oops
jonaraphael Sep 19, 2024
8004447
Clean up singleton use
jonaraphael Sep 19, 2024
dff7c4c
Was closing the problem?
jonaraphael Sep 20, 2024
e5c1522
orchestrator object needed to persist across multiple db connections
jonaraphael Sep 20, 2024
7bcf88e
oops--forgot to flush
jonaraphael Sep 20, 2024
e888ced
Swap engine maintenance to fastapi's app lifetime.
jonaraphael Sep 24, 2024
88672a5
Garbage collection
jonaraphael Sep 24, 2024
fc2bd8e
update AIS as well
jonaraphael Sep 24, 2024
554a94c
Merge branch 'bug/engine-maintenance' into bug/memory-loss
jonaraphael Sep 24, 2024
cf254f1
Oops! Deleted before assignment...
jonaraphael Sep 24, 2024
8bf5803
more debug statement
jonaraphael Sep 25, 2024
7f24c5f
Merge branch 'main' into bug/memory-loss
jonaraphael Nov 26, 2024
06d8ea1
remove duplicate stitching/ensembling
sstill88 Dec 11, 2024
673d687
updated logging, added retries to get_bounds
sstill88 Dec 13, 2024
885734e
more logs
sstill88 Dec 16, 2024
712325f
added retries and cleared some MemoryFiles
sstill88 Dec 16, 2024
2d3035c
reformat
sstill88 Dec 16, 2024
b03b482
moved more print --> logger
sstill88 Dec 16, 2024
426df03
oops, revert one change to MemoryFile back
sstill88 Dec 17, 2024
af03c6c
clear memory when run_parallel_inference fails
sstill88 Dec 17, 2024
17f6d94
return None if fail inf fetch_and_process_image
sstill88 Dec 17, 2024
0860f3d
error logs
sstill88 Dec 17, 2024
29c228f
undo memfile clear; returns the memfile
sstill88 Dec 17, 2024
0118820
converted log text payloads to json payloads
sstill88 Dec 18, 2024
5e12ef7
make logs json serializable
sstill88 Dec 18, 2024
2c553e9
removed log formatting (return jsonPayload)
sstill88 Dec 18, 2024
0b2b72c
update logs
sstill88 Dec 18, 2024
5a50854
clearer logs for if image is empty vs there is no imagery
sstill88 Dec 18, 2024
377bd34
exception --> error
sstill88 Dec 18, 2024
070b736
moved structured_log to utils and added severity
sstill88 Dec 18, 2024
c4dd899
add more retries
sstill88 Dec 18, 2024
1b54756
raise error instead of return None
sstill88 Dec 18, 2024
6c4dd26
import logging correctly...
sstill88 Dec 18, 2024
766a4eb
add method=get to test
sstill88 Dec 19, 2024
9b444b1
add cloud_run_orchestrator/utils.py to ais cloud build
sstill88 Dec 19, 2024
c3d5293
typo...
sstill88 Dec 19, 2024
9287563
make cloud_run_orchestrator directory in action yaml
sstill88 Dec 19, 2024
32628c9
bugfix and add separate try/except for roda (after titiler metadata)
sstill88 Dec 20, 2024
ab93827
added tracebacks to errors
sstill88 Dec 20, 2024
30e874c
oops. removed some nested try/except/if weirdness
sstill88 Dec 20, 2024
d21a9ad
create separate error for if asyncio.gather fails
sstill88 Dec 20, 2024
a34d604
typo - make sure aux_datasets are cleared
sstill88 Dec 20, 2024
be9cfe9
reduced log size when possible
sstill88 Dec 20, 2024
56ccaca
return RuntimeErrors instead of None
sstill88 Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion cerulean_cloud/cloud_function_ais_analysis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

from cerulean_cloud.database_client import DatabaseClient, get_engine

# Initialize the database engine globally to reuse across requests
# This improves performance by avoiding the overhead of creating a new engine for each request
DB_ENGINE = get_engine()


def verify_api_key(request):
"""Function to verify API key"""
Expand Down Expand Up @@ -57,7 +61,10 @@ def main(request):
verify_api_key(request)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
res = loop.run_until_complete(handle_asa_request(request))
try:
res = loop.run_until_complete(handle_asa_request(request))
finally:
loop.close() # Ensure the event loop is properly closed
return res


Expand Down
54 changes: 40 additions & 14 deletions cerulean_cloud/cloud_run_orchestrator/clients.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Clients for other cloud run functions"""

import asyncio
import gc # Import garbage collection module
import json
import os
import zipfile
Expand All @@ -26,7 +27,7 @@


def img_array_to_b64_image(img_array: np.ndarray, to_uint8=False) -> str:
"""convert input b64image to torch tensor"""
"""Convert input image array to base64-encoded image."""
if to_uint8 and not img_array.dtype == np.uint8:
print(
f"WARNING: changing from dtype {img_array.dtype} to uint8 without scaling!"
Expand All @@ -47,7 +48,7 @@ def img_array_to_b64_image(img_array: np.ndarray, to_uint8=False) -> str:


class CloudRunInferenceClient:
"""Client for inference cloud run"""
"""Client for inference cloud run."""

def __init__(
self,
Expand All @@ -60,15 +61,17 @@ def __init__(
scale: int,
model_dict,
):
"""init"""
"""Initialize the inference client."""
self.url = url
self.titiler_client = titiler_client
self.sceneid = sceneid
self.scale = scale # 1=256, 2=512, 3=...
self.model_dict = model_dict

# Handle auxiliary datasets and ensure they are properly managed
self.aux_datasets = handle_aux_datasets(
layers, self.sceneid, tileset_envelope_bounds, image_hw_pixels
)
self.scale = scale # 1=256, 2=512, 3=...
self.model_dict = model_dict

async def fetch_and_process_image(
self, tile_bounds, rescale=(0, 255), num_channels=1
Expand Down Expand Up @@ -117,6 +120,9 @@ async def process_auxiliary_datasets(self, img_array, tile_bounds):
window = rasterio.windows.from_bounds(*tile_bounds, transform=src.transform)
height, width = img_array.shape[1:]
aux_ds = src.read(window=window, out_shape=(src.count, height, width))

del src
gc.collect()
return np.concatenate([img_array, aux_ds], axis=0)

async def send_inference_request_and_handle_response(self, http_client, img_array):
Expand All @@ -134,9 +140,6 @@ async def send_inference_request_and_handle_response(self, http_client, img_arra

Raises:
- Exception: If the request fails or the service returns an unexpected status code, with details provided in the exception message.

Note:
- This function constructs the inference payload by encoding the image and specifying the geographic bounds and any additional inference parameters through `self.model_dict`.
"""

encoded = img_array_to_b64_image(img_array, to_uint8=True)
Expand Down Expand Up @@ -190,9 +193,12 @@ async def get_tile_inference(self, http_client, tile_bounds, rescale=(0, 255)):
return InferenceResultStack(stack=[])
if self.aux_datasets:
img_array = await self.process_auxiliary_datasets(img_array, tile_bounds)
return await self.send_inference_request_and_handle_response(
res = await self.send_inference_request_and_handle_response(
http_client, img_array
)
del img_array
gc.collect()
return res

async def run_parallel_inference(self, tileset):
"""
Expand All @@ -216,12 +222,17 @@ async def run_parallel_inference(self, tileset):
inferences = await asyncio.gather(*tasks, return_exceptions=False)
# False means this process will error out if any subtask errors out
# True means this process will return a list including errors if any subtask errors out

# After processing, close and delete aux_datasets
if self.aux_datasets:
del self.aux_datasets
gc.collect()

return inferences


def get_scene_date_month(scene_id: str) -> str:
"""From a scene id, fetch the month of the scene"""
# i.e. S1A_IW_GRDH_1SDV_20200802T141646_20200802T141711_033729_03E8C7_E4F5
"""From a scene id, fetch the month of the scene."""
date_time_str = scene_id[17:32]
date_time_obj = datetime.strptime(date_time_str, "%Y%m%dT%H%M%S")
date_time_obj = date_time_obj.replace(day=1, hour=0, minute=0, second=0)
Expand All @@ -235,7 +246,7 @@ def get_ship_density(
max_dens=100,
url="http://gmtds.maplarge.com/Api/ProcessDirect?",
) -> np.ndarray:
"""fetch ship density from gmtds service"""
"""Fetch ship density from gmtds service."""
h, w = img_shape
bbox_wms = bounds[0], bounds[2], bounds[1], bounds[-1]

Expand Down Expand Up @@ -322,6 +333,10 @@ def get_query(bbox_wms, scene_date_month):

dens_array = ar / (max_dens / 255)
dens_array[dens_array >= 255] = 255

del tempbuf, zipfile_ob, cont, r, ar
gc.collect()

return np.squeeze(dens_array.astype("uint8"))


Expand All @@ -331,7 +346,7 @@ def get_dist_array(
raster_ds: str,
max_distance: int = 60000,
):
"""fetch distance array from pre computed distance raster dataset"""
"""Fetch distance array from pre-computed distance raster dataset."""
with COGReader(raster_ds) as image:
height, width = img_shape[0:2]
img = image.part(
Expand All @@ -340,6 +355,7 @@ def get_dist_array(
width=width,
)
data = img.data_as_image()

if (data == 0).all():
data = np.ones(img_shape) * 255
else:
Expand All @@ -350,13 +366,17 @@ def get_dist_array(
data, (*img_shape[0:2], 1), order=1, preserve_range=True
) # resampling interpolation must match training data prep
upsampled = np.squeeze(upsampled)

del data, img
gc.collect()

return upsampled.astype(np.uint8)


def handle_aux_datasets(
layers, scene_id, tileset_envelope_bounds, image_hw_pixels, **kwargs
):
"""handle aux datasets"""
"""Handle auxiliary datasets."""
if layers[0].short_name != "VV":
raise NotImplementedError(
f"VV Layer must come first. Instead found: {layers[0].short_name}"
Expand Down Expand Up @@ -389,6 +409,9 @@ def handle_aux_datasets(
[aux_dataset_channels, ar], axis=2
)

del ar
gc.collect()

aux_memfile = MemoryFile()
if aux_dataset_channels is not None:
height, width = aux_dataset_channels.shape[0:2]
Expand All @@ -408,6 +431,9 @@ def handle_aux_datasets(
) as dataset:
dataset.write(reshape_as_raster(aux_dataset_channels))

del aux_dataset_channels
gc.collect()

return aux_memfile
else:
return None
Loading
Loading