Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug/memory loss #126

Closed
wants to merge 46 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
0eff4aa
singleton engine, and garbage collection
jonaraphael Sep 18, 2024
b2924db
Clean up some engine use in orchestrator
jonaraphael Sep 18, 2024
b112974
oops
jonaraphael Sep 19, 2024
8004447
Clean up singleton use
jonaraphael Sep 19, 2024
dff7c4c
Was closing the problem?
jonaraphael Sep 20, 2024
e5c1522
orchestrator object needed to persist across multiple db connections
jonaraphael Sep 20, 2024
7bcf88e
oops--forgot to flush
jonaraphael Sep 20, 2024
e888ced
Swap engine maintenance to fastapi's app lifetime.
jonaraphael Sep 24, 2024
88672a5
Garbage collection
jonaraphael Sep 24, 2024
fc2bd8e
update AIS as well
jonaraphael Sep 24, 2024
554a94c
Merge branch 'bug/engine-maintenance' into bug/memory-loss
jonaraphael Sep 24, 2024
cf254f1
Oops! Deleted before assignment...
jonaraphael Sep 24, 2024
8bf5803
more debug statement
jonaraphael Sep 25, 2024
7f24c5f
Merge branch 'main' into bug/memory-loss
jonaraphael Nov 26, 2024
06d8ea1
remove duplicate stitching/ensembling
sstill88 Dec 11, 2024
673d687
updated logging, added retries to get_bounds
sstill88 Dec 13, 2024
885734e
more logs
sstill88 Dec 16, 2024
712325f
added retries and cleared some MemoryFiles
sstill88 Dec 16, 2024
2d3035c
reformat
sstill88 Dec 16, 2024
b03b482
moved more print --> logger
sstill88 Dec 16, 2024
426df03
oops, revert one change to MemoryFile back
sstill88 Dec 17, 2024
af03c6c
clear memory when run_parallel_inference fails
sstill88 Dec 17, 2024
17f6d94
return None if fail inf fetch_and_process_image
sstill88 Dec 17, 2024
0860f3d
error logs
sstill88 Dec 17, 2024
29c228f
undo memfile clear; returns the memfile
sstill88 Dec 17, 2024
0118820
converted log text payloads to json payloads
sstill88 Dec 18, 2024
5e12ef7
make logs json serializable
sstill88 Dec 18, 2024
2c553e9
removed log formatting (return jsonPayload)
sstill88 Dec 18, 2024
0b2b72c
update logs
sstill88 Dec 18, 2024
5a50854
clearer logs for if image is empty vs there is no imagery
sstill88 Dec 18, 2024
377bd34
exception --> error
sstill88 Dec 18, 2024
070b736
moved structured_log to utils and added severity
sstill88 Dec 18, 2024
c4dd899
add more retries
sstill88 Dec 18, 2024
1b54756
raise error instead of return None
sstill88 Dec 18, 2024
6c4dd26
import logging correctly...
sstill88 Dec 18, 2024
766a4eb
add method=get to test
sstill88 Dec 19, 2024
9b444b1
add cloud_run_orchestrator/utils.py to ais cloud build
sstill88 Dec 19, 2024
c3d5293
typo...
sstill88 Dec 19, 2024
9287563
make cloud_run_orchestrator directory in action yaml
sstill88 Dec 19, 2024
32628c9
bugfix and add separate try/except for roda (after titiler metadata)
sstill88 Dec 20, 2024
ab93827
added tracebacks to errors
sstill88 Dec 20, 2024
30e874c
oops. removed some nested try/except/if weirdness
sstill88 Dec 20, 2024
d21a9ad
create separate error for if asyncio.gather fails
sstill88 Dec 20, 2024
a34d604
typo - make sure aux_datasets are cleared
sstill88 Dec 20, 2024
be9cfe9
reduced log size when possible
sstill88 Dec 20, 2024
56ccaca
return RuntimeErrors instead of None
sstill88 Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/actions/deploy_infrastructure/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,10 @@ runs:
shell: bash
run: |
mkdir -p cerulean_cloud/cloud_function_ais_analysis/cerulean_cloud/
mkdir -p cerulean_cloud/cloud_function_ais_analysis/cerulean_cloud/cloud_run_orchestrator/
cp cerulean_cloud/database_client.py cerulean_cloud/cloud_function_ais_analysis/cerulean_cloud/database_client.py
cp cerulean_cloud/database_schema.py cerulean_cloud/cloud_function_ais_analysis/cerulean_cloud/database_schema.py
cp cerulean_cloud/cloud_run_orchestrator/utils.py cerulean_cloud/cloud_function_ais_analysis/cerulean_cloud/cloud_run_orchestrator/utils.py
cp cerulean_cloud/__init__.py cerulean_cloud/cloud_function_ais_analysis/cerulean_cloud/__init__.py

