From c6294115d137408877575ec12e960ead4fe1ac6d Mon Sep 17 00:00:00 2001 From: Ryan McCarthy Date: Fri, 16 Feb 2024 10:52:42 -0700 Subject: [PATCH] adding a CountMethod to default pixel selction and including tests --- rio_tiler/mosaic/methods/defaults.py | 29 ++++++++++++++++++++++++++++ tests/test_mosaic.py | 10 ++++++++++ 2 files changed, 39 insertions(+) diff --git a/rio_tiler/mosaic/methods/defaults.py b/rio_tiler/mosaic/methods/defaults.py index 8e163808..0c1cc87b 100644 --- a/rio_tiler/mosaic/methods/defaults.py +++ b/rio_tiler/mosaic/methods/defaults.py @@ -188,3 +188,32 @@ def feed(self, array: Optional[numpy.ma.MaskedArray]): mask = numpy.where(pidex, array.mask, self.mosaic.mask) self.mosaic = numpy.ma.where(pidex, array, self.mosaic) self.mosaic.mask = mask + + +@dataclass +class CountMethod(MosaicMethodBase): + """Stack the arrays and return the valid pixel count.""" + + stack: List[numpy.ma.MaskedArray] = field(default_factory=list, init=False) + + @property + def data(self) -> Optional[numpy.ma.MaskedArray]: + """Return valid data count of the data stack.""" + if self.stack: + data = numpy.ma.count(numpy.ma.stack(self.stack, axis=0), axis=0).astype(numpy.uint16) + + # only need the counts from one band + if len(data.shape) > 2: + data = data[0] + + # mask is always empty + mask = numpy.zeros(data.shape, dtype=bool) + array = numpy.ma.MaskedArray(data, mask) + + return array + + return None + + def feed(self, array: Optional[numpy.ma.MaskedArray]): + """Add array to the stack.""" + self.stack.append(array) diff --git a/tests/test_mosaic.py b/tests/test_mosaic.py index 5fb1a2d3..2fe417dc 100644 --- a/tests/test_mosaic.py +++ b/tests/test_mosaic.py @@ -285,6 +285,16 @@ class aClass(object): assert t.dtype == "uint16" assert m.dtype == "uint8" + # Test count pixel selection + (t, m), _ = mosaic.mosaic_reader( + assets, _read_tile, x, y, z, pixel_selection=defaults.CountMethod() + ) + assert t.shape == (1, 256, 256) + assert m.shape == (256, 256) + assert m.all() + assert t.dtype == "uint16" + assert m.dtype == "uint8" + def mock_rasterio_open(asset): """Mock rasterio Open."""