Skip to content

Commit

Permalink
Merge branch 'main' into asa_analysis
Browse files Browse the repository at this point in the history
# Conflicts:
#	notebooks/ASA_test_environment.py
  • Loading branch information
jonaraphael committed Jan 2, 2025
2 parents b853370 + 6fb09e8 commit 56d8e48
Show file tree
Hide file tree
Showing 17 changed files with 247 additions and 52,053 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ jobs:
name: Deploy [TEST]
runs-on: ubuntu-20.04
environment: test
concurrency: test
needs: [tests]
if: github.event_name == 'workflow_dispatch'
steps:
Expand All @@ -133,6 +134,7 @@ jobs:
name: Deploy [STAGING]
runs-on: ubuntu-20.04
environment: staging
concurrency: staging
needs: [tests]
if: github.ref == 'refs/heads/main'
steps:
Expand All @@ -155,6 +157,7 @@ jobs:
name: Deploy [PRODUCTION]
runs-on: ubuntu-20.04
environment: prod20240903
concurrency: prod20240903
needs: [tests]
if: startsWith(github.event.ref, 'refs/tags')
steps:
Expand Down
1 change: 1 addition & 0 deletions alembic/versions/3c4693517ef6_add_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def upgrade() -> None:
sa.Column("slick", sa.BigInteger, sa.ForeignKey("slick.id"), nullable=False),
sa.Column("source", sa.BigInteger, sa.ForeignKey("source.id"), nullable=False),
sa.Column("active", sa.Boolean, nullable=False),
sa.Column("git_hash", sa.Text),
sa.Column("coincidence_score", sa.Float),
sa.Column("collated_score", sa.Float),
sa.Column("rank", sa.BigInteger),
Expand Down
6 changes: 6 additions & 0 deletions alembic/versions/7cd715196b8d_add_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def upgrade() -> None:

op.create_index("idx_source_name", "source", ["st_name", "type"])

op.create_index(
"idx_slick_to_source_collated_score", "slick_to_source", ["collated_score"]
)

op.create_index("idx_slick_to_aoi_slick", "slick_to_aoi", ["slick"])
op.create_index("idx_slick_to_aoi_aoi", "slick_to_aoi", ["aoi"])

Expand Down Expand Up @@ -76,6 +80,8 @@ def downgrade() -> None:

op.drop_index("idx_source_name", "source")

op.drop_index("idx_slick_to_source_collated_score", "slick_to_source")

op.drop_index("idx_filter_hash", "filter")

op.drop_index("idx_slick_hitl", "slick")
Expand Down

This file was deleted.

42 changes: 32 additions & 10 deletions cerulean_cloud/cloud_function_ais_analysis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,32 +74,53 @@ async def handle_asa_request(request):
request_json = request.get_json()
if not request_json.get("dry_run"):
scene_id = request_json.get("scene_id")
run_flags = request_json.get("run_flags", ["ais", "infra", "dark"])
run_flags = request_json.get("run_flags") # expects list of integers
if not run_flags:
run_flags = list(ASA_MAPPING.keys())
elif any(run_flag not in ASA_MAPPING.keys() for run_flag in run_flags):
raise ValueError(
f"Invalid run_flag provided. {run_flags} not in {ASA_MAPPING.keys()}"
)

overwrite_previous = request_json.get("overwrite_previous", False)
print(f"Running ASA ({run_flags}) on scene_id: {scene_id}")
db_engine = get_engine(db_url=os.getenv("DB_URL"))
async with DatabaseClient(db_engine) as db_client:
async with db_client.session.begin():
s1_scene = await db_client.get_scene_from_id(scene_id)
slicks = await db_client.get_slicks_from_scene_id(
scene_id, with_sources=overwrite_previous
)
if overwrite_previous:
for slick in slicks:
slicks = await db_client.get_slicks_from_scene_id(scene_id)
previous_asa = {}
for slick in slicks:
if overwrite_previous:
print(f"Deactivating sources for slick {slick.id}")
await db_client.deactivate_sources_for_slick(slick.id)
previous_asa[slick.id] = []
else:
previous_asa[slick.id] = await db_client.get_previous_asa(
slick.id
)

