Skip to content

Commit

Permalink
#197 fix zoo workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
2320sharon committed Oct 23, 2023
1 parent 817a475 commit 6cba930
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 33 deletions.
56 changes: 56 additions & 0 deletions src/coastseg/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,62 @@
# Logger setup
logger = logging.getLogger(__name__)

def create_new_config(roi_ids:list,settings:dict,roi_settings:dict)->dict:
"""
Creates a new configuration dictionary by combining the given settings and ROI settings.
Arguments:
-----------
roi_ids: list
A list of ROI IDs to include in the new configuration.
settings: dict
A dictionary containing general settings for the configuration.
roi_settings: dict
A dictionary containing ROI-specific settings for the configuration.
example:
{'example_roi_id': {'dates':[]}
Returns:
-----------
new_config: dict
A dictionary containing the combined settings and ROI settings, as well as the ROI IDs.
"""
new_config = {
'settings': {},
'roi_ids': [],

}
if isinstance(roi_ids, str):
roi_ids = [roi_ids]
if not all(roi_id in roi_settings.keys() for roi_id in roi_ids):
raise ValueError(f'roi_ids {roi_ids} not in roi_settings {roi_settings.keys()}')
new_config = {**new_config, **roi_settings}
new_config['roi_ids'].extend(roi_ids)
new_config['settings'] =settings
return new_config

def save_new_config(path:str,roi_ids:list, destination:str)->dict:
"""Save a new config file to a path.
Args:
path (str): the path to read the original config file from
roi_ids (list): a list of roi_ids to include in the new config file
destination (str):the path to save the new config file to
"""
with open(path) as f:
config = json.load(f)

if isinstance(roi_ids, str):
roi_ids = [roi_ids]

roi_settings = {}
for roi_id in roi_ids:
if roi_id in config.keys():
roi_settings[roi_id] = config[roi_id]

new_config=create_new_config(roi_ids,config['settings'],roi_settings)
with open(destination, "w") as f:
json.dump(new_config, f)

