Skip to content

Commit

Permalink
add kwargs to xarrayReader methods for compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentsarago committed Nov 5, 2024
1 parent 9a39c80 commit 6ba6b48
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 8 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# 7.2.0 (2024-11-05)

* Ensure compatibility between XarrayReader and other Readers by adding `**kwargs` on class methods

# 7.1.0 (2024-10-29)

* Add `preview()` and `statistics()` methods to XarrayReader (https://github.com/cogeotiff/rio-tiler/pull/755)
Expand Down
8 changes: 7 additions & 1 deletion rio_tiler/io/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import warnings
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

import attr
import numpy
Expand Down Expand Up @@ -156,6 +156,7 @@ def statistics(
percentiles: Optional[List[int]] = None,
hist_options: Optional[Dict] = None,
nodata: Optional[NoData] = None,
**kwargs: Any,
) -> Dict[str, BandStatistics]:
"""Return statistics from a dataset."""
hist_options = hist_options or {}
Expand Down Expand Up @@ -188,6 +189,7 @@ def tile(
reproject_method: WarpResampling = "nearest",
auto_expand: bool = True,
nodata: Optional[NoData] = None,
**kwargs: Any,
) -> ImageData:
"""Read a Web Map tile from a dataset.
Expand Down Expand Up @@ -264,6 +266,7 @@ def part(
height: Optional[int] = None,
width: Optional[int] = None,
resampling_method: RIOResampling = "nearest",
**kwargs: Any,
) -> ImageData:
"""Read part of a dataset.
Expand Down Expand Up @@ -362,6 +365,7 @@ def preview(
dst_crs: Optional[CRS] = None,
reproject_method: WarpResampling = "nearest",
resampling_method: RIOResampling = "nearest",
**kwargs: Any,
) -> ImageData:
"""Return a preview of a dataset.
Expand Down Expand Up @@ -446,6 +450,7 @@ def point(
lat: float,
coord_crs: CRS = WGS84_CRS,
nodata: Optional[NoData] = None,
**kwargs: Any,
) -> PointData:
"""Read a pixel value from a dataset.
Expand Down Expand Up @@ -499,6 +504,7 @@ def feature(
height: Optional[int] = None,
width: Optional[int] = None,
resampling_method: RIOResampling = "nearest",
**kwargs: Any,
) -> ImageData:
"""Read part of a dataset defined by a geojson feature.
Expand Down
Binary file added tests/fixtures/dataset_2d.nc
Binary file not shown.
Binary file added tests/fixtures/dataset_2d.tif
Binary file not shown.
106 changes: 106 additions & 0 deletions tests/fixtures/stac_netcdf.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
{
"type": "Feature",
"stac_version": "1.0.0",
"id": "my_stac",
"properties": {
"proj:epsg": 4326,
"proj:geometry": {
"type": "Polygon",
"coordinates": [
[
[
-170.085,
79.91999999999659
],
[
169.91499999997504,
79.91999999999659
],
[
169.91499999997504,
-80.08
],
[
-170.085,
-80.08
],
[
-170.085,
79.91999999999659
]
]
]
},
"proj:bbox": [
-170.085,
79.91999999999659,
169.91499999997504,
-80.08
],
"proj:shape": [
1000,
2000
],
"proj:transform": [
0.16999999999998752,
0,
-170.085,
0,
0.1599999999999966,
-80.08,
0,
0,
1
],
"datetime": "2024-11-05T09:03:47.523834Z"
},
"geometry": {
"type": "Polygon",
"coordinates": [
[
[
-170.085,
79.91999999999659
],
[
169.91499999997504,
79.91999999999659
],
[
169.91499999997504,
-80.08
],
[
-170.085,
-80.08
],
[
-170.085,
79.91999999999659
]
]
]
},
"links": [],
"assets": {
"geotiff": {
"href": "dataset_2d.tif",
"type": "image/tiff; application=geotiff",
"roles": ["data"]
},
"netcdf": {
"href": "dataset_2d.nc",
"type": "application/x-netcdf",
"roles": ["data"]
}
},
"bbox": [
-170.085,
-80.08,
169.91499999997504,
79.91999999999659
],
"stac_extensions": [
"https://stac-extensions.github.io/projection/v1.1.0/schema.json"
]
}
58 changes: 51 additions & 7 deletions tests/test_io_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
import json
import os
import sys
from typing import Dict, Set, Tuple, Type
from typing import Dict, List, Set, Tuple, Type
from unittest.mock import patch

import attr
import morecantile
import numpy
import pytest
import rasterio
import xarray
from morecantile import TileMatrixSet
from rasterio._env import get_gdal_config
from rasterio.crs import CRS

from rio_tiler.constants import WEB_MERCATOR_TMS
from rio_tiler.errors import (
AssetAsBandError,
ExpressionMixingWarning,
Expand All @@ -36,6 +39,7 @@
STAC_WRONGSTATS_PATH = os.path.join(PREFIX, "stac_wrong_stats.json")
STAC_ALTERNATE_PATH = os.path.join(PREFIX, "stac_alternate.json")
STAC_GRIB_PATH = os.path.join(PREFIX, "stac_grib.json")
STAC_NETCDF_PATH = os.path.join(PREFIX, "stac_netcdf.json")

with open(STAC_PATH) as f:
item = json.loads(f.read())
Expand Down Expand Up @@ -987,8 +991,39 @@ def test_default_assets(rio):
assert img.band_names == ["green_b1"]


def test_get_reader():
def test_netcdf_reader():
"""Should use the correct reader depending on the media type."""

@attr.s
class NetCDFReader(XarrayReader):
"""Reader: Open NetCDF file and access DataArray."""

src_path: str = attr.ib()
variable: str = attr.ib()

tms: TileMatrixSet = attr.ib(default=WEB_MERCATOR_TMS)

ds: xarray.Dataset = attr.ib(init=False)
input: xarray.DataArray = attr.ib(init=False)

_dims: List = attr.ib(init=False, factory=list)

def __attrs_post_init__(self):
"""Set bounds and CRS."""
self.ds = xarray.open_dataset(self.src_path, decode_coords="all")
da = self.ds[self.variable]

# Make sure we have a valid CRS
crs = da.rio.crs or "epsg:4326"
da = da.rio.write_crs(crs)

if "time" in da.dims:
da = da.isel(time=0)

self.input = da

super().__attrs_post_init__()

valid_types = {
"image/tiff; application=geotiff",
"application/x-netcdf",
Expand All @@ -1004,20 +1039,29 @@ def _get_reader(self, asset_info: AssetInfo) -> Tuple[Type[BaseReader], Dict]:
if asset_type and asset_type in [
"application/x-netcdf",
]:
return XarrayReader, {}
return NetCDFReader, {}

return Reader, {}

with CustomSTACReader(STAC_RASTER_PATH) as stac:
assert stac.assets == ["red", "green", "blue", "netcdf"]
with CustomSTACReader(STAC_NETCDF_PATH) as stac:
assert stac.assets == ["geotiff", "netcdf"]
info = stac._get_asset_info("netcdf")
assert info["media_type"] == "application/x-netcdf"
assert stac._get_reader(info) == (XarrayReader, {})
assert stac._get_reader(info) == (NetCDFReader, {})

info = stac._get_asset_info("red")
info = stac._get_asset_info("geotiff")
assert info["media_type"] == "image/tiff; application=geotiff"
assert stac._get_reader(info) == (Reader, {})

with CustomSTACReader(
STAC_NETCDF_PATH, reader_options={"variable": "dataset"}
) as stac:
info = stac.info(assets=["netcdf"])
assert info["netcdf"].crs

img = stac.preview(assets=["netcdf"])
assert img.band_names == ["netcdf_value"]


@patch("rio_tiler.io.stac.STAC_ALTERNATE_KEY", "s3")
def test_alternate_assets():
Expand Down

0 comments on commit 6ba6b48

Please sign in to comment.