Skip to content

Commit

Permalink
allow sequence and iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentsarago committed Nov 3, 2023
1 parent 76c66c3 commit 090f2f0
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 10 deletions.
25 changes: 18 additions & 7 deletions rio_tiler/mosaic/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,18 @@

import warnings
from inspect import isclass
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union, cast, Iterable
from typing import (
Any,
Callable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)

import numpy
from rasterio.crs import CRS
Expand All @@ -23,7 +34,7 @@


def mosaic_reader( # noqa: C901
mosaic_assets: Iterable,
mosaic_assets: Union[Iterator, Sequence],
reader: Callable[..., ImageData],
*args: Any,
pixel_selection: Union[Type[MosaicMethodBase], MosaicMethodBase] = FirstMethod,
Expand All @@ -36,7 +47,7 @@ def mosaic_reader( # noqa: C901
Args:
mosaic_assets (sequence): List of assets.
mosaic_assets (Sequence or Iterator): List of assets.
reader (callable): Reader function. The function MUST take `(asset, *args, **kwargs)` as arguments, and MUST return an ImageData.
args (Any): Argument to forward to the reader function.
pixel_selection (MosaicMethod, optional): Instance of MosaicMethodBase class. Defaults to `rio_tiler.mosaic.methods.defaults.FirstMethod`.
Expand Down Expand Up @@ -76,17 +87,17 @@ def mosaic_reader( # noqa: C901
"'rio_tiler.mosaic.methods.base.MosaicMethodBase'"
)

# if not chunk_size:
# chunk_size = threads if threads > 1 else len(mosaic_assets)
chunk_size = threads
if not isinstance(mosaic_assets, Iterator) and not chunk_size:
chunk_size = threads if threads > 1 else len(mosaic_assets)

chunk_size = chunk_size or threads

assets_used: List = []
crs: Optional[CRS]
bounds: Optional[BBox]
band_names: List[str]

for chunks in _chunks(mosaic_assets, chunk_size):
print(threads, len(chunks), chunk_size)
tasks = create_tasks(reader, chunks, threads, *args, **kwargs)
for img, asset in filter_tasks(
tasks,
Expand Down
20 changes: 17 additions & 3 deletions rio_tiler/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
"""rio_tiler.utils: utility functions."""

import itertools
import warnings
from io import BytesIO
from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple, Union, Iterable
from typing import (
Any,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Union,
)

import numpy
import rasterio
Expand All @@ -24,12 +36,14 @@
from rio_tiler.constants import WEB_MERCATOR_CRS
from rio_tiler.errors import RioTilerError
from rio_tiler.types import BBox, ColorMapType, IntervalTuple, RIOResampling
import itertools


def _chunks(my_list: Iterable, chuck_size: int) -> Generator[Sequence, None, None]:
"""Yield successive n-sized chunks from l."""
while chunk:= tuple(itertools.islice(my_list, chuck_size)):
if not isinstance(my_list, Iterator):
my_list = iter(my_list)

while chunk := tuple(itertools.islice(my_list, chuck_size)):
yield chunk


Expand Down
14 changes: 14 additions & 0 deletions tests/test_mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,20 @@ class aClass(object):
assert m.dtype == "uint8"


def test_mosaic_tiler_iter():
"""Test mosaic tiler with iterator input."""
assets_iter = iter(assets)

(t, m), assets_used = mosaic.mosaic_reader(assets_iter, _read_tile, x, y, z)
assert t.shape == (3, 256, 256)
assert m.shape == (256, 256)
assert m.all()
# Should only have value of 1
assert numpy.unique(t[0, m == 255]).tolist() == [1]
assert t.dtype == "uint16"
assert m.dtype == "uint8"


def mock_rasterio_open(asset):
"""Mock rasterio Open."""
assert asset.startswith("http://somewhere-over-the-rainbow.io")
Expand Down

0 comments on commit 090f2f0

Please sign in to comment.