diff --git a/api/app/models/grid.py b/api/app/models/grid.py index 0b8d2a3f..47b74a58 100644 --- a/api/app/models/grid.py +++ b/api/app/models/grid.py @@ -6,7 +6,7 @@ from fastapi import Query from pydantic import BaseModel, ConfigDict, Field from pydantic_extra_types.color import Color -from sqlalchemy.sql import column, desc, select, table +from sqlalchemy.sql import column, desc, nullslast, select, table class LegendTypes(str, Enum): @@ -98,7 +98,7 @@ class NumericalFilter(BaseModel): class TableFilters(BaseModel): - filters: list[Annotated[CategoricalFilter | NumericalFilter, Field(..., discriminator="filter_type")]] + filters: list[Annotated[CategoricalFilter | NumericalFilter, Field(discriminator="filter_type")]] limit: int = Field(10, lt=1000, description="Number of records") order_by: Annotated[list[str], Field(Query(..., description="Prepend '-' to column name to make it descending"))] @@ -115,6 +115,8 @@ def to_sql_query(self, table_name: str) -> str: } filters_to_apply = [] for _filter in self.filters: + if _filter is None: + continue col = column(_filter.column_name) param = getattr(col, op_to_python_dunder.get(_filter.operation, _filter.operation))(_filter.value) filters_to_apply.append(param) @@ -123,6 +125,8 @@ def to_sql_query(self, table_name: str) -> str: .select_from(table(table_name)) .where(*filters_to_apply) .limit(self.limit) - .order_by(*[desc(column(col[1:])) if col.startswith("-") else column(col) for col in self.order_by]) + .order_by( + *[nullslast(desc(column(col[1:]))) if col.startswith("-") else column(col) for col in self.order_by] + ) ) return str(query.compile(compile_kwargs={"literal_binds": True})) diff --git a/api/tests/benchmark_post.lua b/api/tests/benchmark_post.lua new file mode 100644 index 00000000..c71ab840 --- /dev/null +++ b/api/tests/benchmark_post.lua @@ -0,0 +1,28 @@ +-- example HTTP POST script which demonstrates setting the +-- HTTP method, body, and adding a header +-- command: +-- wrk -c 100 -t 10 -d 10s -s benchmark_post.lua 'http://localhost:8000/grid/table?level=4&limit=10&order_by=-population' + + +wrk.method = "POST" +wrk.body = [[ +[ + { + "filter_type": "categorical", + "column_name": "fire", + "operation": "in", + "value": [ + 1,2,3 + ] + }, + { + "filter_type": "numerical", + "column_name": "population", + "operation": "gt", + "value": 10000 + } +] +]] +wrk.headers["Content-Type"] = "application/json" +wrk.headers["accept"] = "application/json" +wrk.headers["Authorization"] = "Bearer 1234" diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 35274636..8ec13016 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -2,6 +2,7 @@ from pathlib import Path import numpy as np +import polars as pl import pytest import rasterio @@ -29,11 +30,24 @@ def grid_dataset(setup_data_folder) -> str: level_path.mkdir(parents=True) tile_path = level_path / f"{h3_index}.arrow" + df = pl.DataFrame( + { + "cell": [ + 618668968382824400, + 619428375900454900, + 619428407452893200, + 619428407943888900, + 619428407676764200, + ], + "landcover": [1, 4, 3, 3, 4], + "population": [100, 200, 1, 900, 900], + } + ) with open(grid_dataset_path / "meta.json", "w") as f: f.write("{}") with open(tile_path, "wb") as f: - f.write(b"I am an arrow file!") + df.write_ipc(f) yield h3_index diff --git a/api/tests/test_grid.py b/api/tests/test_grid.py index cb365748..d352ab25 100644 --- a/api/tests/test_grid.py +++ b/api/tests/test_grid.py @@ -1,8 +1,49 @@ +import json + +import polars as pl + from app.models.grid import TableFilters from tests.conftest import HEADERS from tests.utils import test_client +def test_grid_tile(grid_dataset): + response = test_client.get(f"/grid/tile/{grid_dataset}", headers=HEADERS) + + assert response.status_code == 200 + assert pl.read_ipc(response.read()).to_dict(as_series=False) == { + "cell": [ + 618668968382824400, + 619428375900454900, + 619428407452893200, + 619428407943888900, + 619428407676764200, + ], + "landcover": [1, 4, 3, 3, 4], + "population": [100, 200, 1, 900, 900], + } + + +def test_grid_tile_404(grid_dataset): + response = test_client.get("/grid/tile/8439181ffffffff", headers=HEADERS) + + assert response.status_code == 404 + + +def test_grid_tile_bad_index(grid_dataset): + response = test_client.get("/grid/tile/123", headers=HEADERS) + + assert response.status_code == 422 + assert response.json() == {"detail": "Tile index is not a valid H3 cell"} + + +def test_grid_metadata_fails_gracefully(grid_dataset): + res = test_client.get("/grid/meta", headers=HEADERS) + + assert res.status_code == 500 + assert res.json() == {"detail": "Metadata file is malformed. Please contact developer."} + + def test_table_filter_numerical_eq_to_sql(): tf = TableFilters.model_validate( { @@ -110,7 +151,7 @@ def test_table_filters_order_by_desc(): } ) query = tf.to_sql_query("table") - assert query.replace("\n", "") == 'SELECT * FROM "table" WHERE foo > 10.0 ORDER BY baz DESC LIMIT 100' + assert query.replace("\n", "") == 'SELECT * FROM "table" WHERE foo > 10.0 ORDER BY baz DESC NULLS LAST LIMIT 100' def test_table_filters_multiple_order_by(): @@ -123,7 +164,8 @@ def test_table_filters_multiple_order_by(): ) query = tf.to_sql_query("table") assert ( - query.replace("\n", "") == 'SELECT * FROM "table" WHERE foo > 10.0 ORDER BY baz DESC, foo, bar DESC LIMIT 100' + query.replace("\n", "") + == 'SELECT * FROM "table" WHERE foo > 10.0 ORDER BY baz DESC NULLS LAST, foo, bar DESC NULLS LAST LIMIT 100' ) @@ -141,32 +183,25 @@ def test_table_filters_multiple_filters(): query = tf.to_sql_query("table") assert ( query.replace("\n", "") - == 'SELECT * FROM "table" WHERE foo = 10.0 AND bar IN (1, 2, 3) ORDER BY baz DESC, foo LIMIT 100' + == 'SELECT * FROM "table" WHERE foo = 10.0 AND bar IN (1, 2, 3) ORDER BY baz DESC NULLS LAST, foo LIMIT 100' ) -def test_h3grid(grid_dataset): - response = test_client.get(f"/grid/tile/{grid_dataset}", headers=HEADERS) +def test_grid_table(grid_dataset): + filters = [ + {"filter_type": "numerical", "column_name": "population", "operation": "lte", "value": 200}, + {"filter_type": "numerical", "column_name": "population", "operation": "gt", "value": 1}, + ] + response = test_client.post( + "/grid/table?level=4&order_by=-population", headers=HEADERS, content=json.dumps(filters) + ) assert response.status_code == 200 - assert response.read() == b"I am an arrow file!" - - -def test_h3grid_404(grid_dataset): - response = test_client.get("/grid/tile/8439181ffffffff", headers=HEADERS) - - assert response.status_code == 404 - - -def test_h3grid_bad_index(grid_dataset): - response = test_client.get("/grid/tile/123", headers=HEADERS) - - assert response.status_code == 422 - assert response.json() == {"detail": "Tile index is not a valid H3 cell"} - - -def test_h3grid_metadata_fails_gracefully(grid_dataset): - res = test_client.get("/grid/meta", headers=HEADERS) - - assert res.status_code == 500 - assert res.json() == {"detail": "Metadata file is malformed. Please contact developer."} + assert json.loads(response.read()) == { + "cell": [ + 619428375900454900, + 618668968382824400, + ], + "landcover": [4, 1], + "population": [200, 100], + }