From 2c35a8c0ff3fd195b456c1f49f7567256a341d76 Mon Sep 17 00:00:00 2001 From: Biel Stela Date: Thu, 26 Sep 2024 11:31:24 +0200 Subject: [PATCH] Filter cells by region and use hex representation of h3 cell id (#66) * update dataset to use hexadecimal representation of cell id * Add enpoint to get tiles filtered by geojson * Speedup request by caching geojson filling result and improve queries * removes all outputs from notebooks * Adds nbstrip to pre-commit * update deps and filter table by geojson * Adds test for table results with geojson filter * Tidy up the table response model * Improve response documentation in OpenAPI * Update compose call to use the compose subcommand in the ci runner * small comment * improve benchmark by using multiple geojson so we can check caching --- .github/workflows/cicd.yml | 8 +- .pre-commit-config.yaml | 7 +- api/app/config/config.py | 1 + api/app/models/grid.py | 9 ++ api/app/routers/grid.py | 119 +++++++++++--- api/requirements.in | 1 + api/requirements.txt | 84 ++++++---- api/tests/benchmark_grid_post.lua | 147 ++++++++++++++++++ ...mark_post.lua => benchmark_table_post.lua} | 0 api/tests/conftest.py | 35 ++++- api/tests/test_grid.py | 100 +++++++++--- api/tests/test_zonal_stats.py | 1 - science/notebooks/check_combine_results.ipynb | 36 ++--- science/notebooks/merge_entrega_roberto.ipynb | 19 +-- 14 files changed, 449 insertions(+), 118 deletions(-) create mode 100644 api/tests/benchmark_grid_post.lua rename api/tests/{benchmark_post.lua => benchmark_table_post.lua} (100%) diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index d48f841a..fd3bfd10 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -20,10 +20,10 @@ jobs: uses: docker/setup-buildx-action@v1 - name: Build and run tests - run: docker-compose up --build --exit-code-from test test + run: docker compose up --build --exit-code-from test test - name: Clean up - run: docker-compose down + run: docker compose down deploy: name: Deploy @@ -40,8 +40,8 @@ jobs: script: | cd amazonia-360 git pull --rebase - sudo docker-compose down - sudo docker-compose up -d api --build + sudo docker compose down + sudo docker compose up -d api --build health-check: name: Health Check diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9d983358..06290094 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,8 +16,7 @@ repos: - id: ruff-format types_or: [ python, pyi, jupyter ] - # check for private keys and passwords! - - repo: https://github.com/gitleaks/gitleaks - rev: v8.17.0 + - repo: https://github.com/kynan/nbstripout + rev: 0.7.1 hooks: - - id: gitleaks-docker + - id: nbstripout diff --git a/api/app/config/config.py b/api/app/config/config.py index a8882021..c57f6501 100644 --- a/api/app/config/config.py +++ b/api/app/config/config.py @@ -11,6 +11,7 @@ class Settings(BaseSettings): auth_token: str tiff_path: str grid_tiles_path: str + tile_to_cell_resolution_diff: int = 5 @lru_cache diff --git a/api/app/models/grid.py b/api/app/models/grid.py index f1ea4b3f..17ce102d 100644 --- a/api/app/models/grid.py +++ b/api/app/models/grid.py @@ -136,3 +136,12 @@ def to_sql_query(self, table_name: str) -> str: ) ) return str(query.compile(compile_kwargs={"literal_binds": True})) + + +class TableResultColumn(BaseModel): + column: Annotated[str, Field(title="column", description="Column name")] + values: Annotated[list, Field(description="Check dataset metadata for type info")] + + +class TableResults(BaseModel): + table: list[TableResultColumn] diff --git a/api/app/routers/grid.py b/api/app/routers/grid.py index 5394220a..0d1f1d13 100644 --- a/api/app/routers/grid.py +++ b/api/app/routers/grid.py @@ -1,47 +1,114 @@ import logging import os import pathlib +from functools import lru_cache from typing import Annotated import h3 +import h3ronpy.polars # noqa: F401 import polars as pl +import shapely from fastapi import APIRouter, Depends, HTTPException, Path, Query -from fastapi.responses import ORJSONResponse +from fastapi.params import Body +from fastapi.responses import Response +from geojson_pydantic import Feature from h3 import H3CellError +from h3ronpy.polars import cells_to_string +from h3ronpy.polars.vector import geometry_to_cells from pydantic import ValidationError -from starlette.responses import Response from app.config.config import get_settings -from app.models.grid import MultiDatasetMeta, TableFilters +from app.models.grid import MultiDatasetMeta, TableFilters, TableResults -log = logging.getLogger("uvicorn.error") +log = logging.getLogger("uvicorn.error") # Show the logs in the uvicorn runner logs grid_router = APIRouter() +tile_exception_responses = { + 400: {"description": "Column does not exist or tile_index is not valid h3 index."}, + 404: {"description": "Tile does not exist or is empty"}, +} + + +class ArrowIPCResponse(Response): # noqa: D101 + media_type = "application/octet-stream" + + +def get_tile(tile_index: str, columns: list[str]) -> tuple[pl.LazyFrame, int]: + """Get the tile from filesystem filtered by column and the resolution of the tile index""" + try: + z = h3.api.basic_str.h3_get_resolution(tile_index) + except (H3CellError, ValueError): + raise HTTPException(status_code=400, detail="Tile index is not a valid H3 cell") from None + tile_path = os.path.join(get_settings().grid_tiles_path, f"{z}/{tile_index}.arrow") + if not os.path.exists(tile_path): + raise HTTPException(status_code=404, detail=f"Tile {tile_path} not found") + tile = pl.scan_ipc(tile_path).select(["cell", *columns]) + return tile, z + + +@lru_cache +def cells_in_geojson(geometry: str, cell_resolution: int) -> pl.LazyFrame: + """Return the cells that fill the polygon area in the geojson + + Geometry must be a shapely geometry, a wkt or wkb so the lru cache + can hash the parameter. + """ + cells = cells_to_string(geometry_to_cells(geometry, cell_resolution)) + return pl.LazyFrame({"cell": cells}) + @grid_router.get( "/tile/{tile_index}", summary="Get a grid tile", + response_class=ArrowIPCResponse, + response_description="Arrow IPC table", + responses=tile_exception_responses, ) -async def grid_tile( +def grid_tile( tile_index: Annotated[str, Path(description="The `h3` index of the tile")], columns: list[str] = Query( [], description="Colum/s to include in the tile. If empty, it returns only cell indexes." ), -) -> Response: +) -> ArrowIPCResponse: """Get a tile of h3 cells with specified data columns""" + tile, _ = get_tile(tile_index, columns) try: - z = h3.api.basic_str.h3_get_resolution(tile_index) - except H3CellError: - raise HTTPException(status_code=400, detail="Tile index is not a valid H3 cell") from None - tile_path = os.path.join(get_settings().grid_tiles_path, f"{z}/{tile_index}.arrow") - if not os.path.exists(tile_path): - raise HTTPException(status_code=404, detail=f"Tile {tile_path} not found") + tile_buffer = tile.collect().write_ipc(None) + # we don't know if the column requested are correct until we call .collect() + except pl.exceptions.ColumnNotFoundError: + raise HTTPException(status_code=400, detail="One or more of the specified columns is not valid") from None + return ArrowIPCResponse(tile_buffer.getvalue()) + + +@grid_router.post( + "/tile/{tile_index}", + summary="Get a grid tile with cells contained inside the GeoJSON", + response_class=ArrowIPCResponse, + response_description="Arrow IPC table", + responses=tile_exception_responses, +) +def grid_tile_in_area( + tile_index: Annotated[str, Path(description="The `h3` index of the tile")], + geojson: Annotated[Feature, Body(description="GeoJSON feature used to filter the cells.")], + columns: list[str] = Query( + [], description="Colum/s to include in the tile. If empty, it returns only cell indexes." + ), +) -> ArrowIPCResponse: + """Get a tile of h3 cells that are inside the polygon""" + tile, tile_index_res = get_tile(tile_index, columns) + cell_res = tile_index_res + get_settings().tile_to_cell_resolution_diff + geom = shapely.from_geojson(geojson.model_dump_json()) + cells = cells_in_geojson(geom, cell_res) try: - tile_file = pl.read_ipc(tile_path, columns=["cell", *columns]).write_ipc(None) + tile = tile.join(cells, on="cell").collect() + # we don't know if the column requested are correct until we call .collect() except pl.exceptions.ColumnNotFoundError: raise HTTPException(status_code=400, detail="One or more of the specified columns is not valid") from None - return Response(tile_file.getvalue(), media_type="application/octet-stream") + if tile.is_empty(): + raise HTTPException(status_code=404, detail="No data in region") + tile_buffer = tile.write_ipc(None) + return ArrowIPCResponse(tile_buffer.getvalue()) @grid_router.get( @@ -67,23 +134,31 @@ async def grid_dataset_metadata() -> MultiDatasetMeta: def read_table( level: Annotated[int, Query(..., description="Tile level at which the query will be computed")], filters: TableFilters = Depends(), -) -> ORJSONResponse: + geojson: Feature | None = None, +) -> TableResults: """Query tile dataset and return table data""" files_path = pathlib.Path(get_settings().grid_tiles_path) / str(level) if not files_path.exists(): raise HTTPException(404, detail=f"Level {level} does not exist") from None - lf = pl.scan_ipc(files_path.glob("*.arrow")) + + lf = pl.scan_ipc(list(files_path.glob("*.arrow"))) + + if geojson is not None: + cell_res = level + get_settings().tile_to_cell_resolution_diff + geom = shapely.from_geojson(geojson.model_dump_json()) + cells = cells_in_geojson(geom, cell_res) + lf = lf.join(cells, on="cell") + query = filters.to_sql_query("frame") log.debug(query) + try: res = pl.SQLContext(frame=lf).execute(query).collect() - except pl.exceptions.ColumnNotFoundError as e: - # bad column in order by clause + except pl.exceptions.ColumnNotFoundError as e: # bad column in order by clause log.exception(e) raise HTTPException(status_code=400, detail="One or more of the specified columns is not valid") from None - - except pl.exceptions.ComputeError as e: - # possibly raise if wrong type in compare. I'm not aware of other sources of ComputeError + except pl.exceptions.ComputeError as e: # raised if wrong type in compare. log.exception(e) raise HTTPException(status_code=422, detail=str(e)) from None - return ORJSONResponse(res.to_dict(as_series=False)) + + return TableResults(table=[{"column": k, "values": v} for k, v in res.to_dict(as_series=False).items()]) diff --git a/api/requirements.in b/api/requirements.in index a2bf54c0..2d6a81c2 100644 --- a/api/requirements.in +++ b/api/requirements.in @@ -8,3 +8,4 @@ h3 pydantic-extra-types polars sqlalchemy +h3ronpy \ No newline at end of file diff --git a/api/requirements.txt b/api/requirements.txt index 50c7a96b..471937ae 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -2,20 +2,20 @@ # uv pip compile requirements.in -o requirements.txt affine==2.4.0 # via rasterio -annotated-types==0.6.0 +annotated-types==0.7.0 # via pydantic -anyio==4.3.0 +anyio==4.4.0 # via # httpx # starlette -attrs==23.2.0 +attrs==24.2.0 # via # morecantile # rasterio # rio-tiler -cachetools==5.3.3 +cachetools==5.5.0 # via rio-tiler -certifi==2024.2.2 +certifi==2024.8.30 # via # httpcore # httpx @@ -31,49 +31,62 @@ click-plugins==1.1.1 # via rasterio cligj==0.7.2 # via rasterio -color-operations==0.1.3 +color-operations==0.1.5 # via rio-tiler -exactextract==0.2.0.dev0 -fastapi==0.110.1 - # via titiler-core -geojson-pydantic==1.0.2 +exactextract==0.2.0 + # via -r requirements.in +fastapi==0.114.1 + # via + # -r requirements.in + # titiler-core +geojson-pydantic==1.1.1 # via titiler-core -greenlet==3.0.3 +greenlet==3.1.0 # via sqlalchemy h11==0.14.0 # via # httpcore # uvicorn h3==3.7.7 + # via -r requirements.in +h3ronpy==0.21.0 + # via -r requirements.in httpcore==1.0.5 # via httpx -httpx==0.27.0 +httpx==0.27.2 # via rio-tiler -idna==3.6 +idna==3.8 # via # anyio # httpx -jinja2==3.1.3 +jinja2==3.1.4 # via titiler-core markupsafe==2.1.5 # via jinja2 -morecantile==5.3.0 +morecantile==5.4.2 # via # rio-tiler # titiler-core -numexpr==2.10.0 +numexpr==2.10.1 # via rio-tiler numpy==1.26.4 # via # color-operations + # h3ronpy # numexpr + # pyarrow # rasterio # rio-tiler + # shapely # snuggs # titiler-core -orjson==3.10.0 -polars==1.1.0 -pydantic==2.6.4 +orjson==3.10.7 + # via -r requirements.in +polars==1.7.0 + # via -r requirements.in +pyarrow==17.0.0 + # via h3ronpy +pydantic==2.9.1 # via # fastapi # geojson-pydantic @@ -82,29 +95,33 @@ pydantic==2.6.4 # pydantic-settings # rio-tiler # titiler-core -pydantic-core==2.16.3 +pydantic-core==2.23.3 # via pydantic pydantic-extra-types==2.9.0 -pydantic-settings==2.2.1 -pyparsing==3.1.2 + # via -r requirements.in +pydantic-settings==2.5.2 + # via -r requirements.in +pyparsing==3.1.4 # via snuggs pyproj==3.6.1 # via morecantile -pystac==1.10.0 +pystac==1.10.1 # via rio-tiler python-dateutil==2.9.0.post0 # via pystac python-dotenv==1.0.1 # via pydantic-settings -rasterio==1.3.9 +rasterio==1.3.11 # via # rio-tiler # titiler-core -rio-tiler==6.4.5 +rio-tiler==6.7.0 # via titiler-core -setuptools==69.2.0 +setuptools==74.1.2 # via rasterio -simplejson==3.19.2 +shapely==2.0.6 + # via h3ronpy +simplejson==3.19.3 # via titiler-core six==1.16.0 # via python-dateutil @@ -114,15 +131,18 @@ sniffio==1.3.1 # httpx snuggs==1.4.7 # via rasterio -sqlalchemy==2.0.31 -starlette==0.37.2 +sqlalchemy==2.0.34 + # via -r requirements.in +starlette==0.38.5 # via fastapi -titiler-core==0.18.0 -typing-extensions==4.11.0 +titiler-core==0.18.6 + # via -r requirements.in +typing-extensions==4.12.2 # via # fastapi # pydantic # pydantic-core # sqlalchemy # titiler-core -uvicorn==0.29.0 +uvicorn==0.30.6 + # via -r requirements.in diff --git a/api/tests/benchmark_grid_post.lua b/api/tests/benchmark_grid_post.lua new file mode 100644 index 00000000..ea2083cf --- /dev/null +++ b/api/tests/benchmark_grid_post.lua @@ -0,0 +1,147 @@ +-- command: +-- wrk -c 100 -t 10 -d 10s -s benchmark_grid_post.lua 'http://localhost:8000/grid/tile/815f7ffffffffff?columns=AMIN' + +local geojsons = { + [[ + { + "type": "Feature", + "properties": {}, + "geometry": { + "coordinates": [ + [ + [ + -61.113268179996055, + 8.666717320892204 + ], + [ + -61.113268179996055, + 8.505177617822142 + ], + [ + -60.86538798013957, + 8.505177617822142 + ], + [ + -60.86538798013957, + 8.666717320892204 + ], + [ + -61.113268179996055, + 8.666717320892204 + ] + ] + ], + "type": "Polygon" + } + } + ]], + [[ + { + "type": "Feature", + "properties": {}, + "geometry": { + "coordinates": [ + [ + [ + -66.98965634041855, + -2.552105344245007 + ], + [ + -66.98965634041855, + -6.931424712822178 + ], + [ + -60.673596725229004, + -6.931424712822178 + ], + [ + -60.673596725229004, + -2.552105344245007 + ], + [ + -66.98965634041855, + -2.552105344245007 + ] + ] + ], + "type": "Polygon" + } + } + ]], + [[ + { + "type": "Feature", + "properties": {}, + "geometry": { + "coordinates": [ + [ + [ + -59.40141593993765, + -0.8180702598489091 + ], + [ + -59.40141593993765, + -3.8038880006152453 + ], + [ + -56.08276971246181, + -3.8038880006152453 + ], + [ + -56.08276971246181, + -0.8180702598489091 + ], + [ + -59.40141593993765, + -0.8180702598489091 + ] + ] + ], + "type": "Polygon" + } + } + ]], + [[ + { + "type": "Feature", + "properties": {}, + "geometry": { + "coordinates": [ + [ + [ + -68.36016539573357, + -3.4797077655746023 + ], + [ + -68.36016539573357, + -10.328634044400019 + ], + [ + -60.34168576692953, + -10.328634044400019 + ], + [ + -60.34168576692953, + -3.4797077655746023 + ], + [ + -68.36016539573357, + -3.4797077655746023 + ] + ] + ], + "type": "Polygon" + } + } + ]] +} + + +request = function() + wrk.method = "POST" + wrk.body = geojsons[math.random(1, #geojsons)] + wrk.headers["Content-Type"] = "application/json" + wrk.headers["accept"] = "application/json" + wrk.headers["Authorization"] = "Bearer 1234" + return wrk.format() +end diff --git a/api/tests/benchmark_post.lua b/api/tests/benchmark_table_post.lua similarity index 100% rename from api/tests/benchmark_post.lua rename to api/tests/benchmark_table_post.lua diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 2fce5818..8910dd20 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -1,3 +1,4 @@ +import json import os from pathlib import Path @@ -19,6 +20,30 @@ HEADERS = {"Authorization": f"Bearer {get_settings().auth_token}"} +@pytest.fixture() +def geojson() -> str: + """This geojson contains the cell 895f4261e03ffff in `grid_dataset`""" + s = json.dumps( + { + "type": "Feature", + "properties": {}, + "geometry": { + "coordinates": [ + [ + [-61.11, 8.66], + [-61.11, 8.50], + [-60.86, 8.50], + [-60.86, 8.66], + [-61.11, 8.66], + ] + ], + "type": "Polygon", + }, + } + ) + return s + + @pytest.fixture() def grid_dataset(setup_data_folder) -> str: """Create an empty binary file to be used as grid dataset stub @@ -40,11 +65,11 @@ def grid_dataset(setup_data_folder) -> str: df = pl.DataFrame( { "cell": [ - 618668968382824400, - 619428375900454900, - 619428407452893200, - 619428407943888900, - 619428407676764200, + "895f4261e03ffff", + "865f00007ffffff", + "865f0000fffffff", + "865f00017ffffff", + "865f0001fffffff", ], "landcover": [1, 4, 3, 3, 4], "population": [100, 200, 1, 900, 900], diff --git a/api/tests/test_grid.py b/api/tests/test_grid.py index b00f7337..b75c254c 100644 --- a/api/tests/test_grid.py +++ b/api/tests/test_grid.py @@ -15,11 +15,11 @@ def test_grid_tile(grid_dataset): assert response.status_code == 200 assert pl.read_ipc(response.read()).to_dict(as_series=False) == { "cell": [ - 618668968382824400, - 619428375900454900, - 619428407452893200, - 619428407943888900, - 619428407676764200, + "895f4261e03ffff", + "865f00007ffffff", + "865f0000fffffff", + "865f00017ffffff", + "865f0001fffffff", ], "landcover": [1, 4, 3, 3, 4], "population": [100, 200, 1, 900, 900], @@ -32,11 +32,11 @@ def test_grid_tile_empty_column_param(grid_dataset): assert response.status_code == 200 assert pl.read_ipc(response.read()).to_dict(as_series=False) == { "cell": [ - 618668968382824400, - 619428375900454900, - 619428407452893200, - 619428407943888900, - 619428407676764200, + "895f4261e03ffff", + "865f00007ffffff", + "865f0000fffffff", + "865f00017ffffff", + "865f0001fffffff", ], } @@ -211,20 +211,76 @@ def test_table_filters_multiple_filters(): def test_grid_table(grid_dataset): - filters = [ - {"filter_type": "numerical", "column_name": "population", "operation": "lte", "value": 200}, - {"filter_type": "numerical", "column_name": "population", "operation": "gt", "value": 1}, - ] + body = { + "filters": [ + {"filter_type": "numerical", "column_name": "population", "operation": "lte", "value": 200}, + {"filter_type": "numerical", "column_name": "population", "operation": "gt", "value": 1}, + ] + } - response = test_client.post( - "/grid/table?level=4&order_by=-population", headers=HEADERS, content=json.dumps(filters) - ) + response = test_client.post("/grid/table?level=4&order_by=-population", headers=HEADERS, content=json.dumps(body)) assert response.status_code == 200 assert json.loads(response.read()) == { - "cell": [ - 619428375900454900, - 618668968382824400, + "table": [ + {"column": "cell", "values": ["865f00007ffffff", "895f4261e03ffff"]}, + {"column": "landcover", "values": [4, 1]}, + {"column": "population", "values": [200, 100]}, + ] + } + + +def test_grid_table_geojson(grid_dataset, geojson): + body = { + "filters": [ + {"filter_type": "numerical", "column_name": "population", "operation": "lte", "value": 200}, + {"filter_type": "numerical", "column_name": "population", "operation": "gt", "value": 1}, ], - "landcover": [4, 1], - "population": [200, 100], + "geojson": json.loads(geojson), + } + response = test_client.post("/grid/table?level=4&order_by=-population", headers=HEADERS, content=json.dumps(body)) + assert response.status_code == 200 + assert json.loads(response.read()) == { + "table": [ + {"column": "cell", "values": ["895f4261e03ffff"]}, + {"column": "landcover", "values": [1]}, + {"column": "population", "values": [100]}, + ] + } + + +def test_grid_tile_post_geojson(grid_dataset, geojson): + response = test_client.post( + f"/grid/tile/{grid_dataset}", + params={"columns": ["landcover", "population"]}, + headers=HEADERS, + content=geojson, + ) + assert response.status_code == 200 + assert pl.read_ipc(response.read()).to_dict(as_series=False) == { + "cell": ["895f4261e03ffff"], + "landcover": [1], + "population": [100], } + + +def test_grid_tile_post_geojson_404(grid_dataset, geojson): + response = test_client.post( + "/grid/tile/8439181ffffffff", + params={"columns": ["landcover", "population"]}, + headers=HEADERS, + content=geojson, + ) + + assert response.status_code == 404 + + +def test_grid_tile_post_wrong_column(grid_dataset, geojson): + response = test_client.post( + f"/grid/tile/{grid_dataset}", + params={"columns": ["I DO NOT EXIST"]}, + headers=HEADERS, + content=geojson, + ) + + assert response.status_code == 400 + assert response.json() == {"detail": "One or more of the specified columns is not valid"} diff --git a/api/tests/test_zonal_stats.py b/api/tests/test_zonal_stats.py index 13236dca..ef3078b1 100644 --- a/api/tests/test_zonal_stats.py +++ b/api/tests/test_zonal_stats.py @@ -42,7 +42,6 @@ def test_no_geojson_raises_422(tif_file): "loc": ["body"], "msg": "Field required", "type": "missing", - "url": "https://errors.pydantic.dev/2.6/v/missing", } ] } diff --git a/science/notebooks/check_combine_results.ipynb b/science/notebooks/check_combine_results.ipynb index 3fe61ffa..415aadc9 100644 --- a/science/notebooks/check_combine_results.ipynb +++ b/science/notebooks/check_combine_results.ipynb @@ -3,18 +3,17 @@ { "cell_type": "code", "execution_count": null, - "id": "e84373c7-a5e7-47c8-95a3-d2db7ade2e29", + "id": "0", "metadata": {}, "outputs": [], "source": [ - "import polars as pl\n", - "import polars.selectors as cs" + "import polars as pl" ] }, { "cell_type": "code", "execution_count": null, - "id": "750642ad-f9fc-434e-86f8-783cc41d533c", + "id": "1", "metadata": {}, "outputs": [], "source": [ @@ -25,16 +24,16 @@ { "cell_type": "code", "execution_count": null, - "id": "d5012a5d-ea51-4b01-8ccf-055db2feb3ec", + "id": "2", "metadata": {}, "outputs": [], "source": [ - "schema = {\"cell\":pl.UInt64, \"b\":pl.Float32, \"c\":pl.String}\n", + "schema = {\"cell\": pl.UInt64, \"b\": pl.Float32, \"c\": pl.String}\n", "df = pl.DataFrame(schema=schema)\n", "\n", - "join1 = pl.DataFrame({\"cell\": [1,2,3], \"b\": [9., 9., 9.]}, schema_overrides=schema)\n", - "join2 = pl.DataFrame({\"cell\": [1,2,3], \"c\": [\"a\", \"b\", \"c\"]}, schema_overrides=schema)\n", - "join3 = pl.DataFrame({\"cell\": [4,5,6], \"c\": [\"a\", \"b\", \"c\"]}, schema_overrides=schema)\n", + "join1 = pl.DataFrame({\"cell\": [1, 2, 3], \"b\": [9.0, 9.0, 9.0]}, schema_overrides=schema)\n", + "join2 = pl.DataFrame({\"cell\": [1, 2, 3], \"c\": [\"a\", \"b\", \"c\"]}, schema_overrides=schema)\n", + "join3 = pl.DataFrame({\"cell\": [4, 5, 6], \"c\": [\"a\", \"b\", \"c\"]}, schema_overrides=schema)\n", "\n", "tojoin = [join1, join2, join3]\n", "\n", @@ -46,24 +45,23 @@ { "cell_type": "code", "execution_count": null, - "id": "af677cc8-34c5-4259-84c2-1d90a5bf3040", + "id": "3", "metadata": {}, "outputs": [], "source": [ "import polars as pl\n", "\n", "# Define the initial DataFrame\n", - "df = pl.DataFrame({\n", - " \"cell\": [1, 2, 3, 1, 2, 3, 4, 5, 6],\n", - " \"b\": [9.0, 9.0, 9.0, None, None, None, None, None, None],\n", - " \"c\": [None, None, None, \"a\", \"b\", \"c\", \"a\", \"b\", \"c\"]\n", - "})\n", + "df = pl.DataFrame(\n", + " {\n", + " \"cell\": [1, 2, 3, 1, 2, 3, 4, 5, 6],\n", + " \"b\": [9.0, 9.0, 9.0, None, None, None, None, None, None],\n", + " \"c\": [None, None, None, \"a\", \"b\", \"c\", \"a\", \"b\", \"c\"],\n", + " }\n", + ")\n", "\n", "# Perform the group by and aggregation\n", - "agg_df = df.groupby(\"cell\").agg([\n", - " pl.col(\"b\").max().alias(\"b\"),\n", - " pl.col(\"c\").max().alias(\"c\")\n", - "])\n", + "agg_df = df.groupby(\"cell\").agg([pl.col(\"b\").max().alias(\"b\"), pl.col(\"c\").max().alias(\"c\")])\n", "\n", "# Sort the resulting DataFrame by the 'cell' column\n", "result_df = agg_df.sort(\"cell\")\n", diff --git a/science/notebooks/merge_entrega_roberto.ipynb b/science/notebooks/merge_entrega_roberto.ipynb index e7913753..f1ea486d 100644 --- a/science/notebooks/merge_entrega_roberto.ipynb +++ b/science/notebooks/merge_entrega_roberto.ipynb @@ -6,9 +6,9 @@ "metadata": {}, "outputs": [], "source": [ - "import polars as pl\n", "from pathlib import Path\n", - "import h3ronpy.polars" + "\n", + "import polars as pl" ] }, { @@ -17,7 +17,7 @@ "metadata": {}, "outputs": [], "source": [ - "csvs = list(Path(\"../raw/ENTREGA UNO MUESTRAS HEXA CSV 18072024/\").glob(\"*.CSV\"))" + "csvs = list(Path(\"../data/raw/ENTREGA UNO MUESTRAS HEXA CSV 18072024\").glob(\"*.CSV\"))" ] }, { @@ -60,7 +60,8 @@ "OVERVIEW_LEVEL = CELLS_RES - 5\n", "\n", "df = df.with_columns(\n", - " pl.col(\"cell\").h3.change_resolution(OVERVIEW_LEVEL).alias(\"tile_id\") # type: ignore[attr-defined]\n", + " pl.col(\"cell\").h3.change_resolution(OVERVIEW_LEVEL).h3.cells_to_string().alias(\"tile_id\"), # type: ignore[attr-defined]\n", + " pl.col(\"cell\").h3.cells_to_string(),\n", ")\n", "partition_dfs = df.partition_by([\"tile_id\"], as_dict=True, include_key=False)" ] @@ -84,12 +85,12 @@ "for tile_group, tile_df in partition_dfs.items():\n", " if tile_df.shape[0] == 0: # todo: skip empty tiles ?\n", " continue\n", - " tile_id = hex(tile_group[0])[2:]\n", - " filename = Path(\"grid/1\") / (tile_id + \".arrow\")\n", + " tile_id = tile_group[0]\n", + " filename = Path(\"../data/processed/grid/1\") / (tile_id + \".arrow\")\n", " if tile_id in seen_tiles:\n", - " tile_df = pl.concat(\n", - " [pl.read_ipc(filename), tile_df], how=\"vertical_relaxed\"\n", - " ).unique(subset=[\"cell\"])\n", + " tile_df = pl.concat([pl.read_ipc(filename), tile_df], how=\"vertical_relaxed\").unique(\n", + " subset=[\"cell\"]\n", + " )\n", " tile_df.write_parquet(filename)\n", " n_cells += len(tile_df)\n", " else:\n",