Skip to content

Commit

Permalink
Merge branch 'classifier_model' into main_copy_2
Browse files Browse the repository at this point in the history
  • Loading branch information
2320sharon committed Dec 20, 2024
2 parents fbc108e + c936690 commit c896a0f
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 43 deletions.
3 changes: 2 additions & 1 deletion 3_zoo_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"model_type": "global_segformer_RGB_4class_14036903", # model name from the zoo
"otsu": False, # Otsu Thresholding
"tta": False, # Test Time Augmentation
"apply_segmentation_filter": True, # apply segmentation filter to the model outputs to sort them into good or bad
}
# Available models can run input "RGB" # or "MNDWI" or "NDWI"
img_type = "RGB" # make sure the model name is compatible with the image type
Expand All @@ -39,7 +40,7 @@
model_session_name = "sample_session_demo1"
# b. ENTER THE DIRECTORY WHERE THE INPUT IMAGES ARE STORED
# - Example of the directory where the input images are stored ( this should be the /data folder in the CoastSeg directory)
sample_directory = "C:\development\doodleverse\coastseg\CoastSeg\data\ID_wra5_datetime03-04-24__03_43_01\jpg_files\preprocessed\RGB"
sample_directory = r"C:\development\doodleverse\coastseg\CoastSeg\data\ID_wra5_datetime03-04-24__03_43_01\jpg_files\preprocessed\RGB"


# 2. Save the settings to the model instance
Expand Down
3 changes: 2 additions & 1 deletion src/coastseg/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

def filter_segmentations(
session_path: str,
threshold: float = 0.40,
) -> str:
"""
Sort model output files into "good" and "bad" folders based on the satellite name in the filename.
Expand All @@ -31,7 +32,7 @@ def filter_segmentations(
session_path,
session_path,
good_path=good_path,
threshold=0.40)
threshold=threshold)
# if the good folder does not exist then this means the classifier could not find any png files at the session path and something went wrong
if not os.path.exists(good_path):
raise FileNotFoundError(f"No model output files found at {session_path}. Shoreline Filtering failed.")
Expand Down
8 changes: 6 additions & 2 deletions src/coastseg/extracted_shoreline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1843,6 +1843,7 @@ def create_extracted_shorelines_from_session(
new_session_path: str = None,
output_directory: str = None,
shoreline_extraction_area : gpd.GeoDataFrame = None,
apply_segmentation_filter: bool = True,
**kwargs: dict,
) -> "Extracted_Shoreline":
"""
Expand Down Expand Up @@ -1948,8 +1949,11 @@ def create_extracted_shorelines_from_session(
return self

# Filter the segmentations to only include the good segmentations, then update the metadata to only include the files with the good segmentations
good_directory = classifier.filter_segmentations(session_path)
metadata= common.filter_metadata_with_dates(metadata,good_directory,file_type="npz")
good_directory = session_path
if apply_segmentation_filter:
good_directory = classifier.filter_segmentations(session_path)
# Filter the metadata to only include the files with segmentations that are in the good_directory
metadata= common.filter_metadata_with_dates(metadata,good_directory,file_type="npz")

extracted_shorelines_dict = extract_shorelines_with_dask(
session_path,
Expand Down
39 changes: 0 additions & 39 deletions src/coastseg/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,44 +252,5 @@ def apply_land_mask( directory_path: str) -> None:

return directory_path



def filter_model_outputs(
satname: str, files: list, dest_folder_good: str, dest_folder_bad: str
) -> None:
"""
Filter model outputs based on KMeans clustering of RMSE values and organize into 'good' and 'bad'.
Args:
label (str): Label used for categorizing.
files (list): List of file paths.
dest_folder_good (str): Destination folder for 'good' files.
dest_folder_bad (str): Destination folder for 'bad' files.
"""

count_shapes = count_files_with_same_shape(files)
# get the most common shape
# modal_shape = mode(get_image_shapes(files))
for shape, count in count_shapes.items():
print(f"Shape: {shape} Count: {count}")
valid_files = get_files_with_shape(files,shape)
if len(valid_files) <3:
# if there are not enough valid files to perform the analysis, move all files to the good folder
handle_files_and_directories(
[], valid_files, dest_folder_bad, dest_folder_good
)
else:
times, time_var = get_time_vectors(valid_files)
da = xr.concat([load_xarray_data(f) for f in valid_files], dim=time_var)
timeav = da.mean(dim="time")

rmse, input_rmse = measure_rmse(da, times, timeav)
labels, scores = get_kmeans_clusters(input_rmse, rmse)
files_bad, files_good = get_good_bad_files(valid_files, labels, scores)
handle_files_and_directories(
files_bad, files_good, dest_folder_bad, dest_folder_good
)
# apply_land_mask(dest_folder_good)



3 changes: 3 additions & 0 deletions src/coastseg/zoo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ def set_settings(self, **kwargs):
"apply_cloud_mask": True, # whether to apply cloud mask to images or not
"drop_intersection_pts": False, # whether to drop intersection points not on the transect
"coastseg_version": __version__, # version of coastseg used to generate the data
"apply_segmentation_filter": True, # whether to apply to sort the segmentations as good or bad
}
if kwargs:
self.settings.update({key: value for key, value in kwargs.items()})
Expand Down Expand Up @@ -890,6 +891,7 @@ def extract_shorelines_with_unet(
shoreline_extraction_area_gdf = shoreline_extraction_area_gdf,
)


# extract shorelines
extracted_shorelines = extracted_shoreline.Extracted_Shoreline()
extracted_shorelines = (
Expand All @@ -901,6 +903,7 @@ def extract_shorelines_with_unet(
session_path,
new_session_path,
shoreline_extraction_area=shoreline_extraction_area_gdf,
apply_segmentation_filter=settings.get("apply_segmentation_filter", True),
**kwargs,
)
)
Expand Down

0 comments on commit c896a0f

Please sign in to comment.