Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
BielStela committed Jul 10, 2024
1 parent 23c2b37 commit e3b2867
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 30 deletions.
10 changes: 7 additions & 3 deletions api/app/models/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi import Query
from pydantic import BaseModel, ConfigDict, Field
from pydantic_extra_types.color import Color
from sqlalchemy.sql import column, desc, select, table
from sqlalchemy.sql import column, desc, nullslast, select, table


class LegendTypes(str, Enum):
Expand Down Expand Up @@ -98,7 +98,7 @@ class NumericalFilter(BaseModel):


class TableFilters(BaseModel):
filters: list[Annotated[CategoricalFilter | NumericalFilter, Field(..., discriminator="filter_type")]]
filters: list[Annotated[CategoricalFilter | NumericalFilter, Field(discriminator="filter_type")]]
limit: int = Field(10, lt=1000, description="Number of records")
order_by: Annotated[list[str], Field(Query(..., description="Prepend '-' to column name to make it descending"))]

Expand All @@ -115,6 +115,8 @@ def to_sql_query(self, table_name: str) -> str:
}
filters_to_apply = []
for _filter in self.filters:
if _filter is None:
continue
col = column(_filter.column_name)
param = getattr(col, op_to_python_dunder.get(_filter.operation, _filter.operation))(_filter.value)
filters_to_apply.append(param)
Expand All @@ -123,6 +125,8 @@ def to_sql_query(self, table_name: str) -> str:
.select_from(table(table_name))
.where(*filters_to_apply)
.limit(self.limit)
.order_by(*[desc(column(col[1:])) if col.startswith("-") else column(col) for col in self.order_by])
.order_by(
*[nullslast(desc(column(col[1:]))) if col.startswith("-") else column(col) for col in self.order_by]
)
)
return str(query.compile(compile_kwargs={"literal_binds": True}))
28 changes: 28 additions & 0 deletions api/tests/benchmark_post.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
-- example HTTP POST script which demonstrates setting the
-- HTTP method, body, and adding a header
-- command:
-- wrk -c 100 -t 10 -d 10s -s benchmark_post.lua 'http://localhost:8000/grid/table?level=4&limit=10&order_by=-population'


wrk.method = "POST"
wrk.body = [[
[
{
"filter_type": "categorical",
"column_name": "fire",
"operation": "in",
"value": [
1,2,3
]
},
{
"filter_type": "numerical",
"column_name": "population",
"operation": "gt",
"value": 10000
}
]
]]
wrk.headers["Content-Type"] = "application/json"
wrk.headers["accept"] = "application/json"
wrk.headers["Authorization"] = "Bearer 1234"
16 changes: 15 additions & 1 deletion api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path

import numpy as np
import polars as pl
import pytest
import rasterio

Expand Down Expand Up @@ -29,11 +30,24 @@ def grid_dataset(setup_data_folder) -> str:
level_path.mkdir(parents=True)
tile_path = level_path / f"{h3_index}.arrow"

df = pl.DataFrame(
{
"cell": [
618668968382824400,
619428375900454900,
619428407452893200,
619428407943888900,
619428407676764200,
],
"landcover": [1, 4, 3, 3, 4],
"population": [100, 200, 1, 900, 900],
}
)
with open(grid_dataset_path / "meta.json", "w") as f:
f.write("{}")

with open(tile_path, "wb") as f:
f.write(b"I am an arrow file!")
df.write_ipc(f)

yield h3_index

Expand Down
87 changes: 61 additions & 26 deletions api/tests/test_grid.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,49 @@
import json

import polars as pl

from app.models.grid import TableFilters
from tests.conftest import HEADERS
from tests.utils import test_client


def test_grid_tile(grid_dataset):
response = test_client.get(f"/grid/tile/{grid_dataset}", headers=HEADERS)

