Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding default_assets for MultiBaseReader and STACReader #722

Merged
merged 6 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/api/rio_tiler/io/stac.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class STACReader(
exclude_assets: Optional[Set[str]] = None,
include_asset_types: Set[str] = {'image/tiff; profile=cloud-optimized; application=geotiff', 'image/tiff; application=geotiff', 'image/x.geotiff', 'image/tiff; application=geotiff; profile=cloud-optimized', 'image/vnd.stac.geotiff; cloud-optimized=true', 'application/x-hdf', 'image/jp2', 'application/x-hdf5', 'image/tiff'},
exclude_asset_types: Optional[Set[str]] = None,
default_assets: Optional[Sequence[str]] = attr.ib(default=None),
reader: Type[rio_tiler.io.base.BaseReader] = <class 'rio_tiler.io.rasterio.Reader'>,
reader_options: Dict = NOTHING,
fetch_options: Dict = NOTHING,
Expand All @@ -96,6 +97,7 @@ class STACReader(
| exclude_assets | set of string | Exclude specific assets. | None |
| include_asset_types | set of string | Only include some assets base on their type. | None |
| exclude_asset_types | set of string | Exclude some assets base on their type. | None |
| default_assets | Sequence of string | Default assets to use if none are defined. | None |
| reader | rio_tiler.io.BaseReader | rio-tiler Reader. Defaults to `rio_tiler.io.Reader`. | `rio_tiler.io.Reader` |
| reader_options | dict | Additional option to forward to the Reader. Defaults to `{}`. | `{}` |
| fetch_options | dict | Options to pass to `rio_tiler.io.stac.fetch` function fetching the STAC Items. Defaults to `{}`. | `{}` |
Expand Down
1 change: 1 addition & 0 deletions docs/src/readers.md
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ STACReader.__mro__
- **exclude_assets** (set, optional): Set of assets to exclude from the `available` asset list
- **include_asset_types** (set, optional): asset types to consider as valid type for the reader
- **exclude_asset_types** (set, optional): asset types to consider as invalid type for the reader
- **default_assets** (sequence, optional): default assets to use for the reader if nothing else is provided.
- **reader** (BaseReader, optional): Reader to use to read assets (defaults to rio_tiler.io.rasterio.Reader)
- **reader_options** (dict, optional): Options to forward to the reader init
- **fetch_options** (dict, optional): Options to pass to the `httpx.get` or `boto3` when fetching the STAC item
Expand Down
22 changes: 17 additions & 5 deletions rio_tiler/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ class MultiBaseReader(SpatialMixin, metaclass=abc.ABCMeta):

assets: Sequence[str] = attr.ib(init=False)

default_assets: Optional[Sequence[str]] = attr.ib(init=False, default=None)

ctx: Any = attr.ib(init=False, default=contextlib.nullcontext)

def __enter__(self):
Expand Down Expand Up @@ -495,9 +497,11 @@ def tile(
if expression:
assets = self.parse_expression(expression, asset_as_band=asset_as_band)

assets = assets or self.default_assets
mccarthyryanc marked this conversation as resolved.
Show resolved Hide resolved

if not assets:
raise MissingAssets(
"assets must be passed either via `expression` or `assets` options."
"assets must be passed via `expression` or `assets` options, or via class-level `default_assets`."
)

asset_indexes = asset_indexes or {}
Expand Down Expand Up @@ -576,9 +580,11 @@ def part(
if expression:
assets = self.parse_expression(expression, asset_as_band=asset_as_band)

assets = assets or self.default_assets

if not assets:
raise MissingAssets(
"assets must be passed either via `expression` or `assets` options."
"assets must be passed via `expression` or `assets` options, or via class-level `default_assets`."
)

asset_indexes = asset_indexes or {}
Expand Down Expand Up @@ -655,9 +661,11 @@ def preview(
if expression:
assets = self.parse_expression(expression, asset_as_band=asset_as_band)

assets = assets or self.default_assets

if not assets:
raise MissingAssets(
"assets must be passed either via `expression` or `assets` options."
"assets must be passed via `expression` or `assets` options, or via class-level `default_assets`."
)

asset_indexes = asset_indexes or {}
Expand Down Expand Up @@ -738,9 +746,11 @@ def point(
if expression:
assets = self.parse_expression(expression, asset_as_band=asset_as_band)

assets = assets or self.default_assets

if not assets:
raise MissingAssets(
"assets must be passed either via `expression` or `assets` options."
"assets must be passed via `expression` or `assets` options, or via class-level `default_assets`."
)

asset_indexes = asset_indexes or {}
Expand Down Expand Up @@ -813,9 +823,11 @@ def feature(
if expression:
assets = self.parse_expression(expression, asset_as_band=asset_as_band)

assets = assets or self.default_assets

if not assets:
raise MissingAssets(
"assets must be passed either via `expression` or `assets` options."
"assets must be passed via `expression` or `assets` options, or via class-level `default_assets`."
)

asset_indexes = asset_indexes or {}
Expand Down
5 changes: 4 additions & 1 deletion rio_tiler/io/stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import os
import warnings
from typing import Any, Dict, Iterator, Optional, Set, Type, Union
from typing import Any, Dict, Iterator, Optional, Sequence, Set, Type, Union
from urllib.parse import urlparse

import attr
Expand Down Expand Up @@ -200,6 +200,7 @@ class STACReader(MultiBaseReader):
exclude_assets (set of string, optional): Exclude specific assets.
include_asset_types (set of string, optional): Only include some assets base on their type.
exclude_asset_types (set of string, optional): Exclude some assets base on their type.
default_assets (set of string, optional): Default assets to use if none are defined.
mccarthyryanc marked this conversation as resolved.
Show resolved Hide resolved
reader (rio_tiler.io.BaseReader, optional): rio-tiler Reader. Defaults to `rio_tiler.io.Reader`.
reader_options (dict, optional): Additional option to forward to the Reader. Defaults to `{}`.
fetch_options (dict, optional): Options to pass to `rio_tiler.io.stac.fetch` function fetching the STAC Items. Defaults to `{}`.
Expand Down Expand Up @@ -238,6 +239,8 @@ class STACReader(MultiBaseReader):
include_asset_types: Set[str] = attr.ib(default=DEFAULT_VALID_TYPE)
exclude_asset_types: Optional[Set[str]] = attr.ib(default=None)

default_assets: Optional[Sequence[str]] = attr.ib(default=None)

reader: Type[BaseReader] = attr.ib(default=Reader)
reader_options: Dict = attr.ib(factory=dict)

Expand Down
55 changes: 55 additions & 0 deletions tests/test_io_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,3 +890,58 @@ def test_expression_with_wrong_stac_stats(rio):
expression="where((wrongstat>0.5),1,0)",
asset_as_band=True,
)


@patch("rio_tiler.io.rasterio.rasterio")
def test_default_assets(rio):
"""Should raise or return tiles."""
rio.open = mock_rasterio_open

bbox = (-80.477, 32.7988, -79.737, 33.4453)

feat = {
"type": "Feature",
"properties": {},
"geometry": {
"type": "Polygon",
"coordinates": [
[
[-80.013427734375, 33.03169299978312],
[-80.3045654296875, 32.588477769459146],
[-80.05462646484375, 32.42865847084369],
[-79.45037841796875, 32.6093028087336],
[-79.47235107421875, 33.43602551072033],
[-79.89532470703125, 33.47956309444182],
[-80.1068115234375, 33.37870592138779],
[-80.30181884765625, 33.27084277265288],
[-80.0628662109375, 33.146750228776455],
[-80.013427734375, 33.03169299978312],
]
],
},
}

with STACReader(STAC_PATH, default_assets=["green"]) as stac:
img = stac.tile(71, 102, 8)
assert img.data.shape == (1, 256, 256)
assert img.mask.shape == (256, 256)
assert img.band_names == ["green_b1"]

pt = stac.point(-80.477, 33.4453)
assert len(pt.data) == 1
assert pt.band_names == ["green_b1"]

img = stac.preview()
assert img.data.shape == (1, 259, 255)
assert img.mask.shape == (259, 255)
assert img.band_names == ["green_b1"]

img = stac.part(bbox)
assert img.data.shape == (1, 73, 83)
assert img.mask.shape == (73, 83)
assert img.band_names == ["green_b1"]

img = stac.feature(feat)
assert img.data.shape == (1, 118, 96)
assert img.mask.shape == (118, 96)
assert img.band_names == ["green_b1"]
Loading