- name: Deploy Infrastructure
Expand Down
9 changes: 8 additions & 1 deletion cerulean_cloud/cloud_function_ais_analysis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

from cerulean_cloud.database_client import DatabaseClient, get_engine

# Initialize the database engine globally to reuse across requests
# This improves performance by avoiding the overhead of creating a new engine for each request
DB_ENGINE = get_engine()


def verify_api_key(request):
"""Function to verify API key"""
Expand Down Expand Up @@ -57,7 +61,10 @@ def main(request):
verify_api_key(request)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
res = loop.run_until_complete(handle_asa_request(request))
try:
res = loop.run_until_complete(handle_asa_request(request))
finally:
loop.close() # Ensure the event loop is properly closed
return res


Expand Down
190 changes: 140 additions & 50 deletions cerulean_cloud/cloud_run_orchestrator/clients.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Clients for other cloud run functions"""

import asyncio
import gc # Import garbage collection module
import json
import logging
import os
import sys
import traceback
import zipfile
from base64 import b64encode
from datetime import datetime
Expand All @@ -23,10 +27,11 @@
InferenceResultStack,
PredictPayload,
)
from cerulean_cloud.cloud_run_orchestrator.utils import structured_log


def img_array_to_b64_image(img_array: np.ndarray, to_uint8=False) -> str:
"""convert input b64image to torch tensor"""
"""Convert input image array to base64-encoded image."""
if to_uint8 and not img_array.dtype == np.uint8:
print(
f"WARNING: changing from dtype {img_array.dtype} to uint8 without scaling!"
Expand All @@ -47,7 +52,7 @@ def img_array_to_b64_image(img_array: np.ndarray, to_uint8=False) -> str:


class CloudRunInferenceClient:
"""Client for inference cloud run"""
"""Client for inference cloud run."""

def __init__(
self,
Expand All @@ -60,15 +65,23 @@ def __init__(
scale: int,
model_dict,
):
"""init"""
"""Initialize the inference client."""
self.url = url
self.titiler_client = titiler_client
self.sceneid = sceneid
self.scale = scale # 1=256, 2=512, 3=...
self.model_dict = model_dict

# Configure logger
self.logger = logging.getLogger("InferenceClient")
handler = logging.StreamHandler(sys.stdout) # Write logs to stdout
self.logger.addHandler(handler)
self.logger.setLevel(logging.INFO)

# Handle auxiliary datasets and ensure they are properly managed
self.aux_datasets = handle_aux_datasets(
layers, self.sceneid, tileset_envelope_bounds, image_hw_pixels
)
self.scale = scale # 1=256, 2=512, 3=...
self.model_dict = model_dict

async def fetch_and_process_image(
self, tile_bounds, rescale=(0, 255), num_channels=1
Expand All @@ -87,18 +100,29 @@ async def fetch_and_process_image(
"""

hw = self.scale * 256
img_array = await self.titiler_client.get_offset_tile(
self.sceneid,
*tile_bounds,
width=hw,
height=hw,
scale=self.scale,
rescale=rescale,
)
try:
img_array = await self.titiler_client.get_offset_tile(
self.sceneid,
*tile_bounds,
width=hw,
height=hw,
scale=self.scale,
rescale=rescale,
)

