diff --git a/CHANGES.md b/CHANGES.md index 2a32de34..6a19101f 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,8 @@ # Unreleased * Enable dynamic definition of Asset **reader** in `MultiBaseReader` (https://github.com/cogeotiff/rio-tiler/pull/711/, https://github.com/cogeotiff/rio-tiler/pull/728) +* Adding `default_assets` for MultiBaseReader and STACReader (author @mccarthyryanc, https://github.com/cogeotiff/rio-tiler/pull/722) +* Adding `default_bands` for MultiBandReader (https://github.com/cogeotiff/rio-tiler/pull/722) # 6.7.0 (2024-09-05) diff --git a/docs/src/readers.md b/docs/src/readers.md index db04e5be..30ac8612 100644 --- a/docs/src/readers.md +++ b/docs/src/readers.md @@ -398,6 +398,7 @@ STACReader.__mro__ - **exclude_assets** (set, optional): Set of assets to exclude from the `available` asset list - **include_asset_types** (set, optional): asset types to consider as valid type for the reader - **exclude_asset_types** (set, optional): asset types to consider as invalid type for the reader +- **default_assets** (sequence, optional): default assets to use for the reader if nothing else is provided. - **reader** (BaseReader, optional): Reader to use to read assets (defaults to rio_tiler.io.rasterio.Reader) - **reader_options** (dict, optional): Options to forward to the reader init - **fetch_options** (dict, optional): Options to pass to the `httpx.get` or `boto3` when fetching the STAC item diff --git a/rio_tiler/io/base.py b/rio_tiler/io/base.py index e33838a6..5f38da82 100644 --- a/rio_tiler/io/base.py +++ b/rio_tiler/io/base.py @@ -268,6 +268,8 @@ class MultiBaseReader(SpatialMixin, metaclass=abc.ABCMeta): assets: Sequence[str] = attr.ib(init=False) + default_assets: Optional[Sequence[str]] = attr.ib(init=False, default=None) + ctx: Any = attr.ib(init=False, default=contextlib.nullcontext) def __enter__(self): @@ -500,9 +502,16 @@ def tile( if expression: assets = self.parse_expression(expression, asset_as_band=asset_as_band) + if not assets and self.default_assets: + warnings.warn( + f"No assets/expression passed, defaults to {self.default_assets}", + UserWarning, + ) + assets = self.default_assets + if not assets: raise MissingAssets( - "assets must be passed either via `expression` or `assets` options." + "assets must be passed via `expression` or `assets` options, or via class-level `default_assets`." ) asset_indexes = asset_indexes or {} @@ -584,9 +593,16 @@ def part( if expression: assets = self.parse_expression(expression, asset_as_band=asset_as_band) + if not assets and self.default_assets: + warnings.warn( + f"No assets/expression passed, defaults to {self.default_assets}", + UserWarning, + ) + assets = self.default_assets + if not assets: raise MissingAssets( - "assets must be passed either via `expression` or `assets` options." + "assets must be passed via `expression` or `assets` options, or via class-level `default_assets`." ) asset_indexes = asset_indexes or {} @@ -666,9 +682,16 @@ def preview( if expression: assets = self.parse_expression(expression, asset_as_band=asset_as_band) + if not assets and self.default_assets: + warnings.warn( + f"No assets/expression passed, defaults to {self.default_assets}", + UserWarning, + ) + assets = self.default_assets + if not assets: raise MissingAssets( - "assets must be passed either via `expression` or `assets` options." + "assets must be passed via `expression` or `assets` options, or via class-level `default_assets`." ) asset_indexes = asset_indexes or {} @@ -752,9 +775,16 @@ def point( if expression: assets = self.parse_expression(expression, asset_as_band=asset_as_band) + if not assets and self.default_assets: + warnings.warn( + f"No assets/expression passed, defaults to {self.default_assets}", + UserWarning, + ) + assets = self.default_assets + if not assets: raise MissingAssets( - "assets must be passed either via `expression` or `assets` options." + "assets must be passed via `expression` or `assets` options, or via class-level `default_assets`." ) asset_indexes = asset_indexes or {} @@ -830,9 +860,16 @@ def feature( if expression: assets = self.parse_expression(expression, asset_as_band=asset_as_band) + if not assets and self.default_assets: + warnings.warn( + f"No assets/expression passed, defaults to {self.default_assets}", + UserWarning, + ) + assets = self.default_assets + if not assets: raise MissingAssets( - "assets must be passed either via `expression` or `assets` options." + "assets must be passed via `expression` or `assets` options, or via class-level `default_assets`." ) asset_indexes = asset_indexes or {} @@ -909,6 +946,8 @@ class MultiBandReader(SpatialMixin, metaclass=abc.ABCMeta): bands: Sequence[str] = attr.ib(init=False) + default_bands: Optional[Sequence[str]] = attr.ib(init=False, default=None) + def __enter__(self): """Support using with Context Managers.""" return self @@ -1076,6 +1115,13 @@ def tile( if expression: bands = self.parse_expression(expression) + if not bands and self.default_bands: + warnings.warn( + f"No bands/expression passed, defaults to {self.default_bands}", + UserWarning, + ) + bands = self.default_bands + if not bands: raise MissingBands( "bands must be passed either via `expression` or `bands` options." @@ -1134,6 +1180,13 @@ def part( if expression: bands = self.parse_expression(expression) + if not bands and self.default_bands: + warnings.warn( + f"No bands/expression passed, defaults to {self.default_bands}", + UserWarning, + ) + bands = self.default_bands + if not bands: raise MissingBands( "bands must be passed either via `expression` or `bands` options." @@ -1190,6 +1243,13 @@ def preview( if expression: bands = self.parse_expression(expression) + if not bands and self.default_bands: + warnings.warn( + f"No bands/expression passed, defaults to {self.default_bands}", + UserWarning, + ) + bands = self.default_bands + if not bands: raise MissingBands( "bands must be passed either via `expression` or `bands` options." @@ -1250,6 +1310,13 @@ def point( if expression: bands = self.parse_expression(expression) + if not bands and self.default_bands: + warnings.warn( + f"No bands/expression passed, defaults to {self.default_bands}", + UserWarning, + ) + bands = self.default_bands + if not bands: raise MissingBands( "bands must be passed either via `expression` or `bands` options." @@ -1307,6 +1374,13 @@ def feature( if expression: bands = self.parse_expression(expression) + if not bands and self.default_bands: + warnings.warn( + f"No bands/expression passed, defaults to {self.default_bands}", + UserWarning, + ) + bands = self.default_bands + if not bands: raise MissingBands( "bands must be passed either via `expression` or `bands` options." diff --git a/rio_tiler/io/stac.py b/rio_tiler/io/stac.py index 4ad53f2f..4698bab9 100644 --- a/rio_tiler/io/stac.py +++ b/rio_tiler/io/stac.py @@ -3,7 +3,7 @@ import json import os import warnings -from typing import Any, Dict, Iterator, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, Iterator, Optional, Sequence, Set, Tuple, Type, Union from urllib.parse import urlparse import attr @@ -200,6 +200,7 @@ class STACReader(MultiBaseReader): exclude_assets (set of string, optional): Exclude specific assets. include_asset_types (set of string, optional): Only include some assets base on their type. exclude_asset_types (set of string, optional): Exclude some assets base on their type. + default_assets (list of string, optional): Default assets to use if none are defined. reader (rio_tiler.io.BaseReader, optional): rio-tiler Reader. Defaults to `rio_tiler.io.Reader`. reader_options (dict, optional): Additional option to forward to the Reader. Defaults to `{}`. fetch_options (dict, optional): Options to pass to `rio_tiler.io.stac.fetch` function fetching the STAC Items. Defaults to `{}`. @@ -238,6 +239,8 @@ class STACReader(MultiBaseReader): include_asset_types: Set[str] = attr.ib(default=DEFAULT_VALID_TYPE) exclude_asset_types: Optional[Set[str]] = attr.ib(default=None) + default_assets: Optional[Sequence[str]] = attr.ib(default=None) + reader: Type[BaseReader] = attr.ib(default=Reader) reader_options: Dict = attr.ib(factory=dict) diff --git a/tests/test_io_MultiBand.py b/tests/test_io_MultiBand.py index ab910bdd..f72c473e 100644 --- a/tests/test_io_MultiBand.py +++ b/tests/test_io_MultiBand.py @@ -2,7 +2,7 @@ import os import pathlib -from typing import Dict, Type +from typing import Dict, Optional, Sequence, Type import attr import morecantile @@ -26,6 +26,8 @@ class BandFileReader(MultiBandReader): reader: Type[BaseReader] = attr.ib(init=False, default=Reader) reader_options: Dict = attr.ib(factory=dict) + default_bands: Optional[Sequence[str]] = attr.ib(default=None) + minzoom: int = attr.ib() maxzoom: int = attr.ib() @@ -189,3 +191,50 @@ def test_MultiBandReader(): assert img.metadata assert img.metadata["band1"] assert img.metadata["band2"] + + +def test_MultiBandReader_default_bands(): + """Should work as expected.""" + with BandFileReader(PREFIX, default_bands=["band1"]) as src: + assert src.bands == ["band1", "band2"] + + with pytest.warns(UserWarning): + img = src.tile(238, 218, 9) + assert img.data.shape == (1, 256, 256) + assert img.band_names == ["band1"] + + with pytest.warns(UserWarning): + img = src.part((-11.5, 24.5, -11.0, 25.0)) + assert img.band_names == ["band1"] + + with pytest.warns(UserWarning): + img = src.preview() + assert img.band_names == ["band1"] + + with pytest.warns(UserWarning): + pt = src.point(-11.5, 24.5) + assert len(pt.data) == 1 + assert pt.band_names == ["band1"] + + feat = { + "type": "Feature", + "properties": {}, + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [-12.03826904296875, 24.87646991083154], + [-12.14263916015625, 24.831610355586918], + [-12.1563720703125, 24.709410369765177], + [-12.1673583984375, 24.484648999654034], + [-11.898193359375, 24.472150437226865], + [-11.6729736328125, 24.542126388899305], + [-11.47247314453125, 24.79920167537382], + [-12.03826904296875, 24.87646991083154], + ] + ], + }, + } + with pytest.warns(UserWarning): + img = src.feature(feat) + assert img.band_names == ["band1"] diff --git a/tests/test_io_stac.py b/tests/test_io_stac.py index 9ba84dee..06bf8eff 100644 --- a/tests/test_io_stac.py +++ b/tests/test_io_stac.py @@ -894,6 +894,66 @@ def test_expression_with_wrong_stac_stats(rio): ) +@patch("rio_tiler.io.rasterio.rasterio") +def test_default_assets(rio): + """Should raise or return tiles.""" + rio.open = mock_rasterio_open + + bbox = (-80.477, 32.7988, -79.737, 33.4453) + + feat = { + "type": "Feature", + "properties": {}, + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [-80.013427734375, 33.03169299978312], + [-80.3045654296875, 32.588477769459146], + [-80.05462646484375, 32.42865847084369], + [-79.45037841796875, 32.6093028087336], + [-79.47235107421875, 33.43602551072033], + [-79.89532470703125, 33.47956309444182], + [-80.1068115234375, 33.37870592138779], + [-80.30181884765625, 33.27084277265288], + [-80.0628662109375, 33.146750228776455], + [-80.013427734375, 33.03169299978312], + ] + ], + }, + } + + with STACReader(STAC_PATH, default_assets=["green"]) as stac: + with pytest.warns(UserWarning): + img = stac.tile(71, 102, 8) + assert img.data.shape == (1, 256, 256) + assert img.mask.shape == (256, 256) + assert img.band_names == ["green_b1"] + + with pytest.warns(UserWarning): + pt = stac.point(-80.477, 33.4453) + assert len(pt.data) == 1 + assert pt.band_names == ["green_b1"] + + with pytest.warns(UserWarning): + img = stac.preview() + assert img.data.shape == (1, 259, 255) + assert img.mask.shape == (259, 255) + assert img.band_names == ["green_b1"] + + with pytest.warns(UserWarning): + img = stac.part(bbox) + assert img.data.shape == (1, 73, 83) + assert img.mask.shape == (73, 83) + assert img.band_names == ["green_b1"] + + with pytest.warns(UserWarning): + img = stac.feature(feat) + assert img.data.shape == (1, 118, 96) + assert img.mask.shape == (118, 96) + assert img.band_names == ["green_b1"] + + def test_get_reader(): """Should use the correct reader depending on the media type.""" valid_types = { diff --git a/tests/test_reader.py b/tests/test_reader.py index ece1186a..2253eac4 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -855,6 +855,7 @@ def test_tile_read_nodata_float(): def test_inverted_latitude_point(): """Make sure we can read a point from a file with inverted latitude.""" - with rasterio.open(COG_INVERTED) as src_dst: - pt = reader.point(src_dst, [-104.77519499, 38.95367054]) - assert pt.data[0] == -9999.0 + with pytest.warns(UserWarning): + with rasterio.open(COG_INVERTED) as src_dst: + pt = reader.point(src_dst, [-104.77519499, 38.95367054]) + assert pt.data[0] == -9999.0