diff --git a/api/Dockerfile b/api/Dockerfile index b95b4b07..df066a80 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -21,7 +21,7 @@ RUN pip install --no-cache-dir -r requirements.txt EXPOSE 8000 -CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--root-path", "/api/"] FROM production as development diff --git a/api/app/models/grid.py b/api/app/models/grid.py index af246846..8edc8621 100644 --- a/api/app/models/grid.py +++ b/api/app/models/grid.py @@ -50,7 +50,6 @@ class DatasetMeta(BaseModel): nodata: str | None = Field(default=None, description="Nodata value used in grid") description: str = Field(description="Human readable indicator description.") unit: str | None = Field(description="Unit of the measurement") - lineage: list[str] | None = Field(default=None, description="Source data used to compute this dataset.") legend: CategoricalLegend | NumericalLegend = Field(discriminator="legend_type") diff --git a/api/app/routers/grid.py b/api/app/routers/grid.py index b135e6ac..1bfd9928 100644 --- a/api/app/routers/grid.py +++ b/api/app/routers/grid.py @@ -36,6 +36,22 @@ class ArrowIPCResponse(Response): # noqa: D101 media_type = "application/octet-stream" +def colum_filter( + columns: list[str] = Query( + [], description="Column/s to include in the tile. If empty, it returns only cell indexes." + ), +): + return columns + + +def feature_filter(geojson: Annotated[Feature, Body(description="GeoJSON feature used to filter the cells.")]): + return geojson + + +ColumnDep = Annotated[list[str], Depends(colum_filter)] +FeatureDep = Annotated[Feature, Depends(feature_filter)] + + def get_tile( tile_index: Annotated[str, Path(description="The `h3` index of the tile")], columns: list[str] ) -> tuple[pl.LazyFrame, int]: @@ -85,9 +101,7 @@ def polars_to_string_ipc(df: pl.DataFrame) -> bytes: ) def grid_tile( tile_index: Annotated[str, Path(description="The `h3` index of the tile")], - columns: list[str] = Query( - [], description="Colum/s to include in the tile. If empty, it returns only cell indexes." - ), + columns: ColumnDep, ) -> ArrowIPCResponse: """Get a tile of h3 cells with specified data columns""" tile, _ = get_tile(tile_index, columns) @@ -107,11 +121,7 @@ def grid_tile( responses=tile_exception_responses, ) def grid_tile_in_area( - tile_index: Annotated[str, Path(description="The `h3` index of the tile")], - geojson: Annotated[Feature, Body(description="GeoJSON feature used to filter the cells.")], - columns: list[str] = Query( - [], description="Colum/s to include in the tile. If empty, it returns only cell indexes." - ), + tile_index: Annotated[str, Path(description="The `h3` index of the tile")], geojson: FeatureDep, columns: ColumnDep ) -> ArrowIPCResponse: """Get a tile of h3 cells that are inside the polygon""" tile, tile_index_res = get_tile(tile_index, columns) @@ -128,12 +138,9 @@ def grid_tile_in_area( return ArrowIPCResponse(polars_to_string_ipc(tile)) -@grid_router.get( - "/meta", - summary="Dataset metadata", -) -async def grid_dataset_metadata() -> MultiDatasetMeta: - """Get the grid dataset metadata""" +@lru_cache +def load_meta() -> MultiDatasetMeta: + """Load the metadata file and validate it""" file = os.path.join(get_settings().grid_tiles_path, "meta.json") with open(file) as f: raw = f.read() @@ -147,6 +154,53 @@ async def grid_dataset_metadata() -> MultiDatasetMeta: return meta +@grid_router.get( + "/meta", + summary="Dataset metadata", +) +async def grid_dataset_metadata() -> MultiDatasetMeta: + """Get the grid dataset metadata""" + return load_meta() + + +@grid_router.post( + "/meta", + summary="Dataset metadata for feature selection", +) +async def grid_dataset_metadata_in_area( + geojson: FeatureDep, + columns: ColumnDep, + level: Annotated[int, Query(..., description="Tile level at which the query will be computed")] = 1, +) -> MultiDatasetMeta: + """Get the grid dataset metadata with updated min and max for the area""" + meta = load_meta().model_dump() + + files_path = pathlib.Path(get_settings().grid_tiles_path) / str(level) + if not files_path.exists(): + raise HTTPException(404, detail=f"Level {level} does not exist") from None + + lf = pl.scan_ipc(list(files_path.glob("*.arrow"))) + if columns: + lf = lf.select(["cell", *columns]) + if geojson is not None: + cell_res = level + get_settings().tile_to_cell_resolution_diff + geom = shapely.from_geojson(geojson.model_dump_json()) + cells = cells_in_geojson(geom, cell_res) + lf = lf.join(cells, on="cell") + + maxs = lf.select(pl.selectors.numeric().max()).collect() + mins = lf.select(pl.selectors.numeric().min()).collect() + + for dataset in meta["datasets"]: + column = dataset["var_name"] + if dataset["legend"]["legend_type"] == "categorical": + continue + stats = dataset["legend"]["stats"][0] + stats["min"] = mins.select(pl.col(column)).item() + stats["max"] = maxs.select(pl.col(column)).item() + return MultiDatasetMeta.model_validate(meta) + + @grid_router.post("/table") def read_table( level: Annotated[int, Query(..., description="Tile level at which the query will be computed")], diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 8910dd20..46ed7f9c 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -9,6 +9,41 @@ from app.config.config import get_settings +METADATA = { + "datasets": [ + { + "var_name": "landcover", + "var_dtype": "Int32", + "label": "foo", + "description": "foo", + "unit": "", + "legend": { + "legend_type": "categorical", + "entries": [{"value": 1, "color": "#ffffff", "label": "all"}], + }, + }, + { + "var_name": "population", + "var_dtype": "Int32", + "label": "bar", + "description": "bar", + "unit": "count", + "legend": { + "legend_type": "continuous", + "colormap_name": "viridis", + "stats": [{"level": 1, "min": 1, "max": 900}], + }, + }, + ], + "h3_grid_info": [ + { + "level": 1, + "h3_cells_resolution": 6, + "h3_cells_count": 5, + } + ], +} + TEST_ROOT = Path(__file__).resolve().parent # Testing settings env variables @@ -76,7 +111,7 @@ def grid_dataset(setup_data_folder) -> str: } ) with open(grid_dataset_path / "meta.json", "w") as f: - f.write("{}") + f.write(json.dumps(METADATA)) with open(tile_path, "wb") as f: df.write_ipc(f) diff --git a/api/tests/test_grid.py b/api/tests/test_grid.py index bfda7dd1..034bb65b 100644 --- a/api/tests/test_grid.py +++ b/api/tests/test_grid.py @@ -3,7 +3,7 @@ import polars as pl from app.models.grid import TableFilters -from tests.conftest import HEADERS +from tests.conftest import HEADERS, METADATA from tests.utils import test_client @@ -60,11 +60,11 @@ def test_grid_tile_bad_index(grid_dataset): assert response.json() == {"detail": "Tile index is not a valid H3 cell"} -def test_grid_metadata_fails_gracefully(grid_dataset): +def test_grid_metadata(grid_dataset): res = test_client.get("/grid/meta", headers=HEADERS) - assert res.status_code == 500 - assert res.json() == {"detail": "Metadata file is malformed. Please contact developer."} + assert res.status_code == 200 + assert len(res.json()["datasets"]) == 2 def test_table_filter_numerical_eq_to_sql(): @@ -219,7 +219,7 @@ def test_grid_table(grid_dataset): } response = test_client.post("/grid/table?level=4&order_by=-population", headers=HEADERS, content=json.dumps(body)) - assert response.status_code == 200 + assert response.status_code == 200, response.content assert json.loads(response.read()) == { "table": [ {"column": "landcover", "values": [4, 1]}, @@ -284,3 +284,25 @@ def test_grid_tile_post_wrong_column(grid_dataset, geojson): assert response.status_code == 400 assert response.json() == {"detail": "One or more of the specified columns is not valid"} + + +def test_grid_metadata_filter(grid_dataset, geojson): + response = test_client.post( + "/grid/meta", + params={"level": 4}, + headers=HEADERS, + content=geojson, + ) + assert response.status_code == 200, response.content + + meta = response.json() + + assert len(meta["datasets"]) == 2 + population = [ds for ds in meta["datasets"] if ds["var_name"] == "population"][0] + + assert population["legend"]["stats"][0]["max"] == 100 + assert population["legend"]["stats"][0]["min"] == 100 + + + +