Skip to content

Commit

Permalink
Merge pull request #5 from Forest-Recovery-Digital-Companion/fix-inco…
Browse files Browse the repository at this point in the history
…nsistent-watershed

FRML-27 Fix Inconsistent Watershed Result
  • Loading branch information
Eve-ning authored Sep 28, 2023
2 parents f017e94 + 9392bdc commit 26a509c
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 33 deletions.
17 changes: 15 additions & 2 deletions src/frdc/load/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from google.oauth2.service_account import Credentials

from frdc.conf import LOCAL_DATASET_ROOT_DIR, SECRETS_DIR, GCS_PROJECT_ID, GCS_BUCKET_NAME, Band
from frdc.utils.utils import Rect


@dataclass
Expand Down Expand Up @@ -142,10 +143,22 @@ def get_ar_bands(self, band_names=Band.FILE_NAMES) -> np.ndarray:
# Sort the bands by the order in Band.FILE_NAMES
return np.stack([bands_dict[band_name] for band_name in Band.FILE_NAMES], axis=-1)

def get_bounds_and_labels(self, file_name='bounds.csv') -> tuple[Iterable[Iterable[int]], Iterable[str]]:
def get_bounds_and_labels(self, file_name='bounds.csv') -> tuple[Iterable[Rect], Iterable[str]]:
""" Gets the bounds and labels from the bounds.csv file.
Notes:
In the context of np.ndarray, to slice with x, y coordinates, you need to slice
with [y0:y1, x0:x1]. Which is different from the bounds.csv file.
Args:
file_name: The name of the bounds.csv file.
Returns:
A tuple of (bounds, labels), where bounds is a list of (x0, y0, x1, y1) and labels is a list of labels.
"""
fp = self.dl.download_file(path=self.dataset_dir / file_name)
df = pd.read_csv(fp)
return [(i.x0, i.y0, i.x1, i.y1) for i in df.itertuples()], df['name'].tolist()
return [Rect(i.x0, i.y0, i.x1, i.y1) for i in df.itertuples()], df['name'].tolist()

@staticmethod
def _load_image(path: Path | str) -> np.ndarray:
Expand Down
6 changes: 4 additions & 2 deletions src/frdc/preprocess/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .extract_segments import extract_segments_from_labels, extract_segments_from_bounds
from .extract_segments import (extract_segments_from_labels, extract_segments_from_bounds,
remove_small_segments_from_labels)
from .preprocess import compute_labels

__all__ = ['compute_labels', 'extract_segments_from_labels', 'extract_segments_from_bounds']
__all__ = ['compute_labels', 'extract_segments_from_labels', 'extract_segments_from_bounds',
'remove_small_segments_from_labels']
52 changes: 43 additions & 9 deletions src/frdc/preprocess/extract_segments.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,38 @@
import logging
from typing import Iterable

import numpy as np

from frdc.utils.utils import Rect


def remove_small_segments_from_labels(
ar_labels: np.ndarray,
min_height: int = 10,
min_width: int = 10
) -> np.ndarray:
""" Removes small segments from a label image.
Args:
ar_labels: Labelled Image, where each integer value is a segment mask.
min_height: Minimum height of a segment to be considered "small".
min_width: Minimum width of a segment to be considered "small".
Returns:
A labelled image with small segments removed.
"""
ar_labels = ar_labels.copy()
for i in np.unique(ar_labels):
coords = np.argwhere(ar_labels == i)
y0, x0 = coords.min(axis=0)
y1, x1 = coords.max(axis=0) + 1
height = y1 - y0
width = x1 - x0
if height < min_height or width < min_width:
logging.info(f"Removing segment {i} with shape {height}x{width}")
ar_labels[ar_labels == i] = 0
return ar_labels


