Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sketch for raw raster support in pcfuncs #150

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
70 changes: 69 additions & 1 deletion pcfuncs/funclib/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from urllib.parse import quote

import attr
from pydantic import BaseModel
from rasterio.coords import BoundingBox
from rio_tiler.models import ImageData


class RenderOptions(BaseModel):
Expand Down Expand Up @@ -104,3 +107,68 @@ def from_query_params(cls, render_params: str) -> "RenderOptions":
result[k] = v

return RenderOptions(**result)


@attr.s
class RIOImage(ImageData): # type: ignore
"""Extend ImageData class."""

@property
def size(self) -> Tuple[int, int]:
return (self.width, self.height)

def paste(
self,
img: "RIOImage",
box: Optional[
Union[
Tuple[int, int],
Tuple[int, int, int, int],
]
],
) -> None:
if img.count != self.count:
raise Exception("Cannot merge 2 images with different band number")

if img.data.dtype != self.data.dtype:
raise Exception("Cannot merge 2 images with different datatype")

# Pastes another image into this image.
# The box argument is either a 2-tuple giving the upper left corner,
# a 4-tuple defining the left, upper, right, and lower pixel coordinate,
# or None (same as (0, 0)). See Coordinate System. If a 4-tuple is given,
# the size of the pasted image must match the size of the region.
if box is None:
box = (0, 0)

if len(box) == 2:
size = img.size
box += (box[0] + size[0], box[1] + size[1]) # type: ignore
minx, maxy, maxx, miny = box # type: ignore
elif len(box) == 4:
# TODO add more size tests
minx, maxy, maxx, miny = box # type: ignore

else:
raise Exception("Invalid box format")

self.data[:, maxy:miny, minx:maxx] = img.data
self.mask[maxy:miny, minx:maxx] = img.mask

def crop(self, bbox: Tuple[int, int, int, int]) -> "RIOImage":
"""Almost like ImageData.clip but do not deal with Geo transform."""
col_min, row_min, col_max, row_max = bbox

data = self.data[:, row_min:row_max, col_min:col_max]
mask = self.mask[row_min:row_max, col_min:col_max]

return RIOImage(
data,
mask,
assets=self.assets,
crs=self.crs,
bounds=bbox,
band_names=self.band_names,
metadata=self.metadata,
dataset_statistics=self.dataset_statistics,
)
45 changes: 43 additions & 2 deletions pcfuncs/funclib/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Dict, List, Optional, Tuple, TypeVar

import mercantile
from funclib.models import RIOImage
from PIL.Image import Image as PILImage
from pyproj import CRS, Transformer

Expand Down Expand Up @@ -170,5 +171,45 @@ def mask(self, geom: Dict[str, Any]) -> "PILRaster":


class GDALRaster(Raster):
# TODO: Implement
...
def __init__(self, extent: RasterExtent, image: RIOImage) -> None:
self.image = image
super().__init__(extent)

def to_bytes(self, format: str = ExportFormats.PNG) -> io.BytesIO:
img_bytes = self.image.render(
add_mask=True,
img_format=format.upper(),
)
return io.BytesIO(img_bytes)

def crop(self, bbox: Bbox) -> "GDALRaster":
# Web mercator of user bbox
if (
not bbox.crs == self.extent.bbox.crs
and bbox.crs is not None
and self.extent.bbox.crs is not None
):
bbox = bbox.reproject(self.extent.bbox.crs)

col_min, row_min = self.extent.map_to_grid(bbox.xmin, bbox.ymax)
col_max, row_max = self.extent.map_to_grid(bbox.xmax, bbox.ymin)

box: Any = (col_min, row_min, col_max, row_max)
cropped = self.image.crop(box)
return GDALRaster(
extent=RasterExtent(
bbox,
cols=col_max - col_min,
rows=row_max - row_min,
),
image=cropped,
)

def resample(self, cols: int, rows: int) -> "GDALRaster":
return GDALRaster(
extent=RasterExtent(bbox=self.extent.bbox, cols=cols, rows=rows),
image=self.image.resize(rows, cols), # type: ignore
)

def mask(self, geom: Dict[str, Any]) -> "GDALRaster":
raise NotImplementedError("GDALRaster does not support masking")
80 changes: 78 additions & 2 deletions pcfuncs/funclib/tiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