print(f"Running ASA ({run_flags}) on scene_id: {scene_id}")
print(f"{len(slicks)} slicks in scene {scene_id}: {[s.id for s in slicks]}")
if len(slicks) > 0:
analyzers = [ASA_MAPPING[source](s1_scene) for source in run_flags]
analyzers = [
ASA_MAPPING[source_type](s1_scene) for source_type in run_flags
]
random.shuffle(slicks) # Allows rerunning a scene to skip bugs
for slick in slicks:
analyzers_to_run = [
analyzer
for analyzer in analyzers
if analyzer.source_type not in previous_asa[slick.id]
]
if len(analyzers_to_run) == 0:
continue

# Convert slick geometry to GeoDataFrame
slick_geom = wkb.loads(str(slick.geometry)).buffer(0)
slick_gdf = gpd.GeoDataFrame({"geometry": [slick_geom]}, crs="4326")
ranked_sources = pd.DataFrame()

for analyzer in analyzers:
for analyzer in analyzers_to_run:
res = analyzer.compute_coincidence_scores(slick_gdf)
ranked_sources = pd.concat(
[ranked_sources, res], ignore_index=True
Expand All @@ -110,7 +131,7 @@ async def handle_asa_request(request):
)
if len(ranked_sources) > 0:
ranked_sources = ranked_sources.sort_values(
"coincidence_score", ascending=False
"collated_score", ascending=False
).reset_index(drop=True)
async with db_client.session.begin():
for idx, source_row in ranked_sources.iloc[:5].iterrows():
Expand All @@ -131,6 +152,7 @@ async def handle_asa_request(request):
source=source.id,
slick=slick.id,
active=True,
git_hash=os.getenv("GIT_HASH"),
coincidence_score=source_row["coincidence_score"],
collated_score=source_row["collated_score"],
rank=idx + 1,
Expand Down
41 changes: 16 additions & 25 deletions cerulean_cloud/cloud_function_ais_analysis/utils/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(self, s1_scene, **kwargs):
Initialize the InfrastructureAnalyzer.
"""
super().__init__(s1_scene, **kwargs)
self.source_type = 2
self.num_vertices = kwargs.get("num_vertices", NUM_VERTICES)
self.closing_buffer = kwargs.get("closing_buffer", CLOSING_BUFFER)
self.radius_of_interest = kwargs.get("radius_of_interest", INFRA_REF_DIST)
Expand All @@ -123,10 +124,10 @@ def __init__(self, s1_scene, **kwargs):

if self.infra_gdf is None:
self.infra_api_token = os.getenv("INFRA_API_TOKEN")
self.infra_gdf = self.load_infrastructure_data()
self.infra_gdf = self.load_infrastructure_data_api()
self.coincidence_scores = np.zeros(len(self.infra_gdf))

def load_infrastructure_data(self, only_oil=True):
def load_infrastructure_data_csv(self, only_oil=True):
"""
Loads infrastructure data from a CSV file.
"""
Expand Down Expand Up @@ -448,6 +449,7 @@ def __init__(self, s1_scene, **kwargs):
Initialize the AISAnalyzer.
"""
super().__init__(s1_scene, **kwargs)
self.source_type = 1
self.s1_scene = s1_scene
# Default parameters
self.hours_before = kwargs.get("hours_before", HOURS_BEFORE)
Expand Down Expand Up @@ -642,7 +644,6 @@ def buffer_trajectories(self):
def slick_to_curves(
self,
buf_size: int = 2000,
interp_dist: int = 200,
smoothing_factor: float = 1e9,
):
"""
Expand All @@ -651,7 +652,6 @@ def slick_to_curves(
Inputs:
buf_size: buffer size for cleaning up slick detections
interp_dist: interpolation distance for centerline
smoothing_factor: smoothing factor for smoothing centerline
Returns:
GeoDataFrame of slick curves
Expand All @@ -671,24 +671,13 @@ def slick_to_curves(
slick_curves = list()
for _, row in slick_clean.iterrows():
# create centerline -> MultiLineString
try:
cl = centerline.geometry.Centerline(
row.geometry, interpolation_distance=interp_dist
)
except (
Exception
) as e: # noqa # unclear what exception was originally thrown here.
# sometimes the voronoi polygonization fails
# in this case, just fit a a simple line from the start to the end
exterior_coords = row.geometry.exterior.coords
start_point = exterior_coords[0]
end_point = exterior_coords[-1]
curve = shapely.geometry.LineString([start_point, end_point])
slick_curves.append(curve)
print(
f"XXX ~WARNING~ Blanket try/except caught error but continued on anyway: {e}"
)
continue
polygon_perimeter = row.geometry.length # Perimeter of the polygon
interp_dist = min(
100, polygon_perimeter / 1000
) # Use a minimum of 1000 points for voronoi calculation
cl = centerline.geometry.Centerline(
row.geometry, interpolation_distance=interp_dist
)

# grab coordinates from centerline
x = list()
Expand Down Expand Up @@ -927,6 +916,7 @@ def __init__(self, s1_scene, **kwargs):
Initialize the DarkAnalyzer.
"""
super().__init__(s1_scene, **kwargs)
self.source_type = 3
# Initialize attributes specific to dark vessel analysis