def extract_segments_from_labels(
ar: np.ndarray,
Expand All @@ -20,23 +51,25 @@ def extract_segments_from_labels(
"""
ar_segments = []
for segment_ix in range(np.max(ar_labels) + 1):
ar_segment_mask = np.array(ar_labels == segment_ix)
for segment_ix in np.unique(ar_labels):
if cropped:
coords = np.argwhere(ar_segment_mask)
coords = np.argwhere(ar_labels == segment_ix)
x0, y0 = coords.min(axis=0)
x1, y1 = coords.max(axis=0) + 1
ar_segments.append(ar[x0:x1, y0:y1])
ar_segment_cropped_mask = ar_labels[x0:x1, y0:y1] == segment_ix
ar_segment_cropped = ar[x0:x1, y0:y1]
ar_segment_cropped = np.where(ar_segment_cropped_mask[..., None], ar_segment_cropped, np.nan)
ar_segments.append(ar_segment_cropped)
else:
ar_segment = ar.copy()
ar_segment = np.where(ar_segment_mask[..., None], ar_segment, np.nan)
ar_segment_mask = np.array(ar_labels == segment_ix)
ar_segment = np.where(ar_segment_mask[..., None], ar, np.nan)
ar_segments.append(ar_segment)
return ar_segments


def extract_segments_from_bounds(
ar: np.ndarray,
bounds: Iterable[Iterable[int]],
bounds: Iterable[Rect],
cropped: bool = True
) -> list[np.ndarray]:
""" Extracts segments as a list from bounds
Expand All @@ -51,9 +84,10 @@ def extract_segments_from_bounds(
"""
ar_segments = []
for x0, y0, x1, y1 in bounds:
for b in bounds:
x0, y0, x1, y1 = b.x0, b.y0, b.x1, b.y1
if cropped:
ar_segments.append(ar[x0:x1, y0:y1])
ar_segments.append(ar[y0:y1, x0:x1])
else:
ar_segment_mask = np.zeros(ar.shape[:2], dtype=bool)
ar_segment_mask[y0:y1, x0:x1] = True
Expand Down
35 changes: 18 additions & 17 deletions src/frdc/preprocess/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

def compute_labels(
ar: np.ndarray,
nir_threshold_value=0.5,
min_crown_size=100,
min_crown_hole=100,
connectivity=1,
peaks_footprint=20,
watershed_compactness=0.1
nir_threshold_value=90 / 256,
min_crown_size=1000,
min_crown_hole=1000,
connectivity=2,
peaks_footprint=200,
watershed_compactness=0
) -> np.ndarray:
""" Automatically segments crowns from an NDArray with a series of image processing operations.
Expand All @@ -33,8 +33,8 @@ def compute_labels(
Background is of shape (H, W, C), where C is the number of bands, C is sorted by Band.FILE_NAMES.
Crowns is a list of np.ndarray crowns, each crown is of shape (H, W, C).
"""
# ar = scale_0_1_per_band(ar)
ar = scale_static_per_band(ar)
ar = scale_0_1_per_band(ar)
# ar = scale_static_per_band(ar)
ar_mask = threshold_binary_mask(ar, Band.NIR, nir_threshold_value)
ar_mask = remove_small_objects(ar_mask, min_size=min_crown_size, connectivity=connectivity)
ar_mask = remove_small_holes(ar_mask, area_threshold=min_crown_hole, connectivity=connectivity)
Expand Down Expand Up @@ -101,26 +101,27 @@ def binary_watershed(ar_mask: np.ndarray, peaks_footprint: int, watershed_compac
# Image Depth: The distance from the background
# Image Basins: The local maxima of the image depth. i.e. points that are the deepest in the image.

# We can get the image depth by taking the negative euclidean distance transform of the binary mask.
# This means that lower values are further away from the background.
ar_watershed_depth = -distance_transform_edt(ar_mask)
# The ar distance is the distance from the background.
ar_distance = distance_transform_edt(ar_mask)

# For basins, we find the basins, by finding the local maxima of the negative image depth.
ar_watershed_basin_coords = peak_local_max(
-ar_watershed_depth,
ar_distance,
footprint=np.ones((peaks_footprint, peaks_footprint)),
min_distance=1,
# min_distance=1,
exclude_border=0,
p_norm=2
# p_norm=2,
labels=ar_mask
)
ar_watershed_basins = np.zeros(ar_watershed_depth.shape, dtype=bool)
ar_watershed_basins = np.zeros(ar_distance.shape, dtype=bool)
ar_watershed_basins[tuple(ar_watershed_basin_coords.T)] = True
ar_watershed_basins, _ = ndimage.label(ar_watershed_basins)

# TODO: I noticed that low watershed compactness values produces miniblobs, which can be indicative of redundant
# crowns. We should investigate this further.
return watershed(image=-ar_watershed_depth,
return watershed(image=-ar_distance, # We use the negative so that "peaks" become "troughs"
markers=ar_watershed_basins,
mask=ar_mask,
# watershed_line=True, # Enable this to see the watershed lines
compactness=watershed_compactness)
# compactness=watershed_compactness
)
3 changes: 3 additions & 0 deletions src/frdc/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from collections import namedtuple

Rect = namedtuple('Rect', ['x0', 'y0', 'x1', 'y1'])
48 changes: 45 additions & 3 deletions tests/unit_tests/preprocess/test_extract_segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,50 @@
- If we crop, any(segment.shape != ar.shape for segment in segments)
- If we don't crop, all(segment.shape == ar.shape for segment in segments)
"""
import numpy as np

from frdc.preprocess import extract_segments_from_bounds, extract_segments_from_labels, compute_labels
from frdc.preprocess import (extract_segments_from_bounds, extract_segments_from_labels, compute_labels,
remove_small_segments_from_labels)


def test_remove_small_segments_from_labels():
""" We'll test that it correctly removes the correct segments.
The test case:
1 1 2 2 2
1 1 2 2 2
3 3 4 4 4
3 3 4 4 4
3 3 4 4 4
For example, if we removed anything that is smaller than width 3 or height 3, then we expect:
0 0 0 0 0
0 0 0 0 0
0 0 4 4 4
0 0 4 4 4
0 0 4 4 4
Then the unique labels should be {0, 4}
"""
ar = np.zeros((5, 5), dtype=np.uint8)
ar[0:2, 0:2] = 1
ar[0:2, 2:5] = 2
ar[2:5, 0:2] = 3
ar[2:5, 2:5] = 4

def test_unique_labels(expected_labels: set, min_height: int = 2, min_width: int = 2):
""" Tests the unique labels are as expected. """
assert set(np.unique(
remove_small_segments_from_labels(ar, min_height=min_height, min_width=min_width)
)) == expected_labels

# We expect 0 in some, as "removed" labels are relabelled to the background 0.
test_unique_labels({1, 2, 3, 4}, min_height=2, min_width=2)
test_unique_labels({0, 3, 4}, min_height=3, min_width=2)
test_unique_labels({0, 2, 4}, min_height=2, min_width=3)
test_unique_labels({0, 4}, min_height=3, min_width=3)
test_unique_labels({0}, min_height=4, min_width=4)


def test_extract_segments_from_bounds_cropped(ds):
Expand All @@ -29,12 +71,12 @@ def test_extract_segments_from_bounds_no_crop(ds):


def test_extract_segments_from_labels_cropped(ds):
ar_labels = compute_labels(ds.get_ar_bands())
ar_labels = compute_labels(ds.get_ar_bands(), peaks_footprint=10)
segments = extract_segments_from_labels(ar := ds.get_ar_bands(), ar_labels, cropped=True)
assert any(segment.shape != ar.shape for segment in segments)


def test_extract_segments_from_labels_no_crop(ds):
ar_labels = compute_labels(ds.get_ar_bands())
ar_labels = compute_labels(ds.get_ar_bands(), peaks_footprint=10)
segments = extract_segments_from_labels(ar := ds.get_ar_bands(), ar_labels, cropped=False)
assert all(segment.shape == ar.shape for segment in segments)

0 comments on commit 26a509c

Please sign in to comment.