Skip to content

Commit

Permalink
#259 #197 review kmeans filtering, update no data filtering, move pre…
Browse files Browse the repository at this point in the history
…dicition segmentation png to good/bad & npz to good/bad
  • Loading branch information
2320sharon committed Jun 4, 2024
1 parent 7ffdf9d commit e7d8e89
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
1 change: 0 additions & 1 deletion src/coastseg/extracted_shoreline.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def check_percent_no_data_allowed(
bool: True if the percentage of no data pixels is less than or equal to the allowed percentage, False otherwise.
"""
if percent_no_data_allowed is not None:
percent_no_data_allowed = percent_no_data_allowed / 100
num_total_pixels = cloud_mask.shape[0] * cloud_mask.shape[1]
percentage_no_data = np.sum(im_nodata) / num_total_pixels
if percentage_no_data > percent_no_data_allowed:
Expand Down
22 changes: 18 additions & 4 deletions src/coastseg/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os, shutil
from sklearn.cluster import KMeans
from statistics import mode
import pathlib

# Logger setup
logger = logging.getLogger(__name__)
Expand All @@ -26,7 +27,16 @@ def copy_files(files: list, dest_folder: str) -> None:
dest_path = os.path.join(dest_folder, os.path.basename(f))
if os.path.exists(os.path.abspath(dest_path)):
continue
shutil.copy(f, dest_folder)
# shutil.copy(f, dest_folder)
shutil.move(f, dest_path)
# move the matching png files to the respective folders
parent_dir = os.path.abspath(pathlib.Path(f).parent)
png_file = os.path.basename(f.replace("_res.npz","_predseg.png"))
png_path = os.path.join(parent_dir, png_file)
if os.path.exists(png_path):
dest_path = os.path.join(dest_folder, os.path.basename(png_path))
shutil.move(png_path, dest_path)



def load_data(f: str) -> np.array:
Expand Down Expand Up @@ -110,7 +120,10 @@ def measure_rmse(da: xr.DataArray, times: list, timeav: xr.DataArray) -> tuple:

def get_kmeans_clusters(input_rmse: np.array, rmse: list) -> tuple:
"""
Perform KMeans clustering on RMSE values.
Perform KMeans clustering on RMSE values.
Returns the average rmse score for each cluster as well as the labels for each cluster.
score[0] is the average rmse for the cluster 0.
score [1] is the average rmse for the cluster 1.
Args:
input_rmse (np.array): Array of RMSE values.
Expand All @@ -121,12 +134,12 @@ def get_kmeans_clusters(input_rmse: np.array, rmse: list) -> tuple:
"""
kmeans = KMeans(n_clusters=2, random_state=0, n_init="auto").fit(input_rmse)
labels = kmeans.labels_
# Calculate mean RMSE for each cluster
# the lower the RMSE the better the prediction
scores = [
np.mean(np.array(rmse)[labels == 0]),
np.mean(np.array(rmse)[labels == 1]),
]
if scores[0]>scores[1]:
labels = 1-labels
return labels, scores


Expand Down Expand Up @@ -166,6 +179,7 @@ def handle_files_and_directories(
logger.info(f"Copying {len(files_good)} files to {dest_folder_good}")
copy_files(files_bad, dest_folder_bad)
copy_files(files_good, dest_folder_good)
# get the matching png files and copy them to the respective folders


def return_valid_files(files: list) -> list:
Expand Down
4 changes: 2 additions & 2 deletions src/coastseg/zoo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def run_model_and_extract_shorelines(self,
use_GPU = settings.get('use_GPU', "0")
use_otsu = settings.get('otsu', False)
use_tta = settings.get('tta', False)
percent_no_data = settings.get('percent_no_data', 50.0)
percent_no_data = settings.get('percent_no_data', 0.5)

# make a progress bar to show the progress of the model and shoreline extraction
prog_bar = tqdm.auto.tqdm(range(2),
Expand Down Expand Up @@ -1143,7 +1143,7 @@ def get_files_for_seg(
# filter out files with no data pixels greater than percent_no_data
len_before = len(model_ready_files)
model_ready_files = filter_no_data_pixels(model_ready_files, percent_no_data)
print(f"From {len_before} files {len_before - len(model_ready_files)} files were filtered out due to no data pixels percentage being greater than {percent_no_data}%.")
print(f"From {len_before} files {len_before - len(model_ready_files)} files were filtered out due to no data pixels percentage being greater than {percent_no_data*100}%.")

return model_ready_files

Expand Down

0 comments on commit e7d8e89

Please sign in to comment.