From f80d4e7398aa007976aef54c3a687bae3fc9843b Mon Sep 17 00:00:00 2001 From: Jona Date: Thu, 21 Nov 2024 22:27:11 -0500 Subject: [PATCH] Fixing run_flags logic --- .../cloud_function_ais_analysis/main.py | 42 +++++++++---------- .../utils/analyzer.py | 3 ++ 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/cerulean_cloud/cloud_function_ais_analysis/main.py b/cerulean_cloud/cloud_function_ais_analysis/main.py index 89e62aa4..cb528a0c 100644 --- a/cerulean_cloud/cloud_function_ais_analysis/main.py +++ b/cerulean_cloud/cloud_function_ais_analysis/main.py @@ -74,12 +74,15 @@ async def handle_asa_request(request): request_json = request.get_json() if not request_json.get("dry_run"): scene_id = request_json.get("scene_id") - run_flags_empty = not request_json.get("run_flags", False) - run_flags = request_json.get( - "run_flags", ASA_MAPPING.keys() # expects list of integers - ) + run_flags = request_json.get("run_flags") # expects list of integers + if not run_flags: + run_flags = list(ASA_MAPPING.keys()) + elif any(run_flag not in ASA_MAPPING.keys() for run_flag in run_flags): + raise ValueError( + f"Invalid run_flag provided. {run_flags} not in {ASA_MAPPING.keys()}" + ) + overwrite_previous = request_json.get("overwrite_previous", False) - print(f"Running ASA ({run_flags}) on scene_id: {scene_id}") db_engine = get_engine(db_url=os.getenv("DB_URL")) async with DatabaseClient(db_engine) as db_client: async with db_client.session.begin(): @@ -89,35 +92,32 @@ async def handle_asa_request(request): for slick in slicks: print(f"Deactivating sources for slick {slick.id}") await db_client.deactivate_sources_for_slick(slick.id) - previous_asa = [ - asa_type - for slick_results in await asyncio.gather( - *(db_client.get_previous_asa(slick) for slick in slicks) - ) - for asa_type in slick_results - ] + previous_asa = { + slick: await db_client.get_previous_asa(slick) for slick in slicks + } + print(f"Running ASA ({run_flags}) on scene_id: {scene_id}") print(f"{len(slicks)} slicks in scene {scene_id}: {[s.id for s in slicks]}") if len(slicks) > 0: - if run_flags_empty: - # default behavior is to run all ASA types that haven't been run yet - # if run_flags are provided, we will run those ASA types independently of whether they've been run before - run_flags = [ - asa_type - for asa_type in run_flags - if asa_type not in previous_asa - ] analyzers = [ ASA_MAPPING[source_type](s1_scene) for source_type in run_flags ] random.shuffle(slicks) # Allows rerunning a scene to skip bugs for slick in slicks: + analyzers_to_run = [ + analyzer + for analyzer in analyzers + if analyzer.source_type not in previous_asa[slick] + ] + if len(analyzers_to_run) == 0: + continue + # Convert slick geometry to GeoDataFrame slick_geom = wkb.loads(str(slick.geometry)).buffer(0) slick_gdf = gpd.GeoDataFrame({"geometry": [slick_geom]}, crs="4326") ranked_sources = pd.DataFrame() - for analyzer in analyzers: + for analyzer in analyzers_to_run: res = analyzer.compute_coincidence_scores(slick_gdf) ranked_sources = pd.concat( [ranked_sources, res], ignore_index=True diff --git a/cerulean_cloud/cloud_function_ais_analysis/utils/analyzer.py b/cerulean_cloud/cloud_function_ais_analysis/utils/analyzer.py index 7e01780d..ed301676 100644 --- a/cerulean_cloud/cloud_function_ais_analysis/utils/analyzer.py +++ b/cerulean_cloud/cloud_function_ais_analysis/utils/analyzer.py @@ -111,6 +111,7 @@ def __init__(self, s1_scene, **kwargs): Initialize the InfrastructureAnalyzer. """ super().__init__(s1_scene, **kwargs) + self.source_type = 2 self.num_vertices = kwargs.get("num_vertices", NUM_VERTICES) self.closing_buffer = kwargs.get("closing_buffer", CLOSING_BUFFER) self.radius_of_interest = kwargs.get("radius_of_interest", INFRA_REF_DIST) @@ -448,6 +449,7 @@ def __init__(self, s1_scene, **kwargs): Initialize the AISAnalyzer. """ super().__init__(s1_scene, **kwargs) + self.source_type = 1 self.s1_scene = s1_scene # Default parameters self.hours_before = kwargs.get("hours_before", HOURS_BEFORE) @@ -927,6 +929,7 @@ def __init__(self, s1_scene, **kwargs): Initialize the DarkAnalyzer. """ super().__init__(s1_scene, **kwargs) + self.source_type = 3 # Initialize attributes specific to dark vessel analysis def compute_coincidence_scores(self, slick_gdf: gpd.GeoDataFrame):