Skip to content

Commit

Permalink
Merge branch 'main' into bug/asa_qhull
Browse files Browse the repository at this point in the history
  • Loading branch information
jonaraphael committed Dec 6, 2024
2 parents 4294475 + edae49d commit 1aae9a4
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 52,013 deletions.
1 change: 1 addition & 0 deletions alembic/versions/3c4693517ef6_add_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def upgrade() -> None:
sa.Column("slick", sa.BigInteger, sa.ForeignKey("slick.id"), nullable=False),
sa.Column("source", sa.BigInteger, sa.ForeignKey("source.id"), nullable=False),
sa.Column("active", sa.Boolean, nullable=False),
sa.Column("git_hash", sa.Text),
sa.Column("coincidence_score", sa.Float),
sa.Column("collated_score", sa.Float),
sa.Column("rank", sa.BigInteger),
Expand Down
6 changes: 6 additions & 0 deletions alembic/versions/7cd715196b8d_add_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def upgrade() -> None:

op.create_index("idx_source_name", "source", ["st_name", "type"])

op.create_index(
"idx_slick_to_source_collated_score", "slick_to_source", ["collated_score"]
)

op.create_index("idx_slick_to_aoi_slick", "slick_to_aoi", ["slick"])
op.create_index("idx_slick_to_aoi_aoi", "slick_to_aoi", ["aoi"])

Expand Down Expand Up @@ -76,6 +80,8 @@ def downgrade() -> None:

op.drop_index("idx_source_name", "source")

op.drop_index("idx_slick_to_source_collated_score", "slick_to_source")

op.drop_index("idx_filter_hash", "filter")

op.drop_index("idx_slick_hitl", "slick")
Expand Down

This file was deleted.

42 changes: 32 additions & 10 deletions cerulean_cloud/cloud_function_ais_analysis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,32 +74,53 @@ async def handle_asa_request(request):
request_json = request.get_json()
if not request_json.get("dry_run"):
scene_id = request_json.get("scene_id")
run_flags = request_json.get("run_flags", ["ais", "infra", "dark"])
run_flags = request_json.get("run_flags") # expects list of integers
if not run_flags:
run_flags = list(ASA_MAPPING.keys())
elif any(run_flag not in ASA_MAPPING.keys() for run_flag in run_flags):
raise ValueError(
f"Invalid run_flag provided. {run_flags} not in {ASA_MAPPING.keys()}"
)

overwrite_previous = request_json.get("overwrite_previous", False)
print(f"Running ASA ({run_flags}) on scene_id: {scene_id}")
db_engine = get_engine(db_url=os.getenv("DB_URL"))
async with DatabaseClient(db_engine) as db_client:
async with db_client.session.begin():
s1_scene = await db_client.get_scene_from_id(scene_id)
slicks = await db_client.get_slicks_from_scene_id(
scene_id, with_sources=overwrite_previous
)
if overwrite_previous:
for slick in slicks:
slicks = await db_client.get_slicks_from_scene_id(scene_id)
previous_asa = {}
for slick in slicks:
if overwrite_previous:
print(f"Deactivating sources for slick {slick.id}")
await db_client.deactivate_sources_for_slick(slick.id)
previous_asa[slick.id] = []
else:
previous_asa[slick.id] = await db_client.get_previous_asa(
slick.id
)

print(f"Running ASA ({run_flags}) on scene_id: {scene_id}")
print(f"{len(slicks)} slicks in scene {scene_id}: {[s.id for s in slicks]}")
if len(slicks) > 0:
analyzers = [ASA_MAPPING[source](s1_scene) for source in run_flags]
analyzers = [
ASA_MAPPING[source_type](s1_scene) for source_type in run_flags
]
random.shuffle(slicks) # Allows rerunning a scene to skip bugs
for slick in slicks:
analyzers_to_run = [
analyzer
for analyzer in analyzers
if analyzer.source_type not in previous_asa[slick.id]
]
if len(analyzers_to_run) == 0:
continue

# Convert slick geometry to GeoDataFrame
slick_geom = wkb.loads(str(slick.geometry)).buffer(0)
slick_gdf = gpd.GeoDataFrame({"geometry": [slick_geom]}, crs="4326")
ranked_sources = pd.DataFrame()

