Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding default_assets for MultiBaseReader and STACReader #722

Merged
merged 6 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Unreleased

* Enable dynamic definition of Asset **reader** in `MultiBaseReader` (https://github.com/cogeotiff/rio-tiler/pull/711/, https://github.com/cogeotiff/rio-tiler/pull/728)
* Adding `default_assets` for MultiBaseReader and STACReader (author @mccarthyryanc, https://github.com/cogeotiff/rio-tiler/pull/722)
* Adding `default_bands` for MultiBandReader (https://github.com/cogeotiff/rio-tiler/pull/722)

# 6.7.0 (2024-09-05)

Expand Down
1 change: 1 addition & 0 deletions docs/src/readers.md
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ STACReader.__mro__
- **exclude_assets** (set, optional): Set of assets to exclude from the `available` asset list
- **include_asset_types** (set, optional): asset types to consider as valid type for the reader
- **exclude_asset_types** (set, optional): asset types to consider as invalid type for the reader
- **default_assets** (sequence, optional): default assets to use for the reader if nothing else is provided.
- **reader** (BaseReader, optional): Reader to use to read assets (defaults to rio_tiler.io.rasterio.Reader)
- **reader_options** (dict, optional): Options to forward to the reader init
- **fetch_options** (dict, optional): Options to pass to the `httpx.get` or `boto3` when fetching the STAC item
Expand Down
84 changes: 79 additions & 5 deletions rio_tiler/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ class MultiBaseReader(SpatialMixin, metaclass=abc.ABCMeta):

assets: Sequence[str] = attr.ib(init=False)

default_assets: Optional[Sequence[str]] = attr.ib(init=False, default=None)

ctx: Any = attr.ib(init=False, default=contextlib.nullcontext)

def __enter__(self):
Expand Down Expand Up @@ -500,9 +502,16 @@ def tile(
if expression:
assets = self.parse_expression(expression, asset_as_band=asset_as_band)

if not assets and self.default_assets:
warnings.warn(
f"No assets/expression passed, defaults to {self.default_assets}",
UserWarning,
)
assets = self.default_assets

if not assets:
raise MissingAssets(
"assets must be passed either via `expression` or `assets` options."
"assets must be passed via `expression` or `assets` options, or via class-level `default_assets`."
)

asset_indexes = asset_indexes or {}
Expand Down Expand Up @@ -584,9 +593,16 @@ def part(
if expression:
assets = self.parse_expression(expression, asset_as_band=asset_as_band)

if not assets and self.default_assets:
warnings.warn(
f"No assets/expression passed, defaults to {self.default_assets}",
UserWarning,
)
assets = self.default_assets

if not assets:
raise MissingAssets(
"assets must be passed either via `expression` or `assets` options."
"assets must be passed via `expression` or `assets` options, or via class-level `default_assets`."
)

asset_indexes = asset_indexes or {}
Expand Down Expand Up @@ -666,9 +682,16 @@ def preview(
if expression:
assets = self.parse_expression(expression, asset_as_band=asset_as_band)

if not assets and self.default_assets:
warnings.warn(
f"No assets/expression passed, defaults to {self.default_assets}",
UserWarning,
)
assets = self.default_assets

if not assets:
raise MissingAssets(
"assets must be passed either via `expression` or `assets` options."
"assets must be passed via `expression` or `assets` options, or via class-level `default_assets`."
)

asset_indexes = asset_indexes or {}
Expand Down Expand Up @@ -752,9 +775,16 @@ def point(
if expression:
assets = self.parse_expression(expression, asset_as_band=asset_as_band)

if not assets and self.default_assets:
warnings.warn(
f"No assets/expression passed, defaults to {self.default_assets}",
UserWarning,
)
assets = self.default_assets

if not assets:
raise MissingAssets(
"assets must be passed either via `expression` or `assets` options."
"assets must be passed via `expression` or `assets` options, or via class-level `default_assets`."
)

asset_indexes = asset_indexes or {}
Expand Down Expand Up @@ -830,9 +860,16 @@ def feature(
if expression:
assets = self.parse_expression(expression, asset_as_band=asset_as_band)

if not assets and self.default_assets:
warnings.warn(
f"No assets/expression passed, defaults to {self.default_assets}",
UserWarning,
)
assets = self.default_assets

if not assets:
raise MissingAssets(
"assets must be passed either via `expression` or `assets` options."
"assets must be passed via `expression` or `assets` options, or via class-level `default_assets`."
)

asset_indexes = asset_indexes or {}
Expand Down Expand Up @@ -909,6 +946,8 @@ class MultiBandReader(SpatialMixin, metaclass=abc.ABCMeta):

bands: Sequence[str] = attr.ib(init=False)

default_bands: Optional[Sequence[str]] = attr.ib(init=False, default=None)

def __enter__(self):
"""Support using with Context Managers."""
return self
Expand Down Expand Up @@ -1076,6 +1115,13 @@ def tile(
if expression:
bands = self.parse_expression(expression)

if not bands and self.default_bands:
warnings.warn(
f"No bands/expression passed, defaults to {self.default_bands}",
UserWarning,
)
bands = self.default_bands

if not bands:
raise MissingBands(
"bands must be passed either via `expression` or `bands` options."
Expand Down Expand Up @@ -1134,6 +1180,13 @@ def part(
if expression:
bands = self.parse_expression(expression)

if not bands and self.default_bands:
warnings.warn(
f"No bands/expression passed, defaults to {self.default_bands}",
UserWarning,
)
bands = self.default_bands

if not bands:
raise MissingBands(
"bands must be passed either via `expression` or `bands` options."
Expand Down Expand Up @@ -1190,6 +1243,13 @@ def preview(
if expression:
bands = self.parse_expression(expression)

if not bands and self.default_bands:
warnings.warn(
f"No bands/expression passed, defaults to {self.default_bands}",
UserWarning,
)
bands = self.default_bands

if not bands:
raise MissingBands(
"bands must be passed either via `expression` or `bands` options."
Expand Down Expand Up @@ -1250,6 +1310,13 @@ def point(
if expression:
bands = self.parse_expression(expression)

if not bands and self.default_bands:
warnings.warn(
f"No bands/expression passed, defaults to {self.default_bands}",
UserWarning,
)
bands = self.default_bands

if not bands:
raise MissingBands(
"bands must be passed either via `expression` or `bands` options."
Expand Down Expand Up @@ -1307,6 +1374,13 @@ def feature(
if expression:
bands = self.parse_expression(expression)

if not bands and self.default_bands:
warnings.warn(
f"No bands/expression passed, defaults to {self.default_bands}",
UserWarning,
)
bands = self.default_bands

if not bands:
raise MissingBands(
"bands must be passed either via `expression` or `bands` options."
Expand Down
5 changes: 4 additions & 1 deletion rio_tiler/io/stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import os
import warnings
from typing import Any, Dict, Iterator, Optional, Set, Tuple, Type, Union
from typing import Any, Dict, Iterator, Optional, Sequence, Set, Tuple, Type, Union
from urllib.parse import urlparse

import attr
Expand Down Expand Up @@ -200,6 +200,7 @@ class STACReader(MultiBaseReader):
exclude_assets (set of string, optional): Exclude specific assets.
include_asset_types (set of string, optional): Only include some assets base on their type.
exclude_asset_types (set of string, optional): Exclude some assets base on their type.
default_assets (list of string, optional): Default assets to use if none are defined.
reader (rio_tiler.io.BaseReader, optional): rio-tiler Reader. Defaults to `rio_tiler.io.Reader`.
reader_options (dict, optional): Additional option to forward to the Reader. Defaults to `{}`.
fetch_options (dict, optional): Options to pass to `rio_tiler.io.stac.fetch` function fetching the STAC Items. Defaults to `{}`.
Expand Down Expand Up @@ -238,6 +239,8 @@ class STACReader(MultiBaseReader):
include_asset_types: Set[str] = attr.ib(default=DEFAULT_VALID_TYPE)
exclude_asset_types: Optional[Set[str]] = attr.ib(default=None)

default_assets: Optional[Sequence[str]] = attr.ib(default=None)

reader: Type[BaseReader] = attr.ib(default=Reader)
reader_options: Dict = attr.ib(factory=dict)

Expand Down
51 changes: 50 additions & 1 deletion tests/test_io_MultiBand.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import pathlib
from typing import Dict, Type
from typing import Dict, Optional, Sequence, Type

import attr
import morecantile
Expand All @@ -26,6 +26,8 @@ class BandFileReader(MultiBandReader):
reader: Type[BaseReader] = attr.ib(init=False, default=Reader)
reader_options: Dict = attr.ib(factory=dict)

default_bands: Optional[Sequence[str]] = attr.ib(default=None)

minzoom: int = attr.ib()
maxzoom: int = attr.ib()

Expand Down Expand Up @@ -189,3 +191,50 @@ def test_MultiBandReader():
assert img.metadata
assert img.metadata["band1"]
assert img.metadata["band2"]


def test_MultiBandReader_default_bands():
"""Should work as expected."""
with BandFileReader(PREFIX, default_bands=["band1"]) as src:
assert src.bands == ["band1", "band2"]

with pytest.warns(UserWarning):
img = src.tile(238, 218, 9)
assert img.data.shape == (1, 256, 256)
assert img.band_names == ["band1"]

with pytest.warns(UserWarning):
img = src.part((-11.5, 24.5, -11.0, 25.0))
assert img.band_names == ["band1"]

with pytest.warns(UserWarning):
img = src.preview()
assert img.band_names == ["band1"]

with pytest.warns(UserWarning):
pt = src.point(-11.5, 24.5)
assert len(pt.data) == 1
assert pt.band_names == ["band1"]

feat = {
"type": "Feature",
"properties": {},
"geometry": {
"type": "Polygon",
"coordinates": [
[
[-12.03826904296875, 24.87646991083154],
[-12.14263916015625, 24.831610355586918],
[-12.1563720703125, 24.709410369765177],
[-12.1673583984375, 24.484648999654034],
[-11.898193359375, 24.472150437226865],
[-11.6729736328125, 24.542126388899305],
[-11.47247314453125, 24.79920167537382],
[-12.03826904296875, 24.87646991083154],
]
],
},
}
with pytest.warns(UserWarning):
img = src.feature(feat)
assert img.band_names == ["band1"]
60 changes: 60 additions & 0 deletions tests/test_io_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,66 @@ def test_expression_with_wrong_stac_stats(rio):
)


@patch("rio_tiler.io.rasterio.rasterio")
def test_default_assets(rio):
"""Should raise or return tiles."""
rio.open = mock_rasterio_open

bbox = (-80.477, 32.7988, -79.737, 33.4453)

feat = {
"type": "Feature",
"properties": {},
"geometry": {
"type": "Polygon",
"coordinates": [
[
[-80.013427734375, 33.03169299978312],
[-80.3045654296875, 32.588477769459146],
[-80.05462646484375, 32.42865847084369],
[-79.45037841796875, 32.6093028087336],
[-79.47235107421875, 33.43602551072033],
[-79.89532470703125, 33.47956309444182],
[-80.1068115234375, 33.37870592138779],
[-80.30181884765625, 33.27084277265288],
[-80.0628662109375, 33.146750228776455],
[-80.013427734375, 33.03169299978312],
]
],
},
}

with STACReader(STAC_PATH, default_assets=["green"]) as stac:
with pytest.warns(UserWarning):
img = stac.tile(71, 102, 8)
assert img.data.shape == (1, 256, 256)
assert img.mask.shape == (256, 256)
assert img.band_names == ["green_b1"]

with pytest.warns(UserWarning):
pt = stac.point(-80.477, 33.4453)
assert len(pt.data) == 1
assert pt.band_names == ["green_b1"]

with pytest.warns(UserWarning):
img = stac.preview()
assert img.data.shape == (1, 259, 255)
assert img.mask.shape == (259, 255)
assert img.band_names == ["green_b1"]

with pytest.warns(UserWarning):
img = stac.part(bbox)
assert img.data.shape == (1, 73, 83)
assert img.mask.shape == (73, 83)
assert img.band_names == ["green_b1"]

with pytest.warns(UserWarning):
img = stac.feature(feat)
assert img.data.shape == (1, 118, 96)
assert img.mask.shape == (118, 96)
assert img.band_names == ["green_b1"]


def test_get_reader():
"""Should use the correct reader depending on the media type."""
valid_types = {
Expand Down
7 changes: 4 additions & 3 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@ def test_tile_read_nodata_float():

def test_inverted_latitude_point():
"""Make sure we can read a point from a file with inverted latitude."""
with rasterio.open(COG_INVERTED) as src_dst:
pt = reader.point(src_dst, [-104.77519499, 38.95367054])
assert pt.data[0] == -9999.0
with pytest.warns(UserWarning):
with rasterio.open(COG_INVERTED) as src_dst:
pt = reader.point(src_dst, [-104.77519499, 38.95367054])
assert pt.data[0] == -9999.0
Loading