img_array = reshape_as_raster(img_array)
img_array = img_array[0:num_channels, :, :]
return img_array
img_array = reshape_as_raster(img_array)
img_array = img_array[0:num_channels, :, :]
return img_array
except Exception as e:
self.logger.warning(
structured_log(
f"Error retrieving tile array for {json.dumps(tile_bounds)}",
severity="WARNING",
scene_id=self.sceneid,
exception=str(e),
)
)
return None

async def process_auxiliary_datasets(self, img_array, tile_bounds):
"""
Expand All @@ -117,6 +141,9 @@ async def process_auxiliary_datasets(self, img_array, tile_bounds):
window = rasterio.windows.from_bounds(*tile_bounds, transform=src.transform)
height, width = img_array.shape[1:]
aux_ds = src.read(window=window, out_shape=(src.count, height, width))

del src
gc.collect()
return np.concatenate([img_array, aux_ds], axis=0)

async def send_inference_request_and_handle_response(self, http_client, img_array):
Expand All @@ -134,9 +161,6 @@ async def send_inference_request_and_handle_response(self, http_client, img_arra

Raises:
- Exception: If the request fails or the service returns an unexpected status code, with details provided in the exception message.

Note:
- This function constructs the inference payload by encoding the image and specifying the geographic bounds and any additional inference parameters through `self.model_dict`.
"""

encoded = img_array_to_b64_image(img_array, to_uint8=True)
Expand All @@ -146,28 +170,32 @@ async def send_inference_request_and_handle_response(self, http_client, img_arra
max_retries = 2 # Total attempts including the first try
retry_delay = 5 # Delay in seconds between retries

for attempt in range(max_retries):
for attempt in range(1, max_retries + 1):
try:
res = await http_client.post(
self.url + "/predict", json=payload.dict(), timeout=None
)
if res.status_code == 200:
return InferenceResultStack(**res.json())
else:
print(
f"Attempt {attempt + 1}: Failed with status code {res.status_code}. Retrying..."
)
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay) # Wait before retrying
return InferenceResultStack(**res.json())
except Exception as e:
print(f"Attempt {attempt + 1}: Exception occurred: {e}")
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay) # Wait before retrying
if attempt == max_retries:
self.logger.error(
structured_log(
"Failed to get inference",
severity="ERROR",
scene_id=self.sceneid,
exception=str(e),
traceback=traceback.format_exc(),
)
)
raise

# If all attempts fail, raise an exception
raise Exception(
f"All attempts failed after {max_retries} retries. Last known error: {res.content}"
)
self.logger.warning(
structured_log(
f"Error getting inference; Attempt {attempt + 1}, retrying . . .",
severity="WARNING",
)
)
await asyncio.sleep(retry_delay) # Wait before retrying

async def get_tile_inference(self, http_client, tile_bounds, rescale=(0, 255)):
"""
Expand All @@ -186,13 +214,35 @@ async def get_tile_inference(self, http_client, tile_bounds, rescale=(0, 255)):
"""

img_array = await self.fetch_and_process_image(tile_bounds, rescale)
if not np.any(img_array):
if img_array is None:
return InferenceResultStack(stack=[])
elif not np.any(img_array):
self.logger.warning(
structured_log(
f"empty image for {str(tile_bounds)}",
severity="WARNING",
scene_id=self.sceneid,
)
)
return InferenceResultStack(stack=[])

