diff --git a/eolearn/features/extra/clustering.py b/eolearn/features/extra/clustering.py index 35172c3c..443ff3e5 100644 --- a/eolearn/features/extra/clustering.py +++ b/eolearn/features/extra/clustering.py @@ -88,7 +88,7 @@ def execute(self, eopatch: EOPatch) -> EOPatch: # All connections to masked pixels are removed if self.mask_name is not None: - mask = eopatch.mask_timeless[self.mask_name].squeeze() + mask = eopatch.mask_timeless[self.mask_name].squeeze(axis=-1) graph_args["mask"] = mask data = data[np.ravel(mask) != 0] diff --git a/eolearn/geometry/morphology.py b/eolearn/geometry/morphology.py index dbf039fe..b4ee86b9 100644 --- a/eolearn/geometry/morphology.py +++ b/eolearn/geometry/morphology.py @@ -48,7 +48,7 @@ def __init__( self.no_data_label = no_data_label def execute(self, eopatch: EOPatch) -> EOPatch: - feature_array = eopatch[(self.mask_type, self.mask_name)].squeeze().copy() + feature_array = eopatch[(self.mask_type, self.mask_name)].squeeze(axis=-1).copy() all_labels = np.unique(feature_array) erode_labels = self.erode_labels if self.erode_labels else all_labels @@ -148,6 +148,10 @@ def __init__( def map_method(self, feature: np.ndarray) -> np.ndarray: """Applies the morphological operation to a raster feature.""" feature = feature.copy() + is_bool = feature.dtype == bool + if is_bool: + feature = feature.astype(np.uint8) + morph_func = partial(cv2.morphologyEx, kernel=self.struct_elem, op=self.morph_operation) if feature.ndim == 3: for channel in range(feature.shape[2]): @@ -158,4 +162,4 @@ def map_method(self, feature: np.ndarray) -> np.ndarray: else: raise ValueError(f"Invalid number of dimensions: {feature.ndim}") - return feature + return feature.astype(bool) if is_bool else feature diff --git a/tests/features/extra/test_clustering.py b/tests/features/extra/test_clustering.py index b143b080..80b3707b 100644 --- a/tests/features/extra/test_clustering.py +++ b/tests/features/extra/test_clustering.py @@ -43,14 +43,14 @@ def test_clustering(example_eopatch): remove_small=10, ).execute(example_eopatch) - clusters = example_eopatch.data_timeless["clusters_small"].squeeze() + clusters = example_eopatch.data_timeless["clusters_small"].squeeze(axis=-1) assert len(np.unique(clusters)) == 22, "Wrong number of clusters." assert np.median(clusters) == 2 assert np.mean(clusters) == pytest.approx(2.19109 if sys.version_info < (3, 9) else 2.201188) - clusters = example_eopatch.data_timeless["clusters_mask"].squeeze() + clusters = example_eopatch.data_timeless["clusters_mask"].squeeze(axis=-1) assert len(np.unique(clusters)) == 8, "Wrong number of clusters." assert np.median(clusters) == 0 diff --git a/tests/geometry/test_morphology.py b/tests/geometry/test_morphology.py index 646e80ec..f4d69d7a 100644 --- a/tests/geometry/test_morphology.py +++ b/tests/geometry/test_morphology.py @@ -24,8 +24,11 @@ def patch_fixture() -> EOPatch: config = PatchGeneratorConfig(max_integer_value=10, raster_shape=(50, 100), depth_range=(3, 4)) patch = generate_eopatch([MASK_FEATURE, MASK_TIMELESS_FEATURE], config=config) - for feat in [MASK_FEATURE, MASK_TIMELESS_FEATURE]: - patch[feat] = patch[feat].astype(np.uint8) + patch[MASK_FEATURE] = patch[MASK_FEATURE].astype(np.uint8) + patch[MASK_TIMELESS_FEATURE] = patch[MASK_TIMELESS_FEATURE] < 1 + patch[MASK_TIMELESS_FEATURE][10:20, 20:32] = 0 + patch[MASK_TIMELESS_FEATURE][30:, 50:] = 1 + return patch @@ -64,22 +67,32 @@ def test_erosion_partial(test_eopatch): MorphologicalOperations.DILATION, None, [6, 34, 172, 768, 2491, 7405, 19212, 44912], - [1, 2, 16, 104, 466, 1490, 3870, 9051], + [4882, 10118], + ), + ( + MorphologicalOperations.EROSION, + MorphologicalStructFactory.get_disk(4), + [54555, 15639, 3859, 770, 153, 19, 5], + [12391, 2609], ), - (MorphologicalOperations.EROSION, MorphologicalStructFactory.get_disk(11), [74957, 42, 1], [14994, 6]), - (MorphologicalOperations.OPENING, MorphologicalStructFactory.get_disk(11), [73899, 1051, 50], [14837, 163]), - (MorphologicalOperations.CLOSING, MorphologicalStructFactory.get_disk(11), [770, 74230], [425, 14575]), ( MorphologicalOperations.OPENING, - MorphologicalStructFactory.get_rectangle(5, 6), - [48468, 24223, 2125, 169, 15], - [10146, 4425, 417, 3, 9], + MorphologicalStructFactory.get_disk(3), + [8850, 13652, 16866, 14632, 11121, 6315, 2670, 761, 133], + [11981, 3019], + ), + (MorphologicalOperations.CLOSING, MorphologicalStructFactory.get_disk(11), [770, 74230], [661, 14339]), + ( + MorphologicalOperations.OPENING, + MorphologicalStructFactory.get_rectangle(3, 3), + [15026, 23899, 20363, 9961, 4328, 1128, 280, 15], + [12000, 3000], ), ( MorphologicalOperations.DILATION, MorphologicalStructFactory.get_rectangle(5, 6), [2, 19, 198, 3929, 70852], - [32, 743, 14225], + [803, 14197], ), ], ) @@ -91,5 +104,6 @@ def test_morphological_filter(patch, morph_operation, struct_element, mask_count assert patch[MASK_FEATURE].shape == (5, 50, 100, 3) assert patch[MASK_TIMELESS_FEATURE].shape == (50, 100, 3) + assert patch[MASK_TIMELESS_FEATURE].dtype == bool assert_array_equal(np.unique(patch[MASK_FEATURE], return_counts=True)[1], mask_counts) assert_array_equal(np.unique(patch[MASK_TIMELESS_FEATURE], return_counts=True)[1], mask_timeless_counts)