def compute_coincidence_scores(self, slick_gdf: gpd.GeoDataFrame):
Expand All @@ -938,7 +928,8 @@ def compute_coincidence_scores(self, slick_gdf: gpd.GeoDataFrame):


ASA_MAPPING = {
"ais": AISAnalyzer,
"infra": InfrastructureAnalyzer,
"dark": DarkAnalyzer,
1: AISAnalyzer,
2: InfrastructureAnalyzer,
# 3: DarkAnalyzer,
# 4: NaturalAnalyzer,
}
25 changes: 16 additions & 9 deletions cerulean_cloud/cloud_function_scene_relevancy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,23 @@ def handle_notification(request_json, ocean_poly):
for r in request_json.get("Records"):
sns = r["Sns"]
msg = json.loads(sns["Message"])
scene_poly = sh.polygon.Polygon(msg["footprint"]["coordinates"][0][0])

is_highdef = msg["id"][10] == "H"
is_vv = (
msg["id"][15] == "V"
) # we don't want to process any polarization other than vv XXX This is hardcoded in the server, where we look for a vv.grd file
is_oceanic = scene_poly.intersects(ocean_poly)
print(is_highdef, is_vv, is_oceanic)
if is_highdef and is_vv and is_oceanic:
filtered_scenes.append(msg["id"])
if not (msg["id"][4:6] == "IW"):
# Check Beam Mode
# XXX This is workaround a bug in Titiler.get_bounds (404 not found) that fails if the beam mode is not IW
continue
if not (msg["id"][10] == "H"):
# Check High Definition
continue
if not (msg["id"][15] == "V"):
# Check Polarization
# XXX This is hardcoded in the server, where we look for a vv.grd file
continue
scene_poly = sh.polygon.Polygon(msg["footprint"]["coordinates"][0][0])
if not (scene_poly.intersects(ocean_poly)):
# Check Oceanic
continue
filtered_scenes.append(msg["id"])
return filtered_scenes