for analyzer in analyzers:
for analyzer in analyzers_to_run:
res = analyzer.compute_coincidence_scores(slick_gdf)
ranked_sources = pd.concat(
[ranked_sources, res], ignore_index=True
Expand All @@ -110,7 +131,7 @@ async def handle_asa_request(request):
)
if len(ranked_sources) > 0:
ranked_sources = ranked_sources.sort_values(
"coincidence_score", ascending=False
"collated_score", ascending=False
).reset_index(drop=True)
async with db_client.session.begin():
for idx, source_row in ranked_sources.iloc[:5].iterrows():
Expand All @@ -131,6 +152,7 @@ async def handle_asa_request(request):
source=source.id,
slick=slick.id,
active=True,
git_hash=os.getenv("GIT_HASH"),
coincidence_score=source_row["coincidence_score"],
collated_score=source_row["collated_score"],
rank=idx + 1,
Expand Down
14 changes: 9 additions & 5 deletions cerulean_cloud/cloud_function_ais_analysis/utils/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(self, s1_scene, **kwargs):
Initialize the InfrastructureAnalyzer.
"""
super().__init__(s1_scene, **kwargs)
self.source_type = 2
self.num_vertices = kwargs.get("num_vertices", NUM_VERTICES)
self.closing_buffer = kwargs.get("closing_buffer", CLOSING_BUFFER)
self.radius_of_interest = kwargs.get("radius_of_interest", INFRA_REF_DIST)
Expand All @@ -123,10 +124,10 @@ def __init__(self, s1_scene, **kwargs):

if self.infra_gdf is None:
self.infra_api_token = os.getenv("INFRA_API_TOKEN")
self.infra_gdf = self.load_infrastructure_data()
self.infra_gdf = self.load_infrastructure_data_api()
self.coincidence_scores = np.zeros(len(self.infra_gdf))

def load_infrastructure_data(self, only_oil=True):
def load_infrastructure_data_csv(self, only_oil=True):
"""
Loads infrastructure data from a CSV file.
"""
Expand Down Expand Up @@ -451,6 +452,7 @@ def __init__(self, s1_scene, **kwargs):
Initialize the AISAnalyzer.
"""
super().__init__(s1_scene, **kwargs)
self.source_type = 1
self.s1_scene = s1_scene
# Default parameters
self.hours_before = kwargs.get("hours_before", HOURS_BEFORE)
Expand Down Expand Up @@ -918,6 +920,7 @@ def __init__(self, s1_scene, **kwargs):
Initialize the DarkAnalyzer.
"""
super().__init__(s1_scene, **kwargs)
self.source_type = 3
# Initialize attributes specific to dark vessel analysis

def compute_coincidence_scores(self, slick_gdf: gpd.GeoDataFrame):
Expand All @@ -929,7 +932,8 @@ def compute_coincidence_scores(self, slick_gdf: gpd.GeoDataFrame):


ASA_MAPPING = {
"ais": AISAnalyzer,
"infra": InfrastructureAnalyzer,
"dark": DarkAnalyzer,
1: AISAnalyzer,
2: InfrastructureAnalyzer,
# 3: DarkAnalyzer,
# 4: NaturalAnalyzer,
}
21 changes: 20 additions & 1 deletion cerulean_cloud/database_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ async def deactivate_stale_slicks_from_scene_id(self, scene_id):
.where(
and_(
db.Sentinel1Grd.scene_id == scene_id,
db.Slick.active == True, # noqa
db.Slick.active,
)
)
)
Expand All @@ -357,4 +357,23 @@ async def deactivate_sources_for_slick(self, slick_id):
.values(active=False)
)

async def get_previous_asa(self, slick_id):
"""Return a list of ASA types that have been run for a slick."""
return (
(
await self.session.execute(
select(db.Source.type)
.join(db.SlickToSource.source1)
.where(
and_(
db.SlickToSource.slick == slick_id,
db.SlickToSource.active,
)
)
)
)
.scalars()
.all()
)

# EditTheDatabase
1 change: 1 addition & 0 deletions cerulean_cloud/database_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ class SlickToSource(Base): # noqa
slick = Column(ForeignKey("slick.id"), nullable=False)
source = Column(ForeignKey("source.id"), nullable=False)
active = Column(Boolean, nullable=False)
git_hash = Column(Text)
coincidence_score = Column(Float(53))
collated_score = Column(Float(53))
rank = Column(BigInteger)
Expand Down
23 changes: 9 additions & 14 deletions notebooks/ASA_test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def plot_coincidence(
plt.show()


analyzers: dict[str, SourceAnalyzer] = {}
analyzers: dict[int, SourceAnalyzer] = {}

# %%
slick_ids = [
Expand All @@ -243,19 +243,13 @@ def plot_coincidence(
s1_scene = get_s1_scene(slick_gdf.s1_scene_id.iloc[0])

source_types = []
source_types += ["infra"]
# source_types += ["ais"]
source_types += [1] # ais
source_types += [2] # infra
if not ( # If the last analyzer is for the same scene, reuse it
analyzers
and next(iter(analyzers.items()))[1].s1_scene.scene_id == s1_scene.scene_id
):
analyzers = {
s_type: ASA_MAPPING[s_type](
s1_scene,
gfw_infra_filepath="/Users/jonathanraphael/git/cerulean-cloud/cerulean_cloud/cloud_function_ais_analysis/SAR Fixed Infrastructure 202407 DENOISED UNIQUE.csv",
)
for s_type in source_types
}
analyzers = {s_type: ASA_MAPPING[s_type](s1_scene) for s_type in source_types}

ranked_sources = pd.DataFrame(columns=["type", "st_name", "collated_score"])
for s_type, analyzer in analyzers.items():
Expand All @@ -275,12 +269,13 @@ def plot_coincidence(
]
)

if "infra" in analyzers:
plot_coincidence(analyzers["infra"], slick_id)
if 2 in analyzers.keys():
plot_coincidence(analyzers[2], slick_id)

print(ranked_sources[["type", "st_name", "collated_score"]].head())
print(
ranked_sources[["type", "ext_id", "coincidence_score", "collated_score"]].head()
)

print(ranked_sources.head())
# print(accumulated_sources)
# %%
fake_infra_gdf = generate_infrastructure_points(slick_gdf, 50000)
Expand Down
4 changes: 4 additions & 0 deletions stack/cloud_function_ais_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time

import database
import git
import pulumi
from pulumi_gcp import cloudfunctions, cloudtasks, projects, serviceaccount, storage
from utils import construct_name, pulumi_create_zip
Expand Down Expand Up @@ -35,10 +36,13 @@
),
)

repo = git.Repo(search_parent_directories=True)
git_sha = repo.head.object.hexsha

function_name = construct_name("cf-ais")
config_values = {
"DB_URL": database.sql_instance_url_with_asyncpg,
"GIT_HASH": git_sha,
}

# The Cloud Function source code itself needs to be zipped up into an
Expand Down

0 comments on commit 1aae9a4

Please sign in to comment.