diff --git a/3_zoo_workflow.py b/3_zoo_workflow.py index 0f93ee9..0e19c7f 100644 --- a/3_zoo_workflow.py +++ b/3_zoo_workflow.py @@ -27,6 +27,7 @@ "model_type": "global_segformer_RGB_4class_14036903", # model name from the zoo "otsu": False, # Otsu Thresholding "tta": False, # Test Time Augmentation + "apply_segmentation_filter": True, # apply segmentation filter to the model outputs to sort them into good or bad } # Available models can run input "RGB" # or "MNDWI" or "NDWI" img_type = "RGB" # make sure the model name is compatible with the image type @@ -39,7 +40,7 @@ model_session_name = "sample_session_demo1" # b. ENTER THE DIRECTORY WHERE THE INPUT IMAGES ARE STORED # - Example of the directory where the input images are stored ( this should be the /data folder in the CoastSeg directory) -sample_directory = "C:\development\doodleverse\coastseg\CoastSeg\data\ID_wra5_datetime03-04-24__03_43_01\jpg_files\preprocessed\RGB" +sample_directory = r"C:\development\doodleverse\coastseg\CoastSeg\data\ID_wra5_datetime03-04-24__03_43_01\jpg_files\preprocessed\RGB" # 2. Save the settings to the model instance diff --git a/src/coastseg/classifier.py b/src/coastseg/classifier.py index f38d7ea..9d2da73 100644 --- a/src/coastseg/classifier.py +++ b/src/coastseg/classifier.py @@ -14,6 +14,7 @@ def filter_segmentations( session_path: str, + threshold: float = 0.40, ) -> str: """ Sort model output files into "good" and "bad" folders based on the satellite name in the filename. @@ -31,7 +32,7 @@ def filter_segmentations( session_path, session_path, good_path=good_path, - threshold=0.40) + threshold=threshold) # if the good folder does not exist then this means the classifier could not find any png files at the session path and something went wrong if not os.path.exists(good_path): raise FileNotFoundError(f"No model output files found at {session_path}. Shoreline Filtering failed.") diff --git a/src/coastseg/extracted_shoreline.py b/src/coastseg/extracted_shoreline.py index 34c4176..8ed38fb 100644 --- a/src/coastseg/extracted_shoreline.py +++ b/src/coastseg/extracted_shoreline.py @@ -1843,6 +1843,7 @@ def create_extracted_shorelines_from_session( new_session_path: str = None, output_directory: str = None, shoreline_extraction_area : gpd.GeoDataFrame = None, + apply_segmentation_filter: bool = True, **kwargs: dict, ) -> "Extracted_Shoreline": """ @@ -1948,8 +1949,11 @@ def create_extracted_shorelines_from_session( return self # Filter the segmentations to only include the good segmentations, then update the metadata to only include the files with the good segmentations - good_directory = classifier.filter_segmentations(session_path) - metadata= common.filter_metadata_with_dates(metadata,good_directory,file_type="npz") + good_directory = session_path + if apply_segmentation_filter: + good_directory = classifier.filter_segmentations(session_path) + # Filter the metadata to only include the files with segmentations that are in the good_directory + metadata= common.filter_metadata_with_dates(metadata,good_directory,file_type="npz") extracted_shorelines_dict = extract_shorelines_with_dask( session_path, diff --git a/src/coastseg/filters.py b/src/coastseg/filters.py index f58e942..422b4a9 100644 --- a/src/coastseg/filters.py +++ b/src/coastseg/filters.py @@ -252,44 +252,5 @@ def apply_land_mask( directory_path: str) -> None: return directory_path - - -def filter_model_outputs( - satname: str, files: list, dest_folder_good: str, dest_folder_bad: str -) -> None: - """ - Filter model outputs based on KMeans clustering of RMSE values and organize into 'good' and 'bad'. - - Args: - label (str): Label used for categorizing. - files (list): List of file paths. - dest_folder_good (str): Destination folder for 'good' files. - dest_folder_bad (str): Destination folder for 'bad' files. - """ - - count_shapes = count_files_with_same_shape(files) - # get the most common shape - # modal_shape = mode(get_image_shapes(files)) - for shape, count in count_shapes.items(): - print(f"Shape: {shape} Count: {count}") - valid_files = get_files_with_shape(files,shape) - if len(valid_files) <3: - # if there are not enough valid files to perform the analysis, move all files to the good folder - handle_files_and_directories( - [], valid_files, dest_folder_bad, dest_folder_good - ) - else: - times, time_var = get_time_vectors(valid_files) - da = xr.concat([load_xarray_data(f) for f in valid_files], dim=time_var) - timeav = da.mean(dim="time") - - rmse, input_rmse = measure_rmse(da, times, timeav) - labels, scores = get_kmeans_clusters(input_rmse, rmse) - files_bad, files_good = get_good_bad_files(valid_files, labels, scores) - handle_files_and_directories( - files_bad, files_good, dest_folder_bad, dest_folder_good - ) - # apply_land_mask(dest_folder_good) - diff --git a/src/coastseg/zoo_model.py b/src/coastseg/zoo_model.py index edb4f69..0df1df9 100644 --- a/src/coastseg/zoo_model.py +++ b/src/coastseg/zoo_model.py @@ -639,6 +639,7 @@ def set_settings(self, **kwargs): "apply_cloud_mask": True, # whether to apply cloud mask to images or not "drop_intersection_pts": False, # whether to drop intersection points not on the transect "coastseg_version": __version__, # version of coastseg used to generate the data + "apply_segmentation_filter": True, # whether to apply to sort the segmentations as good or bad } if kwargs: self.settings.update({key: value for key, value in kwargs.items()}) @@ -890,6 +891,7 @@ def extract_shorelines_with_unet( shoreline_extraction_area_gdf = shoreline_extraction_area_gdf, ) + # extract shorelines extracted_shorelines = extracted_shoreline.Extracted_Shoreline() extracted_shorelines = ( @@ -901,6 +903,7 @@ def extract_shorelines_with_unet( session_path, new_session_path, shoreline_extraction_area=shoreline_extraction_area_gdf, + apply_segmentation_filter=settings.get("apply_segmentation_filter", True), **kwargs, ) )