import aiohttp
import mercantile
from funclib.models import RenderOptions
import numpy
from funclib.models import RenderOptions, RIOImage
from funclib.raster import (
Bbox,
ExportFormats,
Expand Down Expand Up @@ -152,8 +153,83 @@ async def create(


class GDALTileSet(TileSet[GDALRaster]):
async def _get_tile(self, url: str) -> Union[RIOImage, None]:
async def _f() -> RIOImage:
async with aiohttp.ClientSession() as session:
async with self._async_limit:
# We set Accept-Encoding to make sure the response is compressed
async with session.get(
url, headers={"Accept-Encoding": "gzip"}
) as resp:
if resp.status == 200:
return RIOImage.from_bytes(await resp.read()) # type: ignore

else:
raise TilerError(
f"Error downloading tile: {url}", resp=resp
)

try:
return await with_backoff_async(
_f,
is_throttle=lambda e: isinstance(e, TilerError),
strategy=BackoffStrategy(waits=[0.2, 0.5, 0.75, 1, 2]),
)
except Exception:
logger.warning(f"Tile request failed with backoff: {url}")
return None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is an exception we return None, it will then be handled later


async def get_mosaic(self, tiles: List[Tile]) -> GDALRaster:
raise NotImplementedError()
tasks: List[asyncio.Future[Union[RIOImage, None]]] = []
for tile in tiles:
url = self.get_tile_url(tile.z, tile.x, tile.y)
print(f"Downloading {url}")
tasks.append(asyncio.ensure_future(self._get_tile(url)))

tile_images: List[Union[RIOImage, None]] = list(await asyncio.gather(*tasks))

tileset_dimensions = get_tileset_dimensions(tiles, self.tile_size)

# By default if no tiles where return we create an
# empty mosaic with 3 bands and uint8
count: int = 3
dtype: str = "uint8"
for im in tile_images:
if im:

count = im.count
dtype = im.data.dtype
break # Get Count / datatype from the first valid tile_images

mosaic = RIOImage( # type: ignore
numpy.zeros(
(count, tileset_dimensions.total_rows, tileset_dimensions.total_cols),
dtype=dtype,
)
)

x = 0
y = 0
for i, img in enumerate(tile_images):
if not img:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there was an exception in get_tile

continue

mosaic.paste(img, (x * self.tile_size, y * self.tile_size))

# Increment the row/col position for subsequent tiles
if (i + 1) % tileset_dimensions.tile_rows == 0:
y = 0
x += 1
else:
y += 1

raster_extent = RasterExtent(
bbox=Bbox.from_tiles(tiles),
cols=tileset_dimensions.total_cols,
rows=tileset_dimensions.total_rows,
)

return GDALRaster(raster_extent, mosaic)


class PILTileSet(TileSet[PILRaster]):
Expand Down
1 change: 1 addition & 0 deletions pcfuncs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pillow==9.3.0
pyproj==3.3.1
pydantic>=1.9,<2.0.0
rasterio==1.3.*
rio-tiler==4.1.* # same as titiler 0.10.2

# Deployment needs to copy the local code into
# the app code directory, so requires a separate
Expand Down
Binary file added pcfuncs/tests/data-files/cog.tif
Binary file not shown.
137 changes: 137 additions & 0 deletions pcfuncs/tests/funclib/test_gdal_tileset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import contextlib
import pathlib
import threading
import time
from enum import Enum
from types import DynamicClassAttribute

import pytest
import uvicorn
from fastapi import FastAPI, Path, Query
from funclib.models import RenderOptions
from funclib.tiles import GDALTileSet
from mercantile import Tile
from rio_tiler.io import Reader
from rio_tiler.profiles import img_profiles
from starlette.responses import Response

HERE = pathlib.Path(__file__).parent
DATA_FILES = HERE / ".." / "data-files"

cog_file = HERE / ".." / "data-files" / "cog.tif"


class ImageDriver(str, Enum):
"""Supported output GDAL drivers."""

jpg = "JPEG"
png = "PNG"
tif = "GTiff"


class MediaType(str, Enum):
"""Responses Media types formerly known as MIME types."""

tif = "image/tiff; application=geotiff"
png = "image/png"
jpeg = "image/jpeg"


class ImageType(str, Enum):
"""Available Output image type."""

png = "png"
tif = "tif"
jpg = "jpg"

@DynamicClassAttribute
def profile(self):
"""Return rio-tiler image default profile."""
return img_profiles.get(self._name_, {})

@DynamicClassAttribute
def driver(self):
"""Return rio-tiler image default profile."""
return ImageDriver[self._name_].value

@DynamicClassAttribute
def mediatype(self):
"""Return image media type."""
return MediaType[self._name_].value


class Server(uvicorn.Server):
"""Uvicorn Server."""

def install_signal_handlers(self):
"""install handlers."""
pass

@contextlib.contextmanager
def run_in_thread(self):
"""run in thread."""
thread = threading.Thread(target=self.run)
thread.start()
try:
while not self.started:
time.sleep(1e-3)
yield
finally:
self.should_exit = True
thread.join()


@pytest.fixture(scope="session")
def application():
"""Run app in Thread."""
app = FastAPI()

@app.get("/{z}/{x}/{y}.{format}", response_class=Response)
def tiler(
z: int = Path(...),
x: int = Path(...),
y: int = Path(...),
format: ImageType = Path(...),
collection: str = Query(...),
tile_scale: int = Query(
1, gt=0, lt=4, description="Tile size scale. 1=256x256, 2=512x512..."
),
):
with Reader(collection) as src:
image = src.tile(x, y, z, tilesize=tile_scale * 256)

content = image.render(
img_format=format.driver,
**format.profile,
)
return Response(content, media_type=format.mediatype)

config = uvicorn.Config(
app, host="127.0.0.1", port=5000, log_level="info", loop="asyncio"
)
server = Server(config=config)
with server.run_in_thread():
yield "http://127.0.0.1:5000"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we create a small tiler which will run in thread



async def test_app(application):
"""Test GDAL Tileset application."""
tileset = GDALTileSet(
f"{application}/{{z}}/{{x}}/{{y}}.tif",
RenderOptions(
collection=str(cog_file),
),
)
expect = f"http://127.0.0.1:5000/0/1/2.tif?collection={cog_file}&tile_scale=2"
assert tileset.get_tile_url(0, 1, 2) == expect

# Test one Tile
url = tileset.get_tile_url(7, 44, 25)
im = await tileset._get_tile(url)
assert im.size == (512, 512)

# Test Mosaic
mosaic = await tileset.get_mosaic([Tile(44, 25, 7), Tile(45, 25, 7)])
assert mosaic.image.size == (1024, 512) # width, height
assert mosaic.image.count == 1 # same as cog_file
assert mosaic.image.data.dtype == "uint16" # same as cog_file
Loading