Skip to content

Commit

Permalink
Add column filter to tile endpoint (#60)
Browse files Browse the repository at this point in the history
* query builder and model for table filters json.

* add a couple of tests for the query builder

* add missing __init__ 🤦

* fix tests to new model

* add more test cases for the all the ops and changes order by to use '-' as desc indicator in column name

* tidy up names and tests

* remove duckdb from requirements

* Adds columns query param to grid tile endpoint

* fix column param default and tests

* tidyup api documentation

* set correct status code for incorrect columns in filters
  • Loading branch information
BielStela authored Jul 24, 2024
1 parent 0702788 commit e74b5c6
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 36 deletions.
33 changes: 19 additions & 14 deletions api/app/models/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ class CategoricalLegend(BaseModel):


class DatasetMeta(BaseModel):
var_name: str = Field(description="Column name")
var_dtype: str = Field(description="Column dtype. ")
var_name: str = Field(description="Column name.")
var_dtype: str = Field(description="Column dtype.")
nodata: str
description: str
aggregation_method: str = Field(description="Aggregation method used to compute the overview levels")
lineage: list[str] | None = Field(default=None, description="Source data used to compute this dataset")
aggregation_method: str = Field(description="Aggregation method used to compute the overview levels.")
lineage: list[str] | None = Field(default=None, description="Source data used to compute this dataset.")
legend: CategoricalLegend | NumericalLegend = Field(discriminator="legend_type")


Expand All @@ -60,8 +60,8 @@ class H3GridInfo(BaseModel):


class MultiDatasetMeta(BaseModel):
datasets: list[DatasetMeta] = Field(description="Variables represented in this dataset")
h3_grid_info: list[H3GridInfo] = Field(description="H3 related information")
datasets: list[DatasetMeta] = Field(description="Variables represented in this dataset.")
h3_grid_info: list[H3GridInfo] = Field(description="H3 related information.")


# ===============================================
Expand All @@ -85,22 +85,27 @@ class CategoricalOperators(str, Enum):

class CategoricalFilter(BaseModel):
filter_type: Literal["categorical"]
column_name: str = Field(description="Name of the column to which the filter will apply")
operation: CategoricalOperators = Field()
value: list[int] = Field(description="Value to compare with")
column_name: str = Field(description="Name of the column to which the filter will apply.")
operation: CategoricalOperators
value: list[int] = Field(description="Value to compare with.")


class NumericalFilter(BaseModel):
filter_type: Literal["numerical"]
column_name: str = Field(description="Name of the column to which the filter will apply")
operation: NumericalOperators = Field(description="Operation to use in compare")
value: float = Field(description="Value to compare with")
column_name: str = Field(description="Name of the column to which the filter will apply.")
operation: NumericalOperators = Field(description="Operation to use in compare.")
value: float = Field(description="Value to compare with.")


class TableFilters(BaseModel):
filters: list[Annotated[CategoricalFilter | NumericalFilter, Field(discriminator="filter_type")]]
limit: int = Field(10, lt=1000, description="Number of records")
order_by: Annotated[list[str], Field(Query(..., description="Prepend '-' to column name to make it descending"))]
limit: int = Field(Query(10, lt=1000, description="Number of records."))
order_by: Annotated[
list[str],
Field(
Query(..., description="Prepend '-' to column name to make it descending"),
),
]

def to_sql_query(self, table_name: str) -> str:
"""Compile model to sql query"""
Expand Down
45 changes: 25 additions & 20 deletions api/app/routers/grid.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import logging
import os
from pathlib import Path
import pathlib
from typing import Annotated

import h3
import polars as pl
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import FileResponse, ORJSONResponse
from fastapi import APIRouter, Depends, HTTPException, Path, Query
from fastapi.responses import ORJSONResponse
from h3 import H3CellError
from pydantic import ValidationError
from starlette.responses import Response

from app.config.config import get_settings
from app.models.grid import MultiDatasetMeta, TableFilters
Expand All @@ -20,31 +21,35 @@

@grid_router.get(
"/tile/{tile_index}",
responses={200: {"description": "Get a grid tile"}, 404: {"description": "Not found"}},
response_model=None,
summary="Get a grid tile",
)
async def grid_tile(tile_index: str) -> FileResponse:
"""Request a tile of h3 cells
:raises HTTPException 404: Item not found
:raises HTTPException 422: H3 index is not valid
"""
async def grid_tile(
tile_index: Annotated[str, Path(description="The `h3` index of the tile")],
columns: list[str] = Query(
[], description="Colum/s to include in the tile. If empty, it returns only cell indexes."
),
) -> Response:
"""Get a tile of h3 cells with specified data columns"""
try:
z = h3.api.basic_str.h3_get_resolution(tile_index)
except H3CellError:
raise HTTPException(status_code=422, detail="Tile index is not a valid H3 cell") from None

tile_file = os.path.join(get_settings().grid_tiles_path, f"{z}/{tile_index}.arrow")
if not os.path.exists(tile_file):
raise HTTPException(status_code=404, detail=f"Tile {tile_file} not found")
return FileResponse(tile_file, media_type="application/octet-stream")
raise HTTPException(status_code=400, detail="Tile index is not a valid H3 cell") from None
tile_path = os.path.join(get_settings().grid_tiles_path, f"{z}/{tile_index}.arrow")
if not os.path.exists(tile_path):
raise HTTPException(status_code=404, detail=f"Tile {tile_path} not found")
try:
tile_file = pl.read_ipc(tile_path, columns=["cell", *columns]).write_ipc(None)
except pl.exceptions.ColumnNotFoundError:
raise HTTPException(status_code=400, detail="One or more of the specified columns is not valid") from None
return Response(tile_file.getvalue(), media_type="application/octet-stream")


@grid_router.get(
"/meta",
summary="Dataset metadata",
)
async def grid_dataset_metadata() -> MultiDatasetMeta:
"""Dataset metadata"""
"""Get the grid dataset metadata"""
file = os.path.join(get_settings().grid_tiles_path, "meta.json")
with open(file) as f:
raw = f.read()
Expand All @@ -64,7 +69,7 @@ def read_table(
filters: TableFilters = Depends(),
) -> ORJSONResponse:
"""Query tile dataset and return table data"""
files_path = Path(get_settings().grid_tiles_path) / str(level)
files_path = pathlib.Path(get_settings().grid_tiles_path) / str(level)
if not files_path.exists():
raise HTTPException(404, detail=f"Level {level} does not exist") from None
lf = pl.scan_ipc(files_path.glob("*.arrow"))
Expand All @@ -75,7 +80,7 @@ def read_table(
except pl.exceptions.ColumnNotFoundError as e:
# bad column in order by clause
log.exception(e)
raise HTTPException(status_code=404, detail=f"Column '{e}' not found in dataset") from None
raise HTTPException(status_code=400, detail="One or more of the specified columns is not valid") from None

except pl.exceptions.ComputeError as e:
# possibly raise if wrong type in compare. I'm not aware of other sources of ComputeError
Expand Down
27 changes: 25 additions & 2 deletions api/tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@


def test_grid_tile(grid_dataset):
response = test_client.get(f"/grid/tile/{grid_dataset}", headers=HEADERS)
response = test_client.get(
f"/grid/tile/{grid_dataset}", params={"columns": ["landcover", "population"]}, headers=HEADERS
)

assert response.status_code == 200
assert pl.read_ipc(response.read()).to_dict(as_series=False) == {
Expand All @@ -24,6 +26,27 @@ def test_grid_tile(grid_dataset):
}


def test_grid_tile_empty_column_param(grid_dataset):
response = test_client.get(f"/grid/tile/{grid_dataset}", headers=HEADERS)

assert response.status_code == 200
assert pl.read_ipc(response.read()).to_dict(as_series=False) == {
"cell": [
618668968382824400,
619428375900454900,
619428407452893200,
619428407943888900,
619428407676764200,
],
}


def test_grid_tile_wrong_column(grid_dataset):
response = test_client.get(f"/grid/tile/{grid_dataset}", params={"columns": ["NOEXIST"]}, headers=HEADERS)
assert response.status_code == 400
assert response.json() == {"detail": "One or more of the specified columns is not valid"}


def test_grid_tile_404(grid_dataset):
response = test_client.get("/grid/tile/8439181ffffffff", headers=HEADERS)

Expand All @@ -33,7 +56,7 @@ def test_grid_tile_404(grid_dataset):
def test_grid_tile_bad_index(grid_dataset):
response = test_client.get("/grid/tile/123", headers=HEADERS)

assert response.status_code == 422
assert response.status_code == 400
assert response.json() == {"detail": "Tile index is not a valid H3 cell"}


Expand Down

0 comments on commit e74b5c6

Please sign in to comment.