diff --git a/api/app/main.py b/api/app/main.py index ff6b09f7..9154b09c 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -9,7 +9,7 @@ from app.auth.auth import verify_token from app.config.config import get_settings -from app.routers.h3 import h3_grid_router +from app.routers.grid import grid_router from app.routers.zonal_stats import ZonalTilerFactory @@ -29,7 +29,7 @@ def path_params(raster_filename: Annotated[str, Query(description="Raster filena tiler_routes = ZonalTilerFactory(path_dependency=path_params) app.include_router(tiler_routes.router, tags=["Raster"], dependencies=[Depends(verify_token)]) -app.include_router(h3_grid_router, prefix="/grid", tags=["Grid"], dependencies=[Depends(verify_token)]) +app.include_router(grid_router, prefix="/grid", tags=["Grid"], dependencies=[Depends(verify_token)]) add_exception_handlers(app, DEFAULT_STATUS_CODES) diff --git a/api/app/models/grid.py b/api/app/models/grid.py index 1953cb41..47b74a58 100644 --- a/api/app/models/grid.py +++ b/api/app/models/grid.py @@ -1,10 +1,12 @@ # ruff: noqa: D101 from enum import Enum -from typing import Literal +from typing import Annotated, Literal +from fastapi import Query from pydantic import BaseModel, ConfigDict, Field -from pydantic.color import Color +from pydantic_extra_types.color import Color +from sqlalchemy.sql import column, desc, nullslast, select, table class LegendTypes(str, Enum): @@ -60,3 +62,71 @@ class H3GridInfo(BaseModel): class MultiDatasetMeta(BaseModel): datasets: list[DatasetMeta] = Field(description="Variables represented in this dataset") h3_grid_info: list[H3GridInfo] = Field(description="H3 related information") + + +# =============================================== +# TABLE FILTERING +# =============================================== + + +class NumericalOperators(str, Enum): + eq = "eq" + gt = "gt" + lt = "lt" + gte = "gte" + lte = "lte" + not_eq = "not_eq" + + +class CategoricalOperators(str, Enum): + in_ = "in" + not_in = "not_in" + + +class CategoricalFilter(BaseModel): + filter_type: Literal["categorical"] + column_name: str = Field(description="Name of the column to which the filter will apply") + operation: CategoricalOperators = Field() + value: list[int] = Field(description="Value to compare with") + + +class NumericalFilter(BaseModel): + filter_type: Literal["numerical"] + column_name: str = Field(description="Name of the column to which the filter will apply") + operation: NumericalOperators = Field(description="Operation to use in compare") + value: float = Field(description="Value to compare with") + + +class TableFilters(BaseModel): + filters: list[Annotated[CategoricalFilter | NumericalFilter, Field(discriminator="filter_type")]] + limit: int = Field(10, lt=1000, description="Number of records") + order_by: Annotated[list[str], Field(Query(..., description="Prepend '-' to column name to make it descending"))] + + def to_sql_query(self, table_name: str) -> str: + """Compile model to sql query""" + op_to_python_dunder = { + "eq": "__eq__", + "gt": "__gt__", + "lt": "__lt__", + "gte": "__ge__", + "lte": "__le__", + "not_eq": "__ne__", + "in": "in_", + } + filters_to_apply = [] + for _filter in self.filters: + if _filter is None: + continue + col = column(_filter.column_name) + param = getattr(col, op_to_python_dunder.get(_filter.operation, _filter.operation))(_filter.value) + filters_to_apply.append(param) + query = ( + select("*") + .select_from(table(table_name)) + .where(*filters_to_apply) + .limit(self.limit) + .order_by( + *[nullslast(desc(column(col[1:]))) if col.startswith("-") else column(col) for col in self.order_by] + ) + ) + return str(query.compile(compile_kwargs={"literal_binds": True})) diff --git a/api/app/models/exact_extract.py b/api/app/models/zonal_stats.py similarity index 100% rename from api/app/models/exact_extract.py rename to api/app/models/zonal_stats.py diff --git a/api/app/routers/grid.py b/api/app/routers/grid.py new file mode 100644 index 00000000..5954d298 --- /dev/null +++ b/api/app/routers/grid.py @@ -0,0 +1,84 @@ +import logging +import os +from pathlib import Path +from typing import Annotated + +import h3 +import polars as pl +from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi.responses import FileResponse, ORJSONResponse +from h3 import H3CellError +from pydantic import ValidationError + +from app.config.config import get_settings +from app.models.grid import MultiDatasetMeta, TableFilters + +log = logging.getLogger("uvicorn.error") + +grid_router = APIRouter() + + +@grid_router.get( + "/tile/{tile_index}", + responses={200: {"description": "Get a grid tile"}, 404: {"description": "Not found"}}, + response_model=None, +) +async def grid_tile(tile_index: str) -> FileResponse: + """Request a tile of h3 cells + + :raises HTTPException 404: Item not found + :raises HTTPException 422: H3 index is not valid + """ + try: + z = h3.api.basic_str.h3_get_resolution(tile_index) + except H3CellError: + raise HTTPException(status_code=422, detail="Tile index is not a valid H3 cell") from None + + tile_file = os.path.join(get_settings().grid_tiles_path, f"{z}/{tile_index}.arrow") + if not os.path.exists(tile_file): + raise HTTPException(status_code=404, detail=f"Tile {tile_file} not found") + return FileResponse(tile_file, media_type="application/octet-stream") + + +@grid_router.get( + "/meta", +) +async def grid_dataset_metadata() -> MultiDatasetMeta: + """Dataset metadata""" + file = os.path.join(get_settings().grid_tiles_path, "meta.json") + with open(file) as f: + raw = f.read() + try: + meta = MultiDatasetMeta.model_validate_json(raw) + except ValidationError as e: + # validation error is our fault because meta file is internal. We don't want to show internal error details + # so raise controlled 500 + log.exception(e) + raise HTTPException(status_code=500, detail="Metadata file is malformed. Please contact developer.") from None + return meta + + +@grid_router.post("/table") +def read_table( + level: Annotated[int, Query(..., description="Tile level at which the query will be computed")], + filters: TableFilters = Depends(), +) -> ORJSONResponse: + """Query tile dataset and return table data""" + files_path = Path(get_settings().grid_tiles_path) / str(level) + if not files_path.exists(): + raise HTTPException(404, detail=f"Level {level} does not exist") from None + lf = pl.scan_ipc(files_path.glob("*.arrow")) + query = filters.to_sql_query("frame") + log.debug(query) + try: + res = pl.SQLContext(frame=lf).execute(query).collect() + except pl.exceptions.ColumnNotFoundError as e: + # bad column in order by clause + log.exception(e) + raise HTTPException(status_code=404, detail=f"Column '{e}' not found in dataset") from None + + except pl.exceptions.ComputeError as e: + # possibly raise if wrong type in compare. I'm not aware of other sources of ComputeError + log.exception(e) + raise HTTPException(status_code=422, detail=str(e)) from None + return ORJSONResponse(res.to_dict(as_series=False)) diff --git a/api/app/routers/h3.py b/api/app/routers/h3.py deleted file mode 100644 index 70f4e760..00000000 --- a/api/app/routers/h3.py +++ /dev/null @@ -1,55 +0,0 @@ -import logging -import os - -import h3 -from fastapi import APIRouter, HTTPException -from fastapi.responses import FileResponse -from h3 import H3CellError -from pydantic import ValidationError - -from app.config.config import get_settings -from app.models.grid import MultiDatasetMeta - -log = logging.getLogger(__name__) - -h3_grid_router = APIRouter() - - -@h3_grid_router.get( - "/tile/{tile_index}", - responses={200: {"description": "Get a grid tile"}, 404: {"description": "Not found"}}, - response_model=None, -) -async def grid_tile(tile_index: str) -> FileResponse: - """Request a tile of h3 cells - - :raises HTTPException 404: Item not found - :raises HTTPException 422: H3 index is not valid - """ - try: - z = h3.api.basic_str.h3_get_resolution(tile_index) - except H3CellError: - raise HTTPException(status_code=422, detail="Tile index is not a valid H3 cell") from None - - tile_file = os.path.join(get_settings().grid_tiles_path, f"{z}/{tile_index}.arrow") - if not os.path.exists(tile_file): - raise HTTPException(status_code=404, detail=f"Tile {tile_file} not found") - return FileResponse(tile_file, media_type="application/octet-stream") - - -@h3_grid_router.get( - "/meta", -) -async def grid_dataset_metadata() -> MultiDatasetMeta: - """Dataset metadata""" - file = os.path.join(get_settings().grid_tiles_path, "meta.json") - with open(file) as f: - raw = f.read() - try: - meta = MultiDatasetMeta.model_validate_json(raw) - except ValidationError as e: - # validation error is our fault, and we don't want to show internal error details - # so re re-raising 500 with aseptic message and keep the details in our logs. - log.exception(e) - raise HTTPException(status_code=500, detail="Metadata file is malformed. Please contact developer.") from None - return meta diff --git a/api/app/routers/zonal_stats.py b/api/app/routers/zonal_stats.py index 0f9911c5..1e7f080f 100644 --- a/api/app/routers/zonal_stats.py +++ b/api/app/routers/zonal_stats.py @@ -1,5 +1,5 @@ """Minimal COG tiler.""" -import os + from typing import Annotated, List, Union import rasterio @@ -8,8 +8,7 @@ from geojson_pydantic import Feature, FeatureCollection from titiler.core.factory import TilerFactory -from app.config.config import get_settings -from app.models.exact_extract import StatsFeatures, StatsOps +from app.models.zonal_stats import StatsFeatures, StatsOps class ZonalTilerFactory(TilerFactory): @@ -65,8 +64,6 @@ def exact_zonal_stats( features = [geojson.model_dump()] with rasterio.Env(**env): - tiff_path = get_settings().tiff_path - src_path = os.path.join(tiff_path, src_path) with rasterio.open(src_path, **reader_params) as src_dst: statistics = [op.value for op in statistics] # extract the values from the Enum stats = exact_extract(src_dst, features, ops=statistics) diff --git a/api/pyproject.toml b/api/pyproject.toml index 9f929ad5..034ba548 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,7 +1,7 @@ [tool.ruff] line-length = 120 extend-include = ["*.ipynb"] -src = ["app"] +src = ["."] [tool.ruff.lint] select = [ @@ -24,3 +24,6 @@ extend-immutable-calls = ["fastapi.Depends", "fastapi.params.Depends", "fastapi. [tool.ruff.lint.per-file-ignores] "**/{tests}/*" = ["D103"] # Missing docstring in public function + +[tool.mypy] +disable_error_code = ["import-untyped", "attr-defined"] diff --git a/api/requirements.in b/api/requirements.in index 2e3cf566..a2bf54c0 100644 --- a/api/requirements.in +++ b/api/requirements.in @@ -5,3 +5,6 @@ pydantic_settings titiler.core uvicorn h3 +pydantic-extra-types +polars +sqlalchemy diff --git a/api/requirements.txt b/api/requirements.txt index 14cb14a6..50c7a96b 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -38,6 +38,8 @@ fastapi==0.110.1 # via titiler-core geojson-pydantic==1.0.2 # via titiler-core +greenlet==3.0.3 + # via sqlalchemy h11==0.14.0 # via # httpcore @@ -70,16 +72,19 @@ numpy==1.26.4 # snuggs # titiler-core orjson==3.10.0 +polars==1.1.0 pydantic==2.6.4 # via # fastapi # geojson-pydantic # morecantile + # pydantic-extra-types # pydantic-settings # rio-tiler # titiler-core pydantic-core==2.16.3 # via pydantic +pydantic-extra-types==2.9.0 pydantic-settings==2.2.1 pyparsing==3.1.2 # via snuggs @@ -109,6 +114,7 @@ sniffio==1.3.1 # httpx snuggs==1.4.7 # via rasterio +sqlalchemy==2.0.31 starlette==0.37.2 # via fastapi titiler-core==0.18.0 @@ -117,5 +123,6 @@ typing-extensions==4.11.0 # fastapi # pydantic # pydantic-core + # sqlalchemy # titiler-core uvicorn==0.29.0 diff --git a/api/tests/benchmark_post.lua b/api/tests/benchmark_post.lua new file mode 100644 index 00000000..c71ab840 --- /dev/null +++ b/api/tests/benchmark_post.lua @@ -0,0 +1,28 @@ +-- example HTTP POST script which demonstrates setting the +-- HTTP method, body, and adding a header +-- command: +-- wrk -c 100 -t 10 -d 10s -s benchmark_post.lua 'http://localhost:8000/grid/table?level=4&limit=10&order_by=-population' + + +wrk.method = "POST" +wrk.body = [[ +[ + { + "filter_type": "categorical", + "column_name": "fire", + "operation": "in", + "value": [ + 1,2,3 + ] + }, + { + "filter_type": "numerical", + "column_name": "population", + "operation": "gt", + "value": 10000 + } +] +]] +wrk.headers["Content-Type"] = "application/json" +wrk.headers["accept"] = "application/json" +wrk.headers["Authorization"] = "Bearer 1234" diff --git a/api/tests/conftest.py b/api/tests/conftest.py new file mode 100644 index 00000000..2fce5818 --- /dev/null +++ b/api/tests/conftest.py @@ -0,0 +1,119 @@ +import os +from pathlib import Path + +import numpy as np +import polars as pl +import pytest +import rasterio + +from app.config.config import get_settings + +TEST_ROOT = Path(__file__).resolve().parent + +# Testing settings env variables +os.environ["AUTH_TOKEN"] = "secret" +os.environ["TIFF_PATH"] = str(TEST_ROOT / "data") +os.environ["GRID_TILES_PATH"] = str(TEST_ROOT / "data" / "grid") + +FILES = ["raster.tif", "raster2.tif", "raster3.tif"] +HEADERS = {"Authorization": f"Bearer {get_settings().auth_token}"} + + +@pytest.fixture() +def grid_dataset(setup_data_folder) -> str: + """Create an empty binary file to be used as grid dataset stub + for a level 0 tile. like: + data + └── grid + ├── 0 + │ └── 84395c9ffffffff.arrow + └── meta.json + """ + level = "4" + h3_index = "84395c9ffffffff" + + grid_dataset_path = Path(get_settings().grid_tiles_path) + level_path = grid_dataset_path / level + level_path.mkdir(parents=True) + tile_path = level_path / f"{h3_index}.arrow" + + df = pl.DataFrame( + { + "cell": [ + 618668968382824400, + 619428375900454900, + 619428407452893200, + 619428407943888900, + 619428407676764200, + ], + "landcover": [1, 4, 3, 3, 4], + "population": [100, 200, 1, 900, 900], + } + ) + with open(grid_dataset_path / "meta.json", "w") as f: + f.write("{}") + + with open(tile_path, "wb") as f: + df.write_ipc(f) + + yield h3_index + + tile_path.unlink() + level_path.rmdir() + (grid_dataset_path / "meta.json").unlink() + grid_dataset_path.rmdir() + + +@pytest.fixture() +def setup_data_folder(): + os.mkdir(get_settings().tiff_path) + + yield + + os.rmdir(get_settings().tiff_path) + + +@pytest.fixture() +def tif_file(setup_data_folder): + """Create a test raster file. + + [[0, 1, 0], + [1, 9, 1], + [0, 1, 0]] + + The raster is a 3x3 grid with the upper left corner at 0E, 10N and 1 degree pixel size. + The bbox is BoundingBox(left=0.0, bottom=7.0, right=3.0, top=10.0) + """ + data = np.array([[0, 1, 0], [1, 9, 1], [0, 1, 0]]) + transform = rasterio.transform.from_origin(0, 10, 1, 1) + with rasterio.open( + f"{get_settings().tiff_path}/raster.tif", + "w", + driver="GTiff", + width=data.shape[1], + height=data.shape[0], + count=1, + dtype="uint8", + crs="+proj=latlong", + transform=transform, + ) as dst: + dst.write(data, 1) + + yield + + os.remove(f"{get_settings().tiff_path}/raster.tif") + + +@pytest.fixture() +def setup_empty_files(setup_data_folder): + test_tiff_path = get_settings().tiff_path + + for file in FILES: + # Create empty files writing nothing + with open(f"{test_tiff_path}/{file}", "w") as f: + f.write("") + + yield + + for file in FILES: + os.remove(f"{test_tiff_path}/{file}") diff --git a/api/tests/test_grid.py b/api/tests/test_grid.py new file mode 100644 index 00000000..d352ab25 --- /dev/null +++ b/api/tests/test_grid.py @@ -0,0 +1,207 @@ +import json + +import polars as pl + +from app.models.grid import TableFilters +from tests.conftest import HEADERS +from tests.utils import test_client + + +def test_grid_tile(grid_dataset): + response = test_client.get(f"/grid/tile/{grid_dataset}", headers=HEADERS) + + assert response.status_code == 200 + assert pl.read_ipc(response.read()).to_dict(as_series=False) == { + "cell": [ + 618668968382824400, + 619428375900454900, + 619428407452893200, + 619428407943888900, + 619428407676764200, + ], + "landcover": [1, 4, 3, 3, 4], + "population": [100, 200, 1, 900, 900], + } + + +def test_grid_tile_404(grid_dataset): + response = test_client.get("/grid/tile/8439181ffffffff", headers=HEADERS) + + assert response.status_code == 404 + + +def test_grid_tile_bad_index(grid_dataset): + response = test_client.get("/grid/tile/123", headers=HEADERS) + + assert response.status_code == 422 + assert response.json() == {"detail": "Tile index is not a valid H3 cell"} + + +def test_grid_metadata_fails_gracefully(grid_dataset): + res = test_client.get("/grid/meta", headers=HEADERS) + + assert res.status_code == 500 + assert res.json() == {"detail": "Metadata file is malformed. Please contact developer."} + + +def test_table_filter_numerical_eq_to_sql(): + tf = TableFilters.model_validate( + { + "filters": [{"filter_type": "numerical", "column_name": "foo", "operation": "eq", "value": 10}], + "limit": 10, + "order_by": ["baz"], + } + ) + query = tf.to_sql_query("table") + assert query.replace("\n", "") == 'SELECT * FROM "table" WHERE foo = 10.0 ORDER BY baz LIMIT 10' + + +def test_table_filter_numerical_gt_to_sql(): + tf = TableFilters.model_validate( + { + "filters": [{"filter_type": "numerical", "column_name": "foo", "operation": "gt", "value": 10}], + "limit": 10, + "order_by": ["baz"], + } + ) + query = tf.to_sql_query("table") + assert query.replace("\n", "") == 'SELECT * FROM "table" WHERE foo > 10.0 ORDER BY baz LIMIT 10' + + +def test_table_filter_numerical_lt_to_sql(): + tf = TableFilters.model_validate( + { + "filters": [{"filter_type": "numerical", "column_name": "foo", "operation": "lt", "value": 10}], + "limit": 10, + "order_by": ["baz"], + } + ) + query = tf.to_sql_query("table") + assert query.replace("\n", "") == 'SELECT * FROM "table" WHERE foo < 10.0 ORDER BY baz LIMIT 10' + + +def test_table_filter_numerical_gte_to_sql(): + tf = TableFilters.model_validate( + { + "filters": [{"filter_type": "numerical", "column_name": "foo", "operation": "gte", "value": 10}], + "limit": 10, + "order_by": ["baz"], + } + ) + query = tf.to_sql_query("table") + assert query.replace("\n", "") == 'SELECT * FROM "table" WHERE foo >= 10.0 ORDER BY baz LIMIT 10' + + +def test_table_filter_numerical_lte_to_sql(): + tf = TableFilters.model_validate( + { + "filters": [{"filter_type": "numerical", "column_name": "foo", "operation": "lte", "value": 10}], + "limit": 10, + "order_by": ["baz"], + } + ) + query = tf.to_sql_query("table") + assert query.replace("\n", "") == 'SELECT * FROM "table" WHERE foo <= 10.0 ORDER BY baz LIMIT 10' + + +def test_table_filter_numerical_not_eq_to_sql(): + tf = TableFilters.model_validate( + { + "filters": [{"filter_type": "numerical", "column_name": "foo", "operation": "not_eq", "value": 10}], + "limit": 10, + "order_by": ["baz"], + } + ) + query = tf.to_sql_query("table") + assert query.replace("\n", "") == 'SELECT * FROM "table" WHERE foo != 10.0 ORDER BY baz LIMIT 10' + + +def test_table_filter_categorical_in_to_sql(): + tf = TableFilters.model_validate( + { + "filters": [{"filter_type": "categorical", "column_name": "foo", "operation": "in", "value": [1, 2, 3]}], + "limit": 10, + "order_by": ["baz"], + } + ) + query = tf.to_sql_query("table") + assert query.replace("\n", "") == 'SELECT * FROM "table" WHERE foo IN (1, 2, 3) ORDER BY baz LIMIT 10' + + +def test_table_filter_categorical_not_in_to_sql(): + tf = TableFilters.model_validate( + { + "filters": [ + {"filter_type": "categorical", "column_name": "foo", "operation": "not_in", "value": [1, 2, 3]} + ], + "limit": 10, + "order_by": ["baz"], + } + ) + query = tf.to_sql_query("table") + assert query.replace("\n", "") == 'SELECT * FROM "table" WHERE (foo NOT IN (1, 2, 3)) ORDER BY baz LIMIT 10' + + +def test_table_filters_order_by_desc(): + tf = TableFilters.model_validate( + { + "filters": [{"filter_type": "numerical", "column_name": "foo", "operation": "gt", "value": 10}], + "limit": 100, + "order_by": ["-baz"], + } + ) + query = tf.to_sql_query("table") + assert query.replace("\n", "") == 'SELECT * FROM "table" WHERE foo > 10.0 ORDER BY baz DESC NULLS LAST LIMIT 100' + + +def test_table_filters_multiple_order_by(): + tf = TableFilters.model_validate( + { + "filters": [{"filter_type": "numerical", "column_name": "foo", "operation": "gt", "value": 10}], + "limit": 100, + "order_by": ["-baz", "foo", "-bar"], + } + ) + query = tf.to_sql_query("table") + assert ( + query.replace("\n", "") + == 'SELECT * FROM "table" WHERE foo > 10.0 ORDER BY baz DESC NULLS LAST, foo, bar DESC NULLS LAST LIMIT 100' + ) + + +def test_table_filters_multiple_filters(): + tf = TableFilters.model_validate( + { + "filters": [ + {"filter_type": "numerical", "column_name": "foo", "operation": "eq", "value": 10}, + {"filter_type": "categorical", "column_name": "bar", "operation": "in", "value": [1, 2, 3]}, + ], + "limit": 100, + "order_by": ["-baz", "foo"], + } + ) + query = tf.to_sql_query("table") + assert ( + query.replace("\n", "") + == 'SELECT * FROM "table" WHERE foo = 10.0 AND bar IN (1, 2, 3) ORDER BY baz DESC NULLS LAST, foo LIMIT 100' + ) + + +def test_grid_table(grid_dataset): + filters = [ + {"filter_type": "numerical", "column_name": "population", "operation": "lte", "value": 200}, + {"filter_type": "numerical", "column_name": "population", "operation": "gt", "value": 1}, + ] + + response = test_client.post( + "/grid/table?level=4&order_by=-population", headers=HEADERS, content=json.dumps(filters) + ) + assert response.status_code == 200 + assert json.loads(response.read()) == { + "cell": [ + 619428375900454900, + 618668968382824400, + ], + "landcover": [4, 1], + "population": [200, 100], + } diff --git a/api/tests/test_main.py b/api/tests/test_main.py index 8df1e145..301be158 100644 --- a/api/tests/test_main.py +++ b/api/tests/test_main.py @@ -1,127 +1,8 @@ -import os -from pathlib import Path - -import numpy as np -import pytest -import rasterio -from app.config.config import get_settings from fastapi.routing import APIRoute +from tests.conftest import FILES, HEADERS from tests.utils import test_client -TOKEN = get_settings().auth_token -FILES = ["raster.tif", "raster2.tif", "raster3.tif"] -GEOJSON = { - "type": "FeatureCollection", - "features": [ - { - "type": "Feature", - "properties": {}, - "geometry": { - "type": "Polygon", - "coordinates": [ - [ - [3.0, 7.0], - [3.0, 10.0], - [0.0, 10.0], - [0.0, 7.0], - [3.0, 7.0], - ], - ], - }, - } - ], -} - -HEADERS = {"Authorization": f"Bearer {TOKEN}"} - - -@pytest.fixture() -def setup_data_folder(): - os.mkdir(get_settings().tiff_path) - - yield - - os.rmdir(get_settings().tiff_path) - - -@pytest.fixture() -def tif_file(setup_data_folder): - """Create a test raster file. - - [[0, 1, 0], - [1, 9, 1], - [0, 1, 0]] - - The raster is a 3x3 grid with the upper left corner at 0E, 10N and 1 degree pixel size. - The bbox is BoundingBox(left=0.0, bottom=7.0, right=3.0, top=10.0) - """ - data = np.array([[0, 1, 0], [1, 9, 1], [0, 1, 0]]) - transform = rasterio.transform.from_origin(0, 10, 1, 1) - with rasterio.open( - f"{get_settings().tiff_path}/raster.tif", - "w", - driver="GTiff", - width=data.shape[1], - height=data.shape[0], - count=1, - dtype="uint8", - crs="+proj=latlong", - transform=transform, - ) as dst: - dst.write(data, 1) - - yield - - os.remove(f"{get_settings().tiff_path}/raster.tif") - - -@pytest.fixture() -def h3_dataset(setup_data_folder) -> str: - """Create an empty binary file to be used as grid dataset stub - for a level 0 tile. like: - data - └── grid - ├── 0 - │ └── 84395c9ffffffff.arrow - └── meta.json - """ - level = "4" - h3_index = "84395c9ffffffff" - - grid_dataset_path = Path(get_settings().grid_tiles_path) - level_path = grid_dataset_path / level - level_path.mkdir(parents=True) - tile_path = level_path / f"{h3_index}.arrow" - - with open(grid_dataset_path / "meta.json", "w") as f: - f.write("Not a json") - - with open(tile_path, "wb") as f: - f.write(b"I am an arrow file!") - - yield h3_index - - tile_path.unlink() - level_path.rmdir() - (grid_dataset_path / "meta.json").unlink() - grid_dataset_path.rmdir() - - -@pytest.fixture() -def setup_empty_files(setup_data_folder): - test_tiff_path = get_settings().tiff_path - - for file in FILES: - # Create empty files writing nothing - with open(f"{test_tiff_path}/{file}", "w") as f: - f.write("") - - yield - - for file in FILES: - os.remove(f"{test_tiff_path}/{file}") - def test_no_token_is_unauthorized(): response = test_client.get("/tifs") @@ -171,86 +52,6 @@ def test_list_files(setup_empty_files): assert response.json() == {"files": FILES} -def test_wrong_file_name_raises_404(setup_data_folder): - response = test_client.post( - "/exact_zonal_stats", headers=HEADERS, params={"raster_filename": "wrong.tif"}, json=GEOJSON - ) - assert response.status_code == 404 - assert response.json() == {"detail": "Raster file wrong.tif does not exist."} - - -def test_no_geojson_raises_422(tif_file): - response = test_client.post("/exact_zonal_stats", headers=HEADERS, params={"raster_filename": "raster.tif"}) - assert response.status_code == 422 - assert response.json() == { - "detail": [ - { - "input": None, - "loc": ["body"], - "msg": "Field required", - "type": "missing", - "url": "https://errors.pydantic.dev/2.6/v/missing", - } - ] - } - - -def test_default_zonal_stats(tif_file): - response = test_client.post( - "/exact_zonal_stats", headers=HEADERS, params={"raster_filename": "raster.tif"}, json=GEOJSON - ) - assert response.status_code == 200 - assert response.json() == {"features": [{"properties": {"max": 9.0, "min": 0.0}, "type": "Feature"}]} - - -def test_custom_zonal_stats(tif_file): - response = test_client.post( - "/exact_zonal_stats", - headers=HEADERS, - params={"raster_filename": "raster.tif", "statistics": ["count"]}, - json=GEOJSON, - ) - assert response.status_code == 200 - assert response.json() == {"features": [{"properties": {"count": 9}, "type": "Feature"}]} - - -def test_nonexistent_statistic_raises_422(tif_file): - response = test_client.post( - "/exact_zonal_stats", - headers=HEADERS, - params={"raster_filename": "raster.tif", "statistics": ["nonexistent"]}, - json=GEOJSON, - ) - assert response.status_code == 422 - - -def test_h3grid(h3_dataset): - response = test_client.get(f"/grid/tile/{h3_dataset}", headers=HEADERS) - - assert response.status_code == 200 - assert response.read() == b"I am an arrow file!" - - -def test_h3grid_404(h3_dataset): - response = test_client.get("/grid/tile/8439181ffffffff", headers=HEADERS) - - assert response.status_code == 404 - - -def test_h3grid_bad_index(h3_dataset): - response = test_client.get("/grid/tile/123", headers=HEADERS) - - assert response.status_code == 422 - assert response.json() == {"detail": "Tile index is not a valid H3 cell"} - - -def test_h3grid_metadata_fails_gracefully(h3_dataset): - res = test_client.get("/grid/meta", headers=HEADERS) - - assert res.status_code == 500 - assert res.json() == {"detail": "Metadata file is malformed. Please contact developer."} - - def test_all_api_routes_require_token(): api_routes = {r.path: r.methods for r in test_client.app.routes if isinstance(r, APIRoute)} diff --git a/api/tests/test_zonal_stats.py b/api/tests/test_zonal_stats.py new file mode 100644 index 00000000..13236dca --- /dev/null +++ b/api/tests/test_zonal_stats.py @@ -0,0 +1,77 @@ +from tests.conftest import HEADERS +from tests.utils import test_client + +GEOJSON = { + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "properties": {}, + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [3.0, 7.0], + [3.0, 10.0], + [0.0, 10.0], + [0.0, 7.0], + [3.0, 7.0], + ], + ], + }, + } + ], +} + + +def test_wrong_file_name_raises_404(setup_data_folder): + response = test_client.post( + "/exact_zonal_stats", headers=HEADERS, params={"raster_filename": "wrong.tif"}, json=GEOJSON + ) + assert response.status_code == 404 + assert response.json() == {"detail": "Raster file wrong.tif does not exist."} + + +def test_no_geojson_raises_422(tif_file): + response = test_client.post("/exact_zonal_stats", headers=HEADERS, params={"raster_filename": "raster.tif"}) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "input": None, + "loc": ["body"], + "msg": "Field required", + "type": "missing", + "url": "https://errors.pydantic.dev/2.6/v/missing", + } + ] + } + + +def test_default_zonal_stats(tif_file): + response = test_client.post( + "/exact_zonal_stats", headers=HEADERS, params={"raster_filename": "raster.tif"}, json=GEOJSON + ) + assert response.status_code == 200 + assert response.json() == {"features": [{"properties": {"max": 9.0, "min": 0.0}, "type": "Feature"}]} + + +def test_custom_zonal_stats(tif_file): + response = test_client.post( + "/exact_zonal_stats", + headers=HEADERS, + params={"raster_filename": "raster.tif", "statistics": ["count"]}, + json=GEOJSON, + ) + assert response.status_code == 200 + assert response.json() == {"features": [{"properties": {"count": 9}, "type": "Feature"}]} + + +def test_nonexistent_statistic_raises_422(tif_file): + response = test_client.post( + "/exact_zonal_stats", + headers=HEADERS, + params={"raster_filename": "raster.tif", "statistics": ["nonexistent"]}, + json=GEOJSON, + ) + assert response.status_code == 422 diff --git a/api/tests/utils.py b/api/tests/utils.py index 31d21aff..8aa7242e 100644 --- a/api/tests/utils.py +++ b/api/tests/utils.py @@ -1,4 +1,5 @@ -from app.main import app from fastapi.testclient import TestClient +from app.main import app + test_client = TestClient(app) diff --git a/docker-compose.yml b/docker-compose.yml index 210f1440..ec74b8c1 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -20,10 +20,6 @@ services: context: ./api dockerfile: Dockerfile target: development - environment: - - AUTH_TOKEN=secret - - TIFF_PATH=/opt/api/test_data - - GRID_TILES_PATH=/opt/api/test_data/grid networks: - amazonia360-network