if self.aux_datasets:
img_array = await self.process_auxiliary_datasets(img_array, tile_bounds)
return await self.send_inference_request_and_handle_response(
res = await self.send_inference_request_and_handle_response(
http_client, img_array
)
del img_array
gc.collect()

self.logger.info(
structured_log(
f"generated image for {str(tile_bounds)}",
severity="INFO",
scene_id=self.sceneid,
)
)

return res

async def run_parallel_inference(self, tileset):
"""
Expand All @@ -204,24 +254,49 @@ async def run_parallel_inference(self, tileset):
Returns:
- list: List of inference results, with exceptions filtered out.
"""
async with httpx.AsyncClient(
headers={"Authorization": f"Bearer {os.getenv('API_KEY')}"}
) as async_http_client:
tasks = [
self.get_tile_inference(
http_client=async_http_client, tile_bounds=tile_bounds
try:
async with httpx.AsyncClient(
headers={"Authorization": f"Bearer {os.getenv('API_KEY')}"}
) as async_http_client:
tasks = [
self.get_tile_inference(
http_client=async_http_client, tile_bounds=tile_bounds
)
for tile_bounds in tileset
]
except Exception as e:
self.logger.error(
structured_log(
"Failed to complete parallel inference",
severity="ERROR",
scene_id=self.sceneid,
exception=str(e),
traceback=traceback.format_exc(),
)
for tile_bounds in tileset
]
)

# If get_tile_inference tasks fail, close and delete aux_datasets
if self.aux_datasets:
del self.aux_datasets
gc.collect()

if "tasks" in locals():
inferences = await asyncio.gather(*tasks, return_exceptions=False)
# False means this process will error out if any subtask errors out
# True means this process will return a list including errors if any subtask errors out

# After processing, close and delete aux_datasets
if self.aux_datasets:
del self.aux_datasets
gc.collect()
else:
raise ValueError("Failed to gather inference")

return inferences


def get_scene_date_month(scene_id: str) -> str:
"""From a scene id, fetch the month of the scene"""
# i.e. S1A_IW_GRDH_1SDV_20200802T141646_20200802T141711_033729_03E8C7_E4F5
"""From a scene id, fetch the month of the scene."""
date_time_str = scene_id[17:32]
date_time_obj = datetime.strptime(date_time_str, "%Y%m%dT%H%M%S")
date_time_obj = date_time_obj.replace(day=1, hour=0, minute=0, second=0)
Expand All @@ -235,7 +310,7 @@ def get_ship_density(
max_dens=100,
url="http://gmtds.maplarge.com/Api/ProcessDirect?",
) -> np.ndarray:
"""fetch ship density from gmtds service"""
"""Fetch ship density from gmtds service."""
h, w = img_shape
bbox_wms = bounds[0], bounds[2], bounds[1], bounds[-1]

Expand Down Expand Up @@ -322,6 +397,10 @@ def get_query(bbox_wms, scene_date_month):

dens_array = ar / (max_dens / 255)
dens_array[dens_array >= 255] = 255

del tempbuf, zipfile_ob, cont, r, ar
gc.collect()

return np.squeeze(dens_array.astype("uint8"))


Expand All @@ -331,7 +410,7 @@ def get_dist_array(
raster_ds: str,
max_distance: int = 60000,
):
"""fetch distance array from pre computed distance raster dataset"""
"""Fetch distance array from pre-computed distance raster dataset."""
with COGReader(raster_ds) as image:
height, width = img_shape[0:2]
img = image.part(
Expand All @@ -340,6 +419,7 @@ def get_dist_array(
width=width,
)
data = img.data_as_image()

if (data == 0).all():
data = np.ones(img_shape) * 255
else:
Expand All @@ -350,13 +430,17 @@ def get_dist_array(
data, (*img_shape[0:2], 1), order=1, preserve_range=True
) # resampling interpolation must match training data prep
upsampled = np.squeeze(upsampled)

del data, img
gc.collect()

return upsampled.astype(np.uint8)


def handle_aux_datasets(
layers, scene_id, tileset_envelope_bounds, image_hw_pixels, **kwargs
):
"""handle aux datasets"""
"""Handle auxiliary datasets."""
if layers[0].short_name != "VV":
raise NotImplementedError(
f"VV Layer must come first. Instead found: {layers[0].short_name}"
Expand Down Expand Up @@ -389,6 +473,9 @@ def handle_aux_datasets(
[aux_dataset_channels, ar], axis=2
)

del ar
gc.collect()

aux_memfile = MemoryFile()
if aux_dataset_channels is not None:
height, width = aux_dataset_channels.shape[0:2]
Expand All @@ -408,6 +495,9 @@ def handle_aux_datasets(
) as dataset:
dataset.write(reshape_as_raster(aux_dataset_channels))

del aux_dataset_channels
gc.collect()

return aux_memfile
else:
return None
Loading
Loading