diff --git a/python/src/exactextract/__init__.py b/python/src/exactextract/__init__.py index 52bd26d1..f5ed55ee 100644 --- a/python/src/exactextract/__init__.py +++ b/python/src/exactextract/__init__.py @@ -18,5 +18,6 @@ GDALRasterSource, NumPyRasterSource, RasterioRasterSource, + XArrayRasterSource, ) from .writer import Writer, JSONWriter, GDALWriter diff --git a/python/src/exactextract/exact_extract.py b/python/src/exactextract/exact_extract.py index 691e9311..d808b5ed 100644 --- a/python/src/exactextract/exact_extract.py +++ b/python/src/exactextract/exact_extract.py @@ -6,7 +6,7 @@ JSONFeatureSource, GeoPandasFeatureSource, ) -from .raster_source import RasterSource, GDALRasterSource, RasterioRasterSource +from .raster_source import RasterSource, GDALRasterSource, RasterioRasterSource, XArrayRasterSource from .operation import Operation from .processor import FeatureSequentialProcessor, RasterSequentialProcessor from .writer import JSONWriter @@ -62,6 +62,24 @@ def prep_raster(rast, band=None, name_root=None, names=None): except ImportError: pass + try: + import rioxarray + import xarray + + if isinstance(rast, xarray.core.dataarray.DataArray): + if band: + return [XArrayRasterSource(rast, band)] + else: + if not names: + names = [f"{name_root}_{i+1}" for i in range(rast.rio.count)] + return [ + XArrayRasterSource(rast, i+1, name=names[i]) + for i in range(rast.rio.count) + ] + + except ImportError: + pass + raise Exception("Unhandled raster datatype") diff --git a/python/src/exactextract/raster_source.py b/python/src/exactextract/raster_source.py index 5b7c70ff..a2033fdc 100644 --- a/python/src/exactextract/raster_source.py +++ b/python/src/exactextract/raster_source.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import numpy as np import os import pathlib @@ -126,3 +127,73 @@ def read_window(self, x0, y0, nx, ny): from rasterio.windows import Window return self.ds.read(self.band_idx, window=Window(x0, y0, nx, ny)) + + +class XArrayRasterSource(RasterSource): + def __init__(self, ds, band_idx=1, *, name=None): + super().__init__() + + if isinstance(ds, (str, os.PathLike)): + import rioxarray + import xarray + + ds = xarray.open_dataarray(ds) + + self.ds = ds + if self.ds.rio.crs is None: + # Set a default CRS to prevent clip_box from + # complaining that we don't have one + self.ds.rio.set_crs('EPSG:4326', inplace=True) + self.band_idx = band_idx + self.band_dim = self._band_dim(self.ds) + self.bounds = self.ds.rio.bounds() + + if name: + self.set_name(name) + + + @staticmethod + def _band_dim(ds): + dims = list(ds.dims) + dims.remove(ds.rio.x_dim) + dims.remove(ds.rio.y_dim) + + if len(dims) == 0: + return None + elif len(dims) == 1: + return dims[0] + else: + raise Exception("Cannot handle >1 non-spatial dimension") + + + def res(self): + return tuple(abs(x) for x in self.ds.rio.resolution()) + + + def extent(self): + return self.bounds + + + def nodata_value(self): + return self.ds.rio.nodata + + + def read_window(self, x0, y0, nx, ny): + lats = self.ds[self.ds.rio.y_dim] + flipped = bool(len(lats) > 1 and lats[1] > lats[0]) + + if flipped: + y0 = self.ds.rio.height - y0 - ny + + selection = {} + if self.band_dim is not None: + selection[self.band_dim] = self.ds[self.band_dim][self.band_idx - 1] + selection[self.ds.rio.x_dim] = self.ds[self.ds.rio.x_dim][x0 : x0+nx] + selection[self.ds.rio.y_dim] = self.ds[self.ds.rio.y_dim][y0 : y0+ny] + + ret = self.ds.sel(**selection).to_numpy() + + if flipped: + ret = np.flipud(ret) + + return ret diff --git a/python/tests/test_exact_extract.py b/python/tests/test_exact_extract.py index ba390fcc..7f80c49a 100644 --- a/python/tests/test_exact_extract.py +++ b/python/tests/test_exact_extract.py @@ -320,6 +320,10 @@ def open_with_lib(fname, libname): elif libname == "rasterio": rasterio = pytest.importorskip("rasterio") return rasterio.open(fname) + elif libname == "xarray": + rioxarray = pytest.importorskip("rioxarray") + xarray = pytest.importorskip("xarray") + return xarray.open_dataarray(fname) elif libname == "ogr": ogr = pytest.importorskip("osgeo.ogr") return ogr.Open(fname) @@ -331,7 +335,7 @@ def open_with_lib(fname, libname): return gp.read_file(fname) -@pytest.mark.parametrize("rast_lib", ("gdal", "rasterio")) +@pytest.mark.parametrize("rast_lib", ("gdal", "rasterio", "xarray")) @pytest.mark.parametrize("vec_lib", ("ogr", "fiona", "geopandas")) @pytest.mark.parametrize( "arr,expected", diff --git a/python/tests/test_raster_source.py b/python/tests/test_raster_source.py index 65f0bf9f..3d4f415f 100644 --- a/python/tests/test_raster_source.py +++ b/python/tests/test_raster_source.py @@ -1,29 +1,39 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import numpy as np import pytest -from exactextract import GDALRasterSource, RasterioRasterSource +from exactextract import GDALRasterSource, RasterioRasterSource, XArrayRasterSource @pytest.fixture() def global_half_degree(tmp_path): from osgeo import gdal - fname = str(tmp_path / "test.tif") + fname = str(tmp_path / "test.nc") - drv = gdal.GetDriverByName("GTiff") - ds = drv.Create(fname, 720, 360) + nx = 720 + ny = 360 + + drv = gdal.GetDriverByName("NetCDF") + ds = drv.Create(fname, nx, ny, eType=gdal.GDT_Int32) gt = (-180.0, 0.5, 0.0, 90.0, 0.0, -0.5) ds.SetGeoTransform(gt) band = ds.GetRasterBand(1) band.SetNoDataValue(6) + + data = np.arange(nx * ny).reshape(ny, nx) + band.WriteArray(data) + ds = None return fname -@pytest.mark.parametrize("Source", (GDALRasterSource, RasterioRasterSource)) +@pytest.mark.parametrize( + "Source", (GDALRasterSource, RasterioRasterSource, XArrayRasterSource) +) def test_gdal_raster(global_half_degree, Source): try: src = Source(global_half_degree, 1) @@ -32,6 +42,15 @@ def test_gdal_raster(global_half_degree, Source): assert src.res() == (0.50, 0.50) assert src.extent() == pytest.approx((-180, -90, 180, 90)) - assert src.nodata_value() == 6 - assert src.read_window(0, 0, 10, 10).shape == (10, 10) + window = src.read_window(4, 5, 2, 3) + + assert window.shape == (3, 2) + np.testing.assert_array_equal( + window.astype(np.float64), + np.array([[3604, 3605], [4324, 4325], [5044, 5045]], np.float64), + ) + + if Source != XArrayRasterSource: + assert src.nodata_value() == 6 + assert window.dtype == np.int32