Skip to content

Commit

Permalink
Adding default_assets for MultiBaseReader and STACReader (#722)
Browse files Browse the repository at this point in the history
* adding a default_assets attribute to MutliBaseReader and STACReader, along with tests

* updating docs

* updating docstring and adding warning when using default_assets

* add default_bands to MultiBandReader

* update changelog

---------

Co-authored-by: Ryan McCarthy <[email protected]>
Co-authored-by: vincentsarago <[email protected]>
  • Loading branch information
3 people authored Sep 9, 2024
1 parent 9eae282 commit 9bd2127
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 10 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Unreleased

* Enable dynamic definition of Asset **reader** in `MultiBaseReader` (https://github.com/cogeotiff/rio-tiler/pull/711/, https://github.com/cogeotiff/rio-tiler/pull/728)
* Adding `default_assets` for MultiBaseReader and STACReader (author @mccarthyryanc, https://github.com/cogeotiff/rio-tiler/pull/722)
* Adding `default_bands` for MultiBandReader (https://github.com/cogeotiff/rio-tiler/pull/722)

# 6.7.0 (2024-09-05)

Expand Down
1 change: 1 addition & 0 deletions docs/src/readers.md
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ STACReader.__mro__
- **exclude_assets** (set, optional): Set of assets to exclude from the `available` asset list
- **include_asset_types** (set, optional): asset types to consider as valid type for the reader
- **exclude_asset_types** (set, optional): asset types to consider as invalid type for the reader
- **default_assets** (sequence, optional): default assets to use for the reader if nothing else is provided.
- **reader** (BaseReader, optional): Reader to use to read assets (defaults to rio_tiler.io.rasterio.Reader)
- **reader_options** (dict, optional): Options to forward to the reader init
- **fetch_options** (dict, optional): Options to pass to the `httpx.get` or `boto3` when fetching the STAC item
Expand Down
84 changes: 79 additions & 5 deletions rio_tiler/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ class MultiBaseReader(SpatialMixin, metaclass=abc.ABCMeta):

assets: Sequence[str] = attr.ib(init=False)

default_assets: Optional[Sequence[str]] = attr.ib(init=False, default=None)

ctx: Any = attr.ib(init=False, default=contextlib.nullcontext)

def __enter__(self):
Expand Down Expand Up @@ -500,9 +502,16 @@ def tile(
if expression:
assets = self.parse_expression(expression, asset_as_band=asset_as_band)

if not assets and self.default_assets:
warnings.warn(
f"No assets/expression passed, defaults to {self.default_assets}",
UserWarning,
)
assets = self.default_assets

if not assets:
raise MissingAssets(
"assets must be passed either via `expression` or `assets` options."
"assets must be passed via `expression` or `assets` options, or via class-level `default_assets`."
)

asset_indexes = asset_indexes or {}
Expand Down Expand Up @@ -584,9 +593,16 @@ def part(
if expression:
assets = self.parse_expression(expression, asset_as_band=asset_as_band)

if not assets and self.default_assets:
warnings.warn(
f"No assets/expression passed, defaults to {self.default_assets}",
UserWarning,
)
assets = self.default_assets

if not assets:
raise MissingAssets(
"assets must be passed either via `expression` or `assets` options."
"assets must be passed via `expression` or `assets` options, or via class-level `default_assets`."
)

asset_indexes = asset_indexes or {}
Expand Down Expand Up @@ -666,9 +682,16 @@ def preview(
if expression:
assets = self.parse_expression(expression, asset_as_band=asset_as_band)

if not assets and self.default_assets:
warnings.warn(
f"No assets/expression passed, defaults to {self.default_assets}",
UserWarning,
)
assets = self.default_assets

if not assets:
raise MissingAssets(
"assets must be passed either via `expression` or `assets` options."
"assets must be passed via `expression` or `assets` options, or via class-level `default_assets`."
)

asset_indexes = asset_indexes or {}
Expand Down Expand Up @@ -752,9 +775,16 @@ def point(
if expression:
assets = self.parse_expression(expression, asset_as_band=asset_as_band)

if not assets and self.default_assets:
warnings.warn(
f"No assets/expression passed, defaults to {self.default_assets}",
UserWarning,
)
assets = self.default_assets

if not assets:
raise MissingAssets(
"assets must be passed either via `expression` or `assets` options."
"assets must be passed via `expression` or `assets` options, or via class-level `default_assets`."
)

asset_indexes = asset_indexes or {}
Expand Down Expand Up @@ -830,9 +860,16 @@ def feature(
if expression:
assets = self.parse_expression(expression, asset_as_band=asset_as_band)

if not assets and self.default_assets:
warnings.warn(
f"No assets/expression passed, defaults to {self.default_assets}",
UserWarning,
)
assets = self.default_assets

if not assets:
raise MissingAssets(
"assets must be passed either via `expression` or `assets` options."
"assets must be passed via `expression` or `assets` options, or via class-level `default_assets`."
)

asset_indexes = asset_indexes or {}
Expand Down Expand Up @@ -909,6 +946,8 @@ class MultiBandReader(SpatialMixin, metaclass=abc.ABCMeta):

bands: Sequence[str] = attr.ib(init=False)

default_bands: Optional[Sequence[str]] = attr.ib(init=False, default=None)

def __enter__(self):
"""Support using with Context Managers."""
return self
Expand Down Expand Up @@ -1076,6 +1115,13 @@ def tile(
if expression:
bands = self.parse_expression(expression)

if not bands and self.default_bands:
warnings.warn(
f"No bands/expression passed, defaults to {self.default_bands}",
UserWarning,
)
bands = self.default_bands

if not bands:
raise MissingBands(
"bands must be passed either via `expression` or `bands` options."
Expand Down Expand Up @@ -1134,6 +1180,13 @@ def part(
if expression:
bands = self.parse_expression(expression)

if not bands and self.default_bands:
warnings.warn(
f"No bands/expression passed, defaults to {self.default_bands}",
UserWarning,
)
bands = self.default_bands

if not bands:
raise MissingBands(
"bands must be passed either via `expression` or `bands` options."
Expand Down Expand Up @@ -1190,6 +1243,13 @@ def preview(
if expression:
bands = self.parse_expression(expression)

if not bands and self.default_bands:
warnings.warn(
f"No bands/expression passed, defaults to {self.default_bands}",
UserWarning,
)
bands = self.default_bands

if not bands:
raise MissingBands(
"bands must be passed either via `expression` or `bands` options."
Expand Down Expand Up @@ -1250,6 +1310,13 @@ def point(
if expression:
bands = self.parse_expression(expression)

if not bands and self.default_bands:
warnings.warn(
f"No bands/expression passed, defaults to {self.default_bands}",
UserWarning,
)
bands = self.default_bands

if not bands:
raise MissingBands(
"bands must be passed either via `expression` or `bands` options."
Expand Down Expand Up @@ -1307,6 +1374,13 @@ def feature(
if expression:
bands = self.parse_expression(expression)

if not bands and self.default_bands:
warnings.warn(
f"No bands/expression passed, defaults to {self.default_bands}",
UserWarning,
)
bands = self.default_bands

if not bands:
raise MissingBands(
"bands must be passed either via `expression` or `bands` options."
Expand Down
5 changes: 4 additions & 1 deletion rio_tiler/io/stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import os
import warnings
from typing import Any, Dict, Iterator, Optional, Set, Tuple, Type, Union
from typing import Any, Dict, Iterator, Optional, Sequence, Set, Tuple, Type, Union
from urllib.parse import urlparse

import attr
Expand Down Expand Up @@ -200,6 +200,7 @@ class STACReader(MultiBaseReader):
exclude_assets (set of string, optional): Exclude specific assets.
include_asset_types (set of string, optional): Only include some assets base on their type.
exclude_asset_types (set of string, optional): Exclude some assets base on their type.
default_assets (list of string, optional): Default assets to use if none are defined.
reader (rio_tiler.io.BaseReader, optional): rio-tiler Reader. Defaults to `rio_tiler.io.Reader`.
reader_options (dict, optional): Additional option to forward to the Reader. Defaults to `{}`.
fetch_options (dict, optional): Options to pass to `rio_tiler.io.stac.fetch` function fetching the STAC Items. Defaults to `{}`.
Expand Down Expand Up @@ -238,6 +239,8 @@ class STACReader(MultiBaseReader):
include_asset_types: Set[str] = attr.ib(default=DEFAULT_VALID_TYPE)
exclude_asset_types: Optional[Set[str]] = attr.ib(default=None)

default_assets: Optional[Sequence[str]] = attr.ib(default=None)

reader: Type[BaseReader] = attr.ib(default=Reader)
reader_options: Dict = attr.ib(factory=dict)

Expand Down
51 changes: 50 additions & 1 deletion tests/test_io_MultiBand.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import pathlib
from typing import Dict, Type
from typing import Dict, Optional, Sequence, Type

import attr
import morecantile
Expand All @@ -26,6 +26,8 @@ class BandFileReader(MultiBandReader):
reader: Type[BaseReader] = attr.ib(init=False, default=Reader)
reader_options: Dict = attr.ib(factory=dict)

default_bands: Optional[Sequence[str]] = attr.ib(default=None)

minzoom: int = attr.ib()
maxzoom: int = attr.ib()

Expand Down Expand Up @@ -189,3 +191,50 @@ def test_MultiBandReader():
assert img.metadata
assert img.metadata["band1"]
assert img.metadata["band2"]


def test_MultiBandReader_default_bands():
"""Should work as expected."""
with BandFileReader(PREFIX, default_bands=["band1"]) as src:
assert src.bands == ["band1", "band2"]

with pytest.warns(UserWarning):
img = src.tile(238, 218, 9)
assert img.data.shape == (1, 256, 256)
assert img.band_names == ["band1"]

with pytest.warns(UserWarning):
img = src.part((-11.5, 24.5, -11.0, 25.0))
assert img.band_names == ["band1"]

with pytest.warns(UserWarning):
img = src.preview()
assert img.band_names == ["band1"]

with pytest.warns(UserWarning):
pt = src.point(-11.5, 24.5)
assert len(pt.data) == 1
assert pt.band_names == ["band1"]

feat = {
"type": "Feature",
"properties": {},
"geometry": {
"type": "Polygon",
"coordinates": [
[
[-12.03826904296875, 24.87646991083154],
[-12.14263916015625, 24.831610355586918],
[-12.1563720703125, 24.709410369765177],
[-12.1673583984375, 24.484648999654034],
[-11.898193359375, 24.472150437226865],
[-11.6729736328125, 24.542126388899305],
[-11.47247314453125, 24.79920167537382],
[-12.03826904296875, 24.87646991083154],
]
],
},
}
with pytest.warns(UserWarning):
img = src.feature(feat)
assert img.band_names == ["band1"]
60 changes: 60 additions & 0 deletions tests/test_io_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,66 @@ def test_expression_with_wrong_stac_stats(rio):
)


@patch("rio_tiler.io.rasterio.rasterio")
def test_default_assets(rio):
"""Should raise or return tiles."""
rio.open = mock_rasterio_open

bbox = (-80.477, 32.7988, -79.737, 33.4453)

feat = {
"type": "Feature",
"properties": {},
"geometry": {
"type": "Polygon",
"coordinates": [
[
[-80.013427734375, 33.03169299978312],
[-80.3045654296875, 32.588477769459146],
[-80.05462646484375, 32.42865847084369],
[-79.45037841796875, 32.6093028087336],
[-79.47235107421875, 33.43602551072033],
[-79.89532470703125, 33.47956309444182],
[-80.1068115234375, 33.37870592138779],
[-80.30181884765625, 33.27084277265288],
[-80.0628662109375, 33.146750228776455],
[-80.013427734375, 33.03169299978312],
]
],
},
}

with STACReader(STAC_PATH, default_assets=["green"]) as stac:
with pytest.warns(UserWarning):
img = stac.tile(71, 102, 8)
assert img.data.shape == (1, 256, 256)
assert img.mask.shape == (256, 256)
assert img.band_names == ["green_b1"]

with pytest.warns(UserWarning):
pt = stac.point(-80.477, 33.4453)
assert len(pt.data) == 1
assert pt.band_names == ["green_b1"]

with pytest.warns(UserWarning):
img = stac.preview()
assert img.data.shape == (1, 259, 255)
assert img.mask.shape == (259, 255)
assert img.band_names == ["green_b1"]

with pytest.warns(UserWarning):
img = stac.part(bbox)
assert img.data.shape == (1, 73, 83)
assert img.mask.shape == (73, 83)
assert img.band_names == ["green_b1"]

with pytest.warns(UserWarning):
img = stac.feature(feat)
assert img.data.shape == (1, 118, 96)
assert img.mask.shape == (118, 96)
assert img.band_names == ["green_b1"]


def test_get_reader():
"""Should use the correct reader depending on the media type."""
valid_types = {
Expand Down
7 changes: 4 additions & 3 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@ def test_tile_read_nodata_float():

def test_inverted_latitude_point():
"""Make sure we can read a point from a file with inverted latitude."""
with rasterio.open(COG_INVERTED) as src_dst:
pt = reader.point(src_dst, [-104.77519499, 38.95367054])
assert pt.data[0] == -9999.0
with pytest.warns(UserWarning):
with rasterio.open(COG_INVERTED) as src_dst:
pt = reader.point(src_dst, [-104.77519499, 38.95367054])
assert pt.data[0] == -9999.0

0 comments on commit 9bd2127

Please sign in to comment.