Skip to content

Commit

Permalink
Get rid of BaseSnowMaskTask (#736)
Browse files Browse the repository at this point in the history
  • Loading branch information
zigaLuksic authored Aug 31, 2023
1 parent 0dc99b2 commit 6e714ca
Showing 1 changed file with 32 additions and 39 deletions.
71 changes: 32 additions & 39 deletions eolearn/mask/snow_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

import itertools
import logging
from abc import ABCMeta
from typing import Any

import cv2
import numpy as np
Expand All @@ -24,38 +22,7 @@
LOGGER = logging.getLogger(__name__)


class BaseSnowMaskTask(EOTask, metaclass=ABCMeta):
"""Base class for snow detection and masking"""

def __init__(
self,
data_feature: Feature,
band_indices: list[int],
dilation_size: int = 0,
undefined_value: int = 0,
mask_name: str = "SNOW_MASK",
):
"""
:param data_feature: EOPatch feature represented by a tuple in the form of `(FeatureType, 'feature_name')`
:param band_indices: A list containing the indices at which the required bands can be found in the data_feature.
:param dilation_size: Size of the disk in pixels for performing dilation. Value 0 means do not perform
this post-processing step.
"""
self.bands_feature = self.parse_feature(data_feature, allowed_feature_types={FeatureType.DATA})
self.band_indices = band_indices
self.disk_size = 2 * dilation_size + 1
self.undefined_value = undefined_value
self.mask_feature = (FeatureType.MASK, mask_name)

def _apply_dilation(self, snow_masks: np.ndarray) -> np.ndarray:
"""Apply binary dilation for each mask in the series"""
if self.disk_size > 0:
disk = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (self.disk_size, self.disk_size))
snow_masks = np.array([cv2.dilate(mask.astype(np.uint8), disk) for mask in snow_masks])
return snow_masks.astype(bool)


class SnowMaskTask(BaseSnowMaskTask):
class SnowMaskTask(EOTask):
"""The task calculates the snow mask using the given thresholds.
The default values were optimised based on the Sentinel-2 L1C processing level. Values might not be optimal for L2A
Expand All @@ -70,7 +37,9 @@ def __init__(
band_indices: list[int],
ndsi_threshold: float = 0.4,
brightness_threshold: float = 0.3,
**kwargs: Any,
dilation_size: int = 0,
undefined_value: int = 0,
mask_name: str = "SNOW_MASK",
):
"""
:param data_feature: EOPatch feature represented by a tuple in the form of `(FeatureType, 'feature_name')`
Expand All @@ -82,9 +51,20 @@ def __init__(
:param ndsi_threshold: Minimum value of the NDSI required to classify the pixel as snow
:param brightness_threshold: Minimum value of the red band for a pixel to be classified as bright
"""
super().__init__(data_feature, band_indices, **kwargs)
self.bands_feature = self.parse_feature(data_feature, allowed_feature_types={FeatureType.DATA})
self.band_indices = band_indices
self.ndsi_threshold = ndsi_threshold
self.brightness_threshold = brightness_threshold
self.disk_size = 2 * dilation_size + 1
self.undefined_value = undefined_value
self.mask_feature = (FeatureType.MASK, mask_name)

def _apply_dilation(self, snow_masks: np.ndarray) -> np.ndarray:
"""Apply binary dilation for each mask in the series"""
if self.disk_size > 0:
disk = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (self.disk_size, self.disk_size))
snow_masks = np.array([cv2.dilate(mask.astype(np.uint8), disk) for mask in snow_masks])
return snow_masks.astype(bool)

def execute(self, eopatch: EOPatch) -> EOPatch:
bands = eopatch[self.bands_feature][..., self.band_indices]
Expand All @@ -110,7 +90,7 @@ def execute(self, eopatch: EOPatch) -> EOPatch:
return eopatch


class TheiaSnowMaskTask(BaseSnowMaskTask):
class TheiaSnowMaskTask(EOTask):
"""Task to add a snow mask to an EOPatch. The input data is either Sentinel-2 L1C or L2A level
Original implementation and documentation available at https://gitlab.orfeo-toolbox.org/remote_modules/let-it-snow
Expand All @@ -136,7 +116,9 @@ def __init__(
red_params: tuple[float, float, float, float, float] = (12, 0.3, 0.1, 0.2, 0.040),
ndsi_params: tuple[float, float, float] = (0.4, 0.15, 0.001),
b10_index: int | None = None,
**kwargs: Any,
dilation_size: int = 0,
undefined_value: int = 0,
mask_name: str = "SNOW_MASK",
):
"""
:param data_feature: EOPatch feature represented by a tuple in the form of `(FeatureType, 'feature_name')`
Expand Down Expand Up @@ -166,13 +148,17 @@ def __init__(
is the minimum snow fraction in the image to activate the pass 2 snow test. With reference to the
ATBD, the tuple is (n_1, n_2, f_s)
"""
super().__init__(data_feature, band_indices, **kwargs)
self.bands_feature = self.parse_feature(data_feature, allowed_feature_types={FeatureType.DATA})
self.band_indices = band_indices
self.dem_feature = self.parse_feature(dem_feature)
self.clm_feature = self.parse_feature(cloud_mask_feature)
self.dem_params = dem_params
self.red_params = red_params
self.ndsi_params = ndsi_params
self.b10_index = b10_index
self.disk_size = 2 * dilation_size + 1
self.undefined_value = undefined_value
self.mask_feature = (FeatureType.MASK, mask_name)

def _resample_red(self, input_array: np.ndarray) -> np.ndarray:
"""Method to resample the values of the red band
Expand Down Expand Up @@ -257,6 +243,13 @@ def _apply_second_pass(

return snow_mask_pass2

def _apply_dilation(self, snow_masks: np.ndarray) -> np.ndarray:
"""Apply binary dilation for each mask in the series"""
if self.disk_size > 0:
disk = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (self.disk_size, self.disk_size))
snow_masks = np.array([cv2.dilate(mask.astype(np.uint8), disk) for mask in snow_masks])
return snow_masks.astype(bool)

def execute(self, eopatch: EOPatch) -> EOPatch:
"""Run multi-pass snow detection"""
bands = eopatch[self.bands_feature][..., self.band_indices]
Expand Down

0 comments on commit 6e714ca

Please sign in to comment.