assert response.status_code == 200
assert pl.read_ipc(response.read()).to_dict(as_series=False) == {
"cell": [
618668968382824400,
619428375900454900,
619428407452893200,
619428407943888900,
619428407676764200,
],
"landcover": [1, 4, 3, 3, 4],
"population": [100, 200, 1, 900, 900],
}


def test_grid_tile_404(grid_dataset):
response = test_client.get("/grid/tile/8439181ffffffff", headers=HEADERS)

assert response.status_code == 404


def test_grid_tile_bad_index(grid_dataset):
response = test_client.get("/grid/tile/123", headers=HEADERS)

assert response.status_code == 422
assert response.json() == {"detail": "Tile index is not a valid H3 cell"}


def test_grid_metadata_fails_gracefully(grid_dataset):
res = test_client.get("/grid/meta", headers=HEADERS)

assert res.status_code == 500
assert res.json() == {"detail": "Metadata file is malformed. Please contact developer."}


def test_table_filter_numerical_eq_to_sql():
tf = TableFilters.model_validate(
{
Expand Down Expand Up @@ -110,7 +151,7 @@ def test_table_filters_order_by_desc():
}
)
query = tf.to_sql_query("table")
assert query.replace("\n", "") == 'SELECT * FROM "table" WHERE foo > 10.0 ORDER BY baz DESC LIMIT 100'
assert query.replace("\n", "") == 'SELECT * FROM "table" WHERE foo > 10.0 ORDER BY baz DESC NULLS LAST LIMIT 100'


def test_table_filters_multiple_order_by():
Expand All @@ -123,7 +164,8 @@ def test_table_filters_multiple_order_by():
)
query = tf.to_sql_query("table")
assert (
query.replace("\n", "") == 'SELECT * FROM "table" WHERE foo > 10.0 ORDER BY baz DESC, foo, bar DESC LIMIT 100'
query.replace("\n", "")
== 'SELECT * FROM "table" WHERE foo > 10.0 ORDER BY baz DESC NULLS LAST, foo, bar DESC NULLS LAST LIMIT 100'
)


Expand All @@ -141,32 +183,25 @@ def test_table_filters_multiple_filters():
query = tf.to_sql_query("table")
assert (
query.replace("\n", "")
== 'SELECT * FROM "table" WHERE foo = 10.0 AND bar IN (1, 2, 3) ORDER BY baz DESC, foo LIMIT 100'
== 'SELECT * FROM "table" WHERE foo = 10.0 AND bar IN (1, 2, 3) ORDER BY baz DESC NULLS LAST, foo LIMIT 100'
)


def test_h3grid(grid_dataset):
response = test_client.get(f"/grid/tile/{grid_dataset}", headers=HEADERS)
def test_grid_table(grid_dataset):
filters = [
{"filter_type": "numerical", "column_name": "population", "operation": "lte", "value": 200},
{"filter_type": "numerical", "column_name": "population", "operation": "gt", "value": 1},
]

response = test_client.post(
"/grid/table?level=4&order_by=-population", headers=HEADERS, content=json.dumps(filters)
)
assert response.status_code == 200
assert response.read() == b"I am an arrow file!"


def test_h3grid_404(grid_dataset):
response = test_client.get("/grid/tile/8439181ffffffff", headers=HEADERS)

assert response.status_code == 404


def test_h3grid_bad_index(grid_dataset):
response = test_client.get("/grid/tile/123", headers=HEADERS)

assert response.status_code == 422
assert response.json() == {"detail": "Tile index is not a valid H3 cell"}


def test_h3grid_metadata_fails_gracefully(grid_dataset):
res = test_client.get("/grid/meta", headers=HEADERS)

assert res.status_code == 500
assert res.json() == {"detail": "Metadata file is malformed. Please contact developer."}
assert json.loads(response.read()) == {
"cell": [
619428375900454900,
618668968382824400,
],
"landcover": [4, 1],
"population": [200, 100],
}

0 comments on commit e3b2867

Please sign in to comment.