Skip to content

Commit

Permalink
move filter_segmentations to classifier.py
Browse files Browse the repository at this point in the history
  • Loading branch information
2320sharon committed Dec 19, 2024
1 parent 81c607c commit 80e3559
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 25 deletions.
25 changes: 25 additions & 0 deletions src/coastseg/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,31 @@
# Some of these functions were originally written by Mark Lundine and have been modified for this project.


def filter_segmentations(
session_path: str,
) -> str:
"""
Sort model output files into "good" and "bad" folders based on the satellite name in the filename.
Applies the land mask to the model output files in the "good" folder.
Args:
session_path (str): The path to the session directory containing the model output files.
Returns:
str: The path to the "good" folder containing the sorted model output files.
"""
segmentation_classifier = get_segmentation_classifier()
good_path = os.path.join(session_path, "good")
csv_path,good_path,bad_path = run_inference_segmentation_classifier(segmentation_classifier,
session_path,
session_path,
good_path=good_path,
threshold=0.40)
# if the good folder does not exist then this means the classifier could not find any png files at the session path and something went wrong
if not os.path.exists(good_path):
raise FileNotFoundError(f"No model output files found at {session_path}. Shoreline Filtering failed.")
return good_path

def move_matching_files(input_image_path, search_string, file_exts, target_dir):
"""
Move files matching the given search string and file extensions to the target directory.
Expand Down
26 changes: 1 addition & 25 deletions src/coastseg/extracted_shoreline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1486,30 +1486,6 @@ def extract_shorelines_with_dask(
return combine_satellite_data(shoreline_dict)


def filter_segmentations(
session_path: str,
) -> str:
"""
Sort model output files into "good" and "bad" folders based on the satellite name in the filename.
Applies the land mask to the model output files in the "good" folder.
Args:
session_path (str): The path to the session directory containing the model output files.
Returns:
str: The path to the "good" folder containing the sorted model output files.
"""
segmentation_classifier = classifier.get_segmentation_classifier()
good_path = os.path.join(session_path, "good")
csv_path,good_path,bad_path = classifier.run_inference_segmentation_classifier(segmentation_classifier,
session_path,
session_path,
good_path=good_path,
threshold=0.40)
# if the good folder does not exist then this means the classifier could not find any png files at the session path and something went wrong
if not os.path.exists(good_path):
raise FileNotFoundError(f"No model output files found at {session_path}. Shoreline Filtering failed.")
return good_path


def get_min_shoreline_length(satname: str, default_min_length_sl: float) -> int:
Expand Down Expand Up @@ -1972,7 +1948,7 @@ def create_extracted_shorelines_from_session(
return self

# Filter the segmentations to only include the good segmentations, then update the metadata to only include the files with the good segmentations
good_directory = filter_segmentations(session_path)
good_directory = classifier.filter_segmentations(session_path)
metadata= common.filter_metadata_with_dates(metadata,good_directory,file_type="npz")

extracted_shorelines_dict = extract_shorelines_with_dask(
Expand Down

0 comments on commit 80e3559

Please sign in to comment.