Skip to content

Commit

Permalink
Move inference client on to app state rather than using a global vari…
Browse files Browse the repository at this point in the history
…able. Cleaned up some incorrect typing.
  • Loading branch information
Edward Keeble committed Nov 23, 2023
1 parent 270ae45 commit 647e4fb
Showing 1 changed file with 35 additions and 17 deletions.
52 changes: 35 additions & 17 deletions cerulean_cloud/cloud_run_orchestrator/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy as np
import rasterio
import supermercado
from fastapi import Depends, FastAPI
from fastapi import Depends, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from global_land_mask import globe
from rasterio.io import MemoryFile
Expand All @@ -46,15 +46,6 @@
from cerulean_cloud.tiling import TMS, offset_bounds_from_base_tiles
from cerulean_cloud.titiler_client import TitilerClient

# create a global client
inference_client = httpx.AsyncClient(
headers={"Authorization": f"Bearer {os.getenv('API_KEY')}"},
limits=httpx.Limits(
max_connections=int(os.getenv("MAX_INFERENCE_CONNECTIONS", default=500))
),
timeout=None,
)


@asynccontextmanager
async def lifespan(app: FastAPI):
Expand All @@ -65,9 +56,18 @@ async def lifespan(app: FastAPI):
Yields:
None: This function does not yield any value but ensures proper resource management.
"""
print("Starting up...")
print("Creating inference client...")
app.state.inference_client = httpx.AsyncClient(
headers={"Authorization": f"Bearer {os.getenv('API_KEY')}"},
limits=httpx.Limits(
max_connections=int(os.getenv("MAX_INFERENCE_CONNECTIONS", default=100))
),
timeout=None,
)

yield
await inference_client.aclose()
print("Cleaning up inference client...")
await app.state.inference_client.aclose()
print("Shutting down...")


Expand Down Expand Up @@ -152,7 +152,9 @@ def offset_group_shape_from_base_tiles(
return height, width


def group_bounds_from_list_of_bounds(bounds: List[List[float]]) -> List[float]:
def group_bounds_from_list_of_bounds(
bounds: List[Tuple[float, float, float, float]]
) -> List[float]:
"""from a list of bounds, get the merged bounds (min max)"""
bounds_np = np.array([(b[0], b[1], b[2], b[3]) for b in bounds])
minx, miny, maxx, maxy = (
Expand Down Expand Up @@ -221,14 +223,20 @@ def ping() -> Dict:
)
async def orchestrate(
payload: OrchestratorInput,
request: Request,
tiler=Depends(get_tiler),
titiler_client=Depends(get_titiler_client),
roda_sentinelhub_client=Depends(get_roda_sentinelhub_client),
db_engine=Depends(get_database_engine),
) -> Dict:
"""orchestrate"""
return await _orchestrate(
payload, tiler, titiler_client, roda_sentinelhub_client, db_engine
payload,
tiler,
titiler_client,
request.app.state.inference_client,
roda_sentinelhub_client,
db_engine,
)


Expand Down Expand Up @@ -258,7 +266,7 @@ def create_dataset_from_inference_result(
return memfile.open()


def is_tile_over_water(tile_bounds: List[float]) -> bool:
def is_tile_over_water(tile_bounds: Tuple[float, float, float, float]) -> bool:
"""are the tile bounds over water"""
minx, miny, maxx, maxy = tile_bounds
return any(globe.is_ocean([miny, maxy], [minx, maxx]))
Expand All @@ -276,7 +284,9 @@ def flatten_feature_list(
return flat_list


async def perform_inference(tiles, inference_func, description):
async def perform_inference(
tiles, inference_client: httpx.AsyncClient, inference_func, description
):
"""
Perform inference on a set of tiles asynchronously.
Expand Down Expand Up @@ -307,7 +317,12 @@ async def perform_inference(tiles, inference_func, description):


async def _orchestrate(
payload, tiler, titiler_client, roda_sentinelhub_client, db_engine
payload,
tiler,
titiler_client,
inference_client: httpx.AsyncClient,
roda_sentinelhub_client,
db_engine,
):
# Orchestrate inference
start_time = datetime.now()
Expand Down Expand Up @@ -439,18 +454,21 @@ async def _orchestrate(

base_tiles_inference = await perform_inference(
base_tiles,
inference_client,
cloud_run_inference.get_base_tile_inference,
f"base tiles: {start_time}",
)

offset_tiles_inference = await perform_inference(
offset_tiles_bounds,
inference_client,
cloud_run_inference.get_offset_tile_inference,
f"offset tiles: {start_time}",
)

offset_2_tiles_inference = await perform_inference(
offset_2_tiles_bounds,
inference_client,
cloud_run_inference.get_offset_tile_inference,
f"offset2 tiles: {start_time}",
)
Expand Down

0 comments on commit 647e4fb

Please sign in to comment.