Expand Down
5 changes: 3 additions & 2 deletions cerulean_cloud/cloud_run_offset_tiles/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ starlette-cramjam==0.1.0
uvicorn[standard]
rasterio==1.3.0
setuptools==59.5.0
torch==1.11.0
torchvision==0.12.0
-f https://download.pytorch.org/whl/cpu/torch_stable.html
torch==1.11.0+cpu
torchvision==0.12.0+cpu
numpy<2.0.0
typing-inspect
fastapi-utils
Expand Down
5 changes: 3 additions & 2 deletions cerulean_cloud/cloud_run_orchestrator/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ pygeos
networkx
google-cloud-tasks
protobuf
torch==1.11.0
torchvision==0.12.0
-f https://download.pytorch.org/whl/cpu/torch_stable.html
torch==1.11.0+cpu
torchvision==0.12.0+cpu
scipy
scikit-image
pydantic<2.0
21 changes: 20 additions & 1 deletion cerulean_cloud/database_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ async def deactivate_stale_slicks_from_scene_id(self, scene_id):
.where(
and_(
db.Sentinel1Grd.scene_id == scene_id,
db.Slick.active == True, # noqa
db.Slick.active,
)
)
)
Expand All @@ -357,4 +357,23 @@ async def deactivate_sources_for_slick(self, slick_id):
.values(active=False)
)

async def get_previous_asa(self, slick_id):
"""Return a list of ASA types that have been run for a slick."""
return (
(
await self.session.execute(
select(db.Source.type)
.join(db.SlickToSource.source1)
.where(
and_(
db.SlickToSource.slick == slick_id,
db.SlickToSource.active,
)
)
)
)
.scalars()
.all()
)

# EditTheDatabase
1 change: 1 addition & 0 deletions cerulean_cloud/database_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ class SlickToSource(Base): # noqa
slick = Column(ForeignKey("slick.id"), nullable=False)
source = Column(ForeignKey("source.id"), nullable=False)
active = Column(Boolean, nullable=False)
git_hash = Column(Text)
coincidence_score = Column(Float(53))
collated_score = Column(Float(53))
rank = Column(BigInteger)
Expand Down
15 changes: 8 additions & 7 deletions notebooks/ASA_test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def plot_coincidence(
plt.show()


analyzers: dict[str, SourceAnalyzer] = {}
analyzers: dict[int, SourceAnalyzer] = {}

# %%
slick_ids = [
Expand All @@ -243,8 +243,8 @@ def plot_coincidence(
s1_scene = get_s1_scene(slick_gdf.s1_scene_id.iloc[0])

source_types = []
source_types += ["infra"]
source_types += ["ais"]
source_types += [1] # ais
source_types += [2] # infra
if not ( # If the last analyzer is for the same scene, reuse it
analyzers
and next(iter(analyzers.items()))[1].s1_scene.scene_id == s1_scene.scene_id
Expand All @@ -271,12 +271,13 @@ def plot_coincidence(
]
)

if "infra" in analyzers:
plot_coincidence(analyzers["infra"], slick_id)
if 2 in analyzers.keys():
plot_coincidence(analyzers[2], slick_id)

print(ranked_sources[["type", "st_name", "collated_score"]].head())
print(
ranked_sources[["type", "ext_id", "coincidence_score", "collated_score"]].head()
)

print(ranked_sources.head())
# print(accumulated_sources)
# %%
fake_infra_gdf = generate_infrastructure_points(slick_gdf, 50000)
Expand Down
8 changes: 5 additions & 3 deletions stack/cloud_function_ais_analysis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""cloud function to find slick culprits from AIS tracks"""

import time

import database
import git
import pulumi
from pulumi_gcp import cloudfunctions, cloudtasks, projects, serviceaccount, storage
from utils import construct_name, pulumi_create_zip
Expand Down Expand Up @@ -35,10 +34,13 @@
),
)

repo = git.Repo(search_parent_directories=True)
git_sha = repo.head.object.hexsha

function_name = construct_name("cf-ais")
config_values = {
"DB_URL": database.sql_instance_url_with_asyncpg,
"GIT_HASH": git_sha,
}

# The Cloud Function source code itself needs to be zipped up into an
Expand All @@ -54,7 +56,7 @@
# source code. ("main.py" and "requirements.txt".)
source_archive_object = storage.BucketObject(
construct_name("source-cf-ais"),
name=f"handler.py-{time.time():f}",
name="handler.py",
bucket=bucket.name,
source=archive,
)
Expand Down
Loading

0 comments on commit 56d8e48

Please sign in to comment.