From 80e355985ef7dc3f665a92baf305f73e715541af Mon Sep 17 00:00:00 2001 From: Sharon Fitzpatrick Date: Thu, 19 Dec 2024 09:39:29 -0800 Subject: [PATCH] move filter_segmentations to classifier.py --- src/coastseg/classifier.py | 25 +++++++++++++++++++++++++ src/coastseg/extracted_shoreline.py | 26 +------------------------- 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/src/coastseg/classifier.py b/src/coastseg/classifier.py index b0e45b91..f38d7ea4 100644 --- a/src/coastseg/classifier.py +++ b/src/coastseg/classifier.py @@ -12,6 +12,31 @@ # Some of these functions were originally written by Mark Lundine and have been modified for this project. +def filter_segmentations( + session_path: str, +) -> str: + """ + Sort model output files into "good" and "bad" folders based on the satellite name in the filename. + Applies the land mask to the model output files in the "good" folder. + + Args: + session_path (str): The path to the session directory containing the model output files. + + Returns: + str: The path to the "good" folder containing the sorted model output files. + """ + segmentation_classifier = get_segmentation_classifier() + good_path = os.path.join(session_path, "good") + csv_path,good_path,bad_path = run_inference_segmentation_classifier(segmentation_classifier, + session_path, + session_path, + good_path=good_path, + threshold=0.40) + # if the good folder does not exist then this means the classifier could not find any png files at the session path and something went wrong + if not os.path.exists(good_path): + raise FileNotFoundError(f"No model output files found at {session_path}. Shoreline Filtering failed.") + return good_path + def move_matching_files(input_image_path, search_string, file_exts, target_dir): """ Move files matching the given search string and file extensions to the target directory. diff --git a/src/coastseg/extracted_shoreline.py b/src/coastseg/extracted_shoreline.py index a7acf244..34c41760 100644 --- a/src/coastseg/extracted_shoreline.py +++ b/src/coastseg/extracted_shoreline.py @@ -1486,30 +1486,6 @@ def extract_shorelines_with_dask( return combine_satellite_data(shoreline_dict) -def filter_segmentations( - session_path: str, -) -> str: - """ - Sort model output files into "good" and "bad" folders based on the satellite name in the filename. - Applies the land mask to the model output files in the "good" folder. - - Args: - session_path (str): The path to the session directory containing the model output files. - - Returns: - str: The path to the "good" folder containing the sorted model output files. - """ - segmentation_classifier = classifier.get_segmentation_classifier() - good_path = os.path.join(session_path, "good") - csv_path,good_path,bad_path = classifier.run_inference_segmentation_classifier(segmentation_classifier, - session_path, - session_path, - good_path=good_path, - threshold=0.40) - # if the good folder does not exist then this means the classifier could not find any png files at the session path and something went wrong - if not os.path.exists(good_path): - raise FileNotFoundError(f"No model output files found at {session_path}. Shoreline Filtering failed.") - return good_path def get_min_shoreline_length(satname: str, default_min_length_sl: float) -> int: @@ -1972,7 +1948,7 @@ def create_extracted_shorelines_from_session( return self # Filter the segmentations to only include the good segmentations, then update the metadata to only include the files with the good segmentations - good_directory = filter_segmentations(session_path) + good_directory = classifier.filter_segmentations(session_path) metadata= common.filter_metadata_with_dates(metadata,good_directory,file_type="npz") extracted_shorelines_dict = extract_shorelines_with_dask(