Skip to content

Commit

Permalink
Automatic Good/Bad Image Filter #171 V1
Browse files Browse the repository at this point in the history
  • Loading branch information
2320sharon committed Aug 17, 2023
1 parent c16fae4 commit 0a53da3
Show file tree
Hide file tree
Showing 4 changed files with 279 additions and 49 deletions.
84 changes: 73 additions & 11 deletions src/coastseg/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

# Internal dependencies imports
from coastseg import exceptions
from coastseg.validation import find_satellite_in_filename

# widget icons from https://fontawesome.com/icons/angle-down?s=solid&f=classic

Expand Down Expand Up @@ -158,6 +159,61 @@ def filter_metadata(metadata: dict, sitename: str, filepath_data: str) -> dict[s
)
# filter out files that were removed from RGB directory
filtered_files = get_filtered_files_dict(RGB_directory, "jpg", sitename)
metadata = edit_metadata(metadata, filtered_files)
return metadata


def edit_metadata(
metadata: dict[str, Union[str, List[str]]], filtered_files: dict[str, set]
) -> dict:
"""Filters the metadata so that it contains the data for the filenames in filered_files
Args:
metadata (dict): A dictionary containing the metadata for each satellite
Each satellite has the following key fields "filenames","epsg","dates","acc_georef"
Example:
metadata = {
'L8':{
"filenames": ["2019-02-16-18-22-17_L8_sitename_ms.tif","2012-02-16-18-22-17_L8_sitename_ms.tif"],
"epsg":[4326,4326],
"dates":[datetime.datetime(2022, 1, 26, 15, 33, 50, tzinfo=<UTC>),datetime.datetime(2012, 1, 26, 15, 33, 50, tzinfo=<UTC>)],
"acc_georef":[9.185,9.125],
}
'L9':{
"filenames": ["2019-02-16-18-22-17_L9_sitename_ms.tif"],
"epsg":[4326],
"dates":[datetime.datetime(2022, 1, 26, 15, 33, 50, tzinfo=<UTC>)],
"acc_georef":[9.185],
}
}
filtered_files (dict): A dictionary containing a set of the tif filenames available for each satellite
Example:
filtered_files = {
"L5": {},
"L7": {},
"L8": {"2019-02-16-18-22-17_L8_sitename_ms.tif"},
"L9": {"2019-02-16-18-22-17_L9_sitename_ms.tif"},
"S2": {},
}
Returns:
dict: a filtered dictionary containing only the data for the filenames in filtered_files
Example:
metadata = {
'L8':{
"filenames": ["2019-02-16-18-22-17_L8_sitename_ms.tif"],
"epsg":[4326],
"dates":[datetime.datetime(2022, 1, 26, 15, 33, 50, tzinfo=<UTC>)],
"acc_georef":[9.185],
}
'L9':{
"filenames": ["2019-02-16-18-22-17_L9_sitename_ms.tif"],
"epsg":[4326],
"dates":[datetime.datetime(2022, 1, 26, 15, 33, 50, tzinfo=<UTC>)],
"acc_georef":[9.185],
}
}
"""
for satname in filtered_files:
if satname in metadata:
idx_keep = list(
Expand Down Expand Up @@ -207,29 +263,35 @@ def get_filtered_files_dict(directory: str, file_type: str, sitename: str) -> di

satellites = {"L5": set(), "L7": set(), "L8": set(), "L9": set(), "S2": set()}
for filepath in filepaths:
old_filename = os.path.basename(filepath)
parts = old_filename.split("_")
filename = os.path.basename(filepath)
parts = filename.split("_")

if len(parts) < 2:
logging.warning(
f"Skipping file with unexpected name format: {old_filename}"
)
logging.warning(f"Skipping file with unexpected name format: {filename}")
continue

date = parts[0]
satname_parts = parts[-1].split(".")

if len(satname_parts) < 2:
satname = find_satellite_in_filename(filename)
if satname is None:
logging.warning(
f"Skipping file with unexpected name format: {old_filename}"
f"Skipping file with unexpected name format which was missing a satname: {filename}"
)
continue

satname = satname_parts[0]
# satname_parts = parts[-1].split(".")

# if len(satname_parts) < 2:
# logging.warning(
# f"Skipping file with unexpected name format: {old_filename}"
# )
# continue

# satname = satname_parts[0]

new_filename = f"{date}_{satname}_{sitename}_ms.tif"
tif_filename = f"{date}_{satname}_{sitename}_ms.tif"
if satname in satellites:
satellites[satname].add(new_filename)
satellites[satname].add(tif_filename)

return satellites

Expand Down
83 changes: 45 additions & 38 deletions src/coastseg/extracted_shoreline.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,58 @@
# Standard library imports
import colorsys
import copy
import fnmatch
import json
import logging
import os
import json
import copy
from glob import glob
from typing import Optional, Union, List, Dict


# Internal dependencies imports
from coastseg import exceptions
from coastseg import common
from time import perf_counter
from typing import Dict, List, Optional, Union

# External dependencies imports
import dask
from dask.diagnostics import ProgressBar
import geopandas as gpd
import numpy as np
from ipyleaflet import GeoJSON
from matplotlib.pyplot import get_cmap
from matplotlib.colors import rgb2hex
from tqdm.auto import tqdm

import matplotlib.lines as mlines
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import skimage.measure as measure
from coastsat import SDS_shoreline
from coastsat import SDS_preprocess
import skimage.morphology as morphology
from coastsat import SDS_preprocess, SDS_shoreline, SDS_tools
from coastsat.SDS_download import get_metadata
from coastsat.SDS_transects import compute_intersection_QC
from coastsat.SDS_shoreline import extract_shorelines
from coastsat.SDS_tools import (
get_filenames,
get_filepath,
output_to_gdf,
remove_duplicates,
remove_inaccurate_georef,
output_to_gdf,
get_filepath,
get_filenames,
)
import pandas as pd
import skimage.morphology as morphology
from coastsat.SDS_transects import compute_intersection_QC
from ipyleaflet import GeoJSON
from matplotlib import gridspec
from matplotlib.colors import rgb2hex
from matplotlib.pyplot import get_cmap
from skimage import measure, morphology
from tqdm.auto import tqdm

pd.set_option("mode.chained_assignment", None)
# Internal dependencies imports
from coastseg import common, exceptions
from coastseg.validation import get_satellites_in_directory
from coastseg.filters import filter_model_outputs
from coastseg.common import get_filtered_files_dict, edit_metadata

# imports for show detection
from coastsat import SDS_tools
from matplotlib import gridspec
import matplotlib.patches as mpatches
import matplotlib.lines as mlines

# Set pandas option
pd.set_option("mode.chained_assignment", None)

# Logger setup
logger = logging.getLogger(__name__)

# Module level variables
__all__ = ["Extracted_Shoreline"]

from time import perf_counter


def time_func(func):
def wrapper(*args, **kwargs):
Expand All @@ -68,9 +66,6 @@ def wrapper(*args, **kwargs):
return wrapper


from skimage import measure, morphology


def read_from_dict(d: dict, keys_of_interest: list | set | tuple):
"""
Function to extract the value from the first matching key in a dictionary.
Expand All @@ -92,9 +87,6 @@ def read_from_dict(d: dict, keys_of_interest: list | set | tuple):
raise KeyError(f"{keys_of_interest} were not in {d}")


import re


def remove_small_objects_and_binarize(merged_labels, min_size):
# Ensure the image is binary
binary_image = merged_labels > 0
Expand Down Expand Up @@ -406,7 +398,9 @@ def process_satellite_image(
cloud_mask.shape, georef, image_epsg, pixel_size, settings
)
# read the model outputs from the npz file for this image
npz_file = find_matching_npz(filename, session_path)
npz_file = find_matching_npz(filename, os.path.join(session_path, "good"))
if npz_file is None:
npz_file = find_matching_npz(filename, session_path)
logger.info(f"npz_file: {npz_file}")
if npz_file is None:
logger.warning(f"npz file not found for {filename}")
Expand Down Expand Up @@ -1069,6 +1063,19 @@ def extract_shorelines_with_dask(
# if len(filenames) == 0:
# logger.warning(f"Satellite {satname} had no imagery")
# return output
good_folder = os.path.join(session_path, "good")
bad_folder = os.path.join(session_path, "bad")
satellites = get_satellites_in_directory(session_path)
# for each satellite sort the model outputs into good & bad
for satname in satellites:
# get all the model_outputs that have the satellite in the filename
files = glob(f"{session_path}{os.sep}*{satname}*.npz")
if len(files) != 0:
filter_model_outputs(satname, files, good_folder, bad_folder)

filtered_files = get_filtered_files_dict(good_folder, "npz", sitename)
metadata = edit_metadata(metadata, filtered_files)

result_dict = {}
for satname in metadata:
satellite_dict = process_satellite(
Expand Down
127 changes: 127 additions & 0 deletions src/coastseg/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import datetime
from statistics import mode
import numpy as np
import xarray as xr
from glob import glob
import os, shutil
from sklearn.cluster import KMeans
from statistics import mode

def copy_files(files: list, dest_folder: str) -> None:
"""
Copy files to a specified destination folder.
Args:
files (list): List of file paths to be copied.
dest_folder (str): Destination folder where files will be copied.
Returns:
None
"""
for f in files:
shutil.copy(f, dest_folder)

def load_data(f: str) -> np.array:
with np.load(f) as data:
grey = data["grey_label"].astype("uint8")
return grey

def get_good_bad_files(files: list, labels: np.array, scores: list) -> tuple:
"""
Split files into 'good' and 'bad' categories based on provided labels and scores.
Args:
files (list): List of file paths.
labels (np.array): Array of labels corresponding to the files.
scores (list): List of scores associated with the files.
Returns:
tuple: A tuple containing two arrays:
- files_bad (np.array): Array of 'bad' categorized file paths (highest score label).
- files_good (np.array): Array of 'good' categorized file paths (lowest score label).
"""
files_bad = np.array(files)[labels == np.argmax(scores)]
files_good = np.array(files)[labels == np.argmin(scores)]
return files_bad, files_good

def get_time_vectors(files: list) -> tuple:
"""
Extract time information from a list of file paths and create an xarray variable.
Args:
files (list): List of file paths containing time information.
Returns:
tuple: A tuple containing two elements:
- times (list): List of time values extracted from the file paths.
- time_variable (xr.Variable): xarray variable containing the time values.
"""
times = [f.split(os.sep)[-1].split("_")[0] for f in files]
return times, xr.Variable("time", times)

def get_image_shapes(files: list) -> list:
return [load_data(f).shape for f in files]

def get_image_shapes(files: list) -> list:
return [load_data(f).shape for f in files]

def measure_rmse(da: xr.DataArray, times: list, timeav: xr.DataArray) -> tuple:
rmse = [
float(np.sqrt(np.mean((da.sel(time=t) - timeav) ** 2)).to_numpy())
for t in times
]
input_rmse = np.array(rmse).reshape(-1, 1)
return rmse, input_rmse

def get_kmeans_clusters(input_rmse: np.array, rmse: list) -> tuple:
kmeans = KMeans(n_clusters=2, random_state=0, n_init="auto").fit(input_rmse)
labels = kmeans.labels_
scores = [
np.mean(np.array(rmse)[labels == 0]),
np.mean(np.array(rmse)[labels == 1]),
]
return labels, scores

def load_xarray_data(f: str) -> xr.DataArray:
with np.load(f) as data:
grey = data["grey_label"].astype("uint8")
ny, nx = grey.shape
y = np.arange(ny)
x = np.arange(nx)
return xr.DataArray(grey, coords={"y": y, "x": x}, dims=["y", "x"])

def handle_files_and_directories(
files_bad: list, files_good: list, dest_folder_bad: str, dest_folder_good: str
) -> None:
os.makedirs(dest_folder_bad, exist_ok=True)
os.makedirs(dest_folder_good, exist_ok=True)

copy_files(files_bad, dest_folder_bad)
copy_files(files_good, dest_folder_good)

def return_valid_files(files: list) -> list:
# print(get_image_shapes(files))
modal_shape = mode(get_image_shapes(files))
return [f for f in files if load_data(f).shape == modal_shape]

def filter_model_outputs(
label: str, files: list, dest_folder_good: str, dest_folder_bad: str
) -> None:
valid_files = return_valid_files(files)
times, time_var = get_time_vectors(valid_files)
da = xr.concat([load_xarray_data(f) for f in valid_files], dim=time_var)
timeav = da.mean(dim="time")

rmse, input_rmse = measure_rmse(da, times, timeav)
labels, scores = get_kmeans_clusters(input_rmse, rmse)
files_bad, files_good = get_good_bad_files(valid_files, labels, scores)

handle_files_and_directories(
files_bad, files_good, dest_folder_bad, dest_folder_good
)

print(f"{len(files_good)} good {label} labels")
print(f"{len(files_bad)} bad {label} labels")



Loading

0 comments on commit 0a53da3

Please sign in to comment.