Skip to content

Commit

Permalink
Morphology on booleans (#758)
Browse files Browse the repository at this point in the history
* adjust tests

* fix implementation

* remove other occurences of empty squeeze

* assert correct dtype
  • Loading branch information
zigaLuksic authored Oct 12, 2023
1 parent 1ff6cad commit ed721ff
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 15 deletions.
2 changes: 1 addition & 1 deletion eolearn/features/extra/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def execute(self, eopatch: EOPatch) -> EOPatch:

# All connections to masked pixels are removed
if self.mask_name is not None:
mask = eopatch.mask_timeless[self.mask_name].squeeze()
mask = eopatch.mask_timeless[self.mask_name].squeeze(axis=-1)
graph_args["mask"] = mask
data = data[np.ravel(mask) != 0]

Expand Down
8 changes: 6 additions & 2 deletions eolearn/geometry/morphology.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
self.no_data_label = no_data_label

def execute(self, eopatch: EOPatch) -> EOPatch:
feature_array = eopatch[(self.mask_type, self.mask_name)].squeeze().copy()
feature_array = eopatch[(self.mask_type, self.mask_name)].squeeze(axis=-1).copy()

all_labels = np.unique(feature_array)
erode_labels = self.erode_labels if self.erode_labels else all_labels
Expand Down Expand Up @@ -148,6 +148,10 @@ def __init__(
def map_method(self, feature: np.ndarray) -> np.ndarray:
"""Applies the morphological operation to a raster feature."""
feature = feature.copy()
is_bool = feature.dtype == bool
if is_bool:
feature = feature.astype(np.uint8)

morph_func = partial(cv2.morphologyEx, kernel=self.struct_elem, op=self.morph_operation)
if feature.ndim == 3:
for channel in range(feature.shape[2]):
Expand All @@ -158,4 +162,4 @@ def map_method(self, feature: np.ndarray) -> np.ndarray:
else:
raise ValueError(f"Invalid number of dimensions: {feature.ndim}")

return feature
return feature.astype(bool) if is_bool else feature
4 changes: 2 additions & 2 deletions tests/features/extra/test_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ def test_clustering(example_eopatch):
remove_small=10,
).execute(example_eopatch)

clusters = example_eopatch.data_timeless["clusters_small"].squeeze()
clusters = example_eopatch.data_timeless["clusters_small"].squeeze(axis=-1)

assert len(np.unique(clusters)) == 22, "Wrong number of clusters."
assert np.median(clusters) == 2

assert np.mean(clusters) == pytest.approx(2.19109 if sys.version_info < (3, 9) else 2.201188)

clusters = example_eopatch.data_timeless["clusters_mask"].squeeze()
clusters = example_eopatch.data_timeless["clusters_mask"].squeeze(axis=-1)

assert len(np.unique(clusters)) == 8, "Wrong number of clusters."
assert np.median(clusters) == 0
Expand Down
34 changes: 24 additions & 10 deletions tests/geometry/test_morphology.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@
def patch_fixture() -> EOPatch:
config = PatchGeneratorConfig(max_integer_value=10, raster_shape=(50, 100), depth_range=(3, 4))
patch = generate_eopatch([MASK_FEATURE, MASK_TIMELESS_FEATURE], config=config)
for feat in [MASK_FEATURE, MASK_TIMELESS_FEATURE]:
patch[feat] = patch[feat].astype(np.uint8)
patch[MASK_FEATURE] = patch[MASK_FEATURE].astype(np.uint8)
patch[MASK_TIMELESS_FEATURE] = patch[MASK_TIMELESS_FEATURE] < 1
patch[MASK_TIMELESS_FEATURE][10:20, 20:32] = 0
patch[MASK_TIMELESS_FEATURE][30:, 50:] = 1

return patch


Expand Down Expand Up @@ -64,22 +67,32 @@ def test_erosion_partial(test_eopatch):
MorphologicalOperations.DILATION,
None,
[6, 34, 172, 768, 2491, 7405, 19212, 44912],
[1, 2, 16, 104, 466, 1490, 3870, 9051],
[4882, 10118],
),
(
MorphologicalOperations.EROSION,
MorphologicalStructFactory.get_disk(4),
[54555, 15639, 3859, 770, 153, 19, 5],
[12391, 2609],
),
(MorphologicalOperations.EROSION, MorphologicalStructFactory.get_disk(11), [74957, 42, 1], [14994, 6]),
(MorphologicalOperations.OPENING, MorphologicalStructFactory.get_disk(11), [73899, 1051, 50], [14837, 163]),
(MorphologicalOperations.CLOSING, MorphologicalStructFactory.get_disk(11), [770, 74230], [425, 14575]),
(
MorphologicalOperations.OPENING,
MorphologicalStructFactory.get_rectangle(5, 6),
[48468, 24223, 2125, 169, 15],
[10146, 4425, 417, 3, 9],
MorphologicalStructFactory.get_disk(3),
[8850, 13652, 16866, 14632, 11121, 6315, 2670, 761, 133],
[11981, 3019],
),
(MorphologicalOperations.CLOSING, MorphologicalStructFactory.get_disk(11), [770, 74230], [661, 14339]),
(
MorphologicalOperations.OPENING,
MorphologicalStructFactory.get_rectangle(3, 3),
[15026, 23899, 20363, 9961, 4328, 1128, 280, 15],
[12000, 3000],
),
(
MorphologicalOperations.DILATION,
MorphologicalStructFactory.get_rectangle(5, 6),
[2, 19, 198, 3929, 70852],
[32, 743, 14225],
[803, 14197],
),
],
)
Expand All @@ -91,5 +104,6 @@ def test_morphological_filter(patch, morph_operation, struct_element, mask_count

assert patch[MASK_FEATURE].shape == (5, 50, 100, 3)
assert patch[MASK_TIMELESS_FEATURE].shape == (50, 100, 3)
assert patch[MASK_TIMELESS_FEATURE].dtype == bool
assert_array_equal(np.unique(patch[MASK_FEATURE], return_counts=True)[1], mask_counts)
assert_array_equal(np.unique(patch[MASK_TIMELESS_FEATURE], return_counts=True)[1], mask_timeless_counts)

0 comments on commit ed721ff

Please sign in to comment.