def filter_images_by_roi(roi_settings: list[dict]):
"""
Expand Down
83 changes: 61 additions & 22 deletions src/coastseg/extracted_shoreline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import json
import logging
import re
import os
from glob import glob
from time import perf_counter
Expand Down Expand Up @@ -245,7 +246,7 @@ def compute_transects_from_roi(
return cross_distance


def combine_satellite_data(satellite_data: dict):
def combine_satellite_data(satellite_data: dict) -> dict:
"""
Function to merge the satellite_data dictionary, which has one key per satellite mission
into a dictionnary containing all the shorelines and dates ordered chronologically.
Expand All @@ -263,7 +264,14 @@ def combine_satellite_data(satellite_data: dict):
"""
# Initialize merged_satellite_data dict
merged_satellite_data = {}
merged_satellite_data = {
"dates": [],
"geoaccuracy": [],
"shorelines": [],
"idx": [],
"satname": [],
}

# Iterate through satellite_data keys (satellite names)
for satname in satellite_data:
# Iterate through each key in the nested dictionary
Expand All @@ -287,14 +295,17 @@ def combine_satellite_data(satellite_data: dict):
# Add satellite name entries for each date
if "dates" in sat_data.keys():
merged_satellite_data["satname"] += [satname] * len(sat_data["dates"])
# Sort chronologically
idx_sorted = sorted(
range(len(merged_satellite_data["dates"])),
key=lambda i: merged_satellite_data["dates"][i],
)
# Sort dates chronologically
if "dates" in merged_satellite_data.keys():
idx_sorted = sorted(
range(len(merged_satellite_data["dates"])),
key=lambda i: merged_satellite_data["dates"][i],
)

for key in merged_satellite_data.keys():
merged_satellite_data[key] = [merged_satellite_data[key][i] for i in idx_sorted]
for key in merged_satellite_data.keys():
merged_satellite_data[key] = [
merged_satellite_data[key][i] for i in idx_sorted
]

return merged_satellite_data

Expand Down Expand Up @@ -1254,23 +1265,14 @@ def extract_shorelines_with_dask(
filepath_jpg = os.path.join(filepath_data, sitename, "jpg_files", "detection")
os.makedirs(filepath_jpg, exist_ok=True)

# for each satellite, sort the model outputs into good & bad
good_folder = os.path.join(session_path, "good")
bad_folder = os.path.join(session_path, "bad")
satellites = get_satellites_in_directory(session_path)
for satname in satellites:
# get all the model_outputs that have the satellite in the filename
# files = glob(f"{session_path}{os.sep}*{satname}*.npz")
files = file_utilities.find_files_recursively(
session_path, f"*{satname}*.npz", raise_error=False
)
if len(files) != 0:
filter_model_outputs(satname, files, good_folder, bad_folder)
# get the directory containing the good model outputs
good_folder = get_sorted_model_outputs_directory(session_path)

# get the list of files that were sorted as 'good'
filtered_files = get_filtered_files_dict(good_folder, "npz", sitename)
# keep only the metadata for the files that were sorted as 'good'
metadata = edit_metadata(metadata, filtered_files)
logger.info(f"edit_metadata metadata: {metadata}")

result_dict = {}
for satname in metadata:
Expand All @@ -1287,12 +1289,48 @@ def extract_shorelines_with_dask(
)
result_dict.update(satellite_dict)

# combine the extracted shorelines for each satellite
# combine the extracted shorelines for each satellite
logger.info(f"Combining extracted shorelines for each satellite : {result_dict}")
extracted_shorelines_data = combine_satellite_data(result_dict)

return extracted_shorelines_data


def get_sorted_model_outputs_directory(
session_path: str,
) -> str:
"""
Sort model output files into "good" and "bad" folders based on the satellite name in the filename.
Args:
session_path (str): The path to the session directory containing the model output files.
Returns:
str: The path to the "good" folder containing the sorted model output files.
"""
# for each satellite, sort the model outputs into good & bad
good_folder = os.path.join(session_path, "good")
bad_folder = os.path.join(session_path, "bad")
satellites = get_satellites_in_directory(session_path)
for satname in satellites:
if os.path.exists(good_folder) and os.listdir(good_folder)!= []:
return good_folder
else:
# get all the model_outputs that have the satellite in the filename
try:
# get all the model_outputs that have the satellite in the filename
files = file_utilities.find_files_recursively(
session_path, f".*{re.escape(satname)}.*\\.npz$", raise_error=False
)
except Exception as e:
logger.error(f"Error finding files for satellite {satname}: {e}")
continue
logger.info(f"{session_path} contained {satname} files: {files} ")
if len(files) != 0:
filter_model_outputs(satname, files, good_folder, bad_folder)
return good_folder


def get_min_shoreline_length(satname: str, default_min_length_sl: float) -> int:
"""
Given a satellite name and a default minimum shoreline length, returns the minimum shoreline length
Expand Down Expand Up @@ -1611,6 +1649,7 @@ def create_extracted_shorelines_from_session(
except FileNotFoundError as e:
logger.warning(f"No RGB files existed so no metadata.")
self.dictionary = {}
return self
else:
logger.info(f"metadata: {metadata}")

Expand Down
13 changes: 7 additions & 6 deletions src/coastseg/file_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,28 +627,29 @@ def find_directory_recursively(path: str = ".", name: str = "RGB") -> str:


def find_files_recursively(
path: str = ".", name: str = "*RGB*", raise_error: bool = False
path: str = ".", search_pattern: str = "*RGB*", raise_error: bool = False
) -> List[str]:
"""
Recursively search for files with the given name in the given path or its subdirectories.
Recursively search for files with the given search pattern in the given path or its subdirectories.
Args:
path (str): The starting directory to search in. Defaults to current directory.
name (str): The name of the files to search for. Defaults to "RGB".
search_pattern (str): The search pattern to match against file names. Defaults to "*RGB*".
raise_error (bool): Whether to raise an error if no files are found. Defaults to False.
Returns:
list: A list of paths to all files that match the given name.
list: A list of paths to all files that match the given search pattern.
"""
file_locations = []
regex = re.compile(search_pattern, re.IGNORECASE)
for dirpath, dirnames, filenames in os.walk(path):
for filename in filenames:
if filename == name:
if regex.match(filename):
file_location = os.path.join(dirpath, filename)
file_locations.append(file_location)

if not file_locations and raise_error:
raise Exception(f"No files matching {name} could be found at {path}")
raise Exception(f"No files matching {search_pattern} could be found at {path}")

return file_locations

Expand Down
9 changes: 6 additions & 3 deletions src/coastseg/filters.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import datetime
import logging
from statistics import mode
import numpy as np
import xarray as xr
from glob import glob
import os, shutil
from sklearn.cluster import KMeans
from statistics import mode

# Logger setup
logger = logging.getLogger(__name__)


def copy_files(files: list, dest_folder: str) -> None:
"""
Expand Down Expand Up @@ -153,7 +155,8 @@ def handle_files_and_directories(
"""
os.makedirs(dest_folder_bad, exist_ok=True)
os.makedirs(dest_folder_good, exist_ok=True)

logger.info(f"Copying {len(files_bad)} files to {dest_folder_bad}")
logger.info(f"Copying {len(files_good)} files to {dest_folder_good}")
copy_files(files_bad, dest_folder_bad)
copy_files(files_good, dest_folder_good)

Expand Down
18 changes: 16 additions & 2 deletions src/coastseg/zoo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import platform
import json
import logging
import shutil
from typing import List, Set, Tuple

from coastsat import SDS_tools
Expand Down Expand Up @@ -855,8 +856,19 @@ def postprocess_data(
raise FileNotFoundError(
f"Config files config.json or config_gdf.geojson do not exist in roi directory {roi_directory}\n This means that the download did not complete successfully."
)
# copy configs from data/roi_id location to session location
common.copy_configs(roi_directory, session_path)
# modify the config.json to only have the ROI ID that was used and save to session directory
roi_id = file_utilities.extract_roi_id(roi_directory)
common.save_new_config(
os.path.join(roi_directory, "config.json"),
roi_id,
os.path.join(session_path, "config.json"),
)
# Copy over the config_gdf.geojson file
config_gdf_path = os.path.join(roi_directory, "config_gdf.geojson")
if os.path.exists(config_gdf_path):
shutil.copy(
config_gdf_path, os.path.join(session_path, "config_gdf.geojson")
)
model_settings_path = os.path.join(session_path, "model_settings.json")
file_utilities.write_to_json(model_settings_path, preprocessed_data)

Expand Down Expand Up @@ -950,12 +962,14 @@ def run_model(
roi_directory = file_utilities.find_parent_directory(
src_directory, "ID_", "data"
)

print(f"Preprocessing the data at {roi_directory}")
model_dict = self.preprocess_data(roi_directory, model_dict, img_type)
logger.info(f"model_dict: {model_dict}")

self.compute_segmentation(model_dict, percent_no_data)
self.postprocess_data(model_dict, session, roi_directory)
session.add_roi_ids([file_utilities.extract_roi_id(roi_directory)])
print(f"\n Model results saved to {session.path}")

def get_model_directory(self, model_id: str):
Expand Down

0 comments on commit 6cba930

Please sign in to comment.