diff --git a/cerulean_cloud/cloud_function_ais_analysis/main.py b/cerulean_cloud/cloud_function_ais_analysis/main.py index 3ac6d31e..89e62aa4 100644 --- a/cerulean_cloud/cloud_function_ais_analysis/main.py +++ b/cerulean_cloud/cloud_function_ais_analysis/main.py @@ -74,24 +74,42 @@ async def handle_asa_request(request): request_json = request.get_json() if not request_json.get("dry_run"): scene_id = request_json.get("scene_id") - run_flags = request_json.get("run_flags", ["ais", "infra", "dark"]) + run_flags_empty = not request_json.get("run_flags", False) + run_flags = request_json.get( + "run_flags", ASA_MAPPING.keys() # expects list of integers + ) overwrite_previous = request_json.get("overwrite_previous", False) print(f"Running ASA ({run_flags}) on scene_id: {scene_id}") db_engine = get_engine(db_url=os.getenv("DB_URL")) async with DatabaseClient(db_engine) as db_client: async with db_client.session.begin(): s1_scene = await db_client.get_scene_from_id(scene_id) - slicks = await db_client.get_slicks_from_scene_id( - scene_id, with_sources=overwrite_previous - ) + slicks = await db_client.get_slicks_from_scene_id(scene_id) if overwrite_previous: for slick in slicks: print(f"Deactivating sources for slick {slick.id}") await db_client.deactivate_sources_for_slick(slick.id) + previous_asa = [ + asa_type + for slick_results in await asyncio.gather( + *(db_client.get_previous_asa(slick) for slick in slicks) + ) + for asa_type in slick_results + ] print(f"{len(slicks)} slicks in scene {scene_id}: {[s.id for s in slicks]}") if len(slicks) > 0: - analyzers = [ASA_MAPPING[source](s1_scene) for source in run_flags] + if run_flags_empty: + # default behavior is to run all ASA types that haven't been run yet + # if run_flags are provided, we will run those ASA types independently of whether they've been run before + run_flags = [ + asa_type + for asa_type in run_flags + if asa_type not in previous_asa + ] + analyzers = [ + ASA_MAPPING[source_type](s1_scene) for source_type in run_flags + ] random.shuffle(slicks) # Allows rerunning a scene to skip bugs for slick in slicks: # Convert slick geometry to GeoDataFrame diff --git a/cerulean_cloud/cloud_function_ais_analysis/utils/analyzer.py b/cerulean_cloud/cloud_function_ais_analysis/utils/analyzer.py index 317ae87d..7e01780d 100644 --- a/cerulean_cloud/cloud_function_ais_analysis/utils/analyzer.py +++ b/cerulean_cloud/cloud_function_ais_analysis/utils/analyzer.py @@ -938,7 +938,8 @@ def compute_coincidence_scores(self, slick_gdf: gpd.GeoDataFrame): ASA_MAPPING = { - "ais": AISAnalyzer, - "infra": InfrastructureAnalyzer, - "dark": DarkAnalyzer, + 1: AISAnalyzer, + 2: InfrastructureAnalyzer, + # 3: DarkAnalyzer, + # 4: NaturalAnalyzer, } diff --git a/cerulean_cloud/database_client.py b/cerulean_cloud/database_client.py index cfa697cf..179c55a9 100644 --- a/cerulean_cloud/database_client.py +++ b/cerulean_cloud/database_client.py @@ -357,4 +357,18 @@ async def deactivate_sources_for_slick(self, slick_id): .values(active=False) ) + async def get_previous_asa(self, slick): + """Return a list of ASA types that have been run for a slick.""" + return ( + ( + await self.session.execute( + select(db.Source.type) + .join(db.SlickToSource.source1) + .where(db.SlickToSource.slick == slick) + ) + ) + .scalars() + .all() + ) + # EditTheDatabase