From 1cb29e3ce7f24954a14054be51d375f9851d533c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 25 Nov 2024 11:14:53 +0100 Subject: [PATCH 1/5] build: add devcontainer setup (#1725) Add devcontainer configuration with special customizations for VS Code. --------- Co-authored-by: Enrique Gonzalez Paredes --- .devcontainer/.vscode/launch.json | 24 +++++++++++++++ .devcontainer/Dockerfile | 5 ++++ .devcontainer/devcontainer.json | 49 +++++++++++++++++++++++++++++++ .devcontainer/setup.sh | 10 +++++++ .gitignore | 2 +- 5 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 .devcontainer/.vscode/launch.json create mode 100644 .devcontainer/Dockerfile create mode 100644 .devcontainer/devcontainer.json create mode 100755 .devcontainer/setup.sh diff --git a/.devcontainer/.vscode/launch.json b/.devcontainer/.vscode/launch.json new file mode 100644 index 0000000000..f682b56388 --- /dev/null +++ b/.devcontainer/.vscode/launch.json @@ -0,0 +1,24 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Current File (just my code)", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": true + }, + { + "name": "Python: Current File (all)", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": false + } + ] +} diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 0000000000..414f2d0292 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,5 @@ +FROM mcr.microsoft.com/devcontainers/python:1-3.10-bookworm +RUN apt-get update \ + && export DEBIAN_FRONTEND=noninteractive && apt-get install -y libboost-dev \ + && apt-get clean && rm -rf /var/cache/apt/* && rm -rf /var/lib/apt/lists/* && rm -rf /tmp/* +RUN curl -LsSf https://astral.sh/uv/install.sh | env UV_INSTALL_DIR="/bin" sh diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000000..7dc4b2f08c --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,49 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/python +{ + "name": "Python 3", + // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile + "build": { + "dockerfile": "Dockerfile" + }, + + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + + // Use 'postCreateCommand' to run commands after the container is created. + "postCreateCommand": "bash .devcontainer/setup.sh", + + "containerEnv": { + "PRE_COMMIT_HOME": "/workspaces/gt4py/.caches/pre-commit" + }, + + // Configure tool-specific properties. + "customizations": { + // Configure properties specific to VS Code. + "vscode": { + // Set *default* container specific settings.json values on container create. + "settings": { + "python.formatting.provider": "ruff", + "python.testing.pytestEnabled": true, + "python.defaultInterpreterPath": "/workspaces/gt4py/.venv/bin/python", + "files.insertFinalNewline": true, + "python.terminal.activateEnvironment": true, + "cmake.ignoreCMakeListsMissing": true + }, + "extensions": [ + "charliermarsh.ruff", + "donjayamanne.githistory", + "github.vscode-github-actions", + "lextudio.restructuredtext", + "ms-python.python", + "ms-vsliveshare.vsliveshare", + "swyddfa.esbonio" + ] + } + } + // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "root" +} diff --git a/.devcontainer/setup.sh b/.devcontainer/setup.sh new file mode 100755 index 0000000000..d23dda9dea --- /dev/null +++ b/.devcontainer/setup.sh @@ -0,0 +1,10 @@ +#!/bin/sh + +ln -sfn /workspaces/gt4py/.devcontainer/.vscode /workspaces/gt4py/.vscode +uv venv .venv +source .venv/bin/activate +uv pip install -r requirements-dev.txt +uv pip install -e . +uv pip install -i https://test.pypi.org/simple/ atlas4py +pre-commit install --install-hooks +deactivate diff --git a/.gitignore b/.gitignore index 5792b8a9b7..b1c8ed26e9 100644 --- a/.gitignore +++ b/.gitignore @@ -159,5 +159,5 @@ venv.bak/ ### Others ### .obsidian - coverage.json +.caches From d7f55522beacfc77c12964f6bbb1962899d8821d Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 25 Nov 2024 14:22:24 +0100 Subject: [PATCH 2/5] feat[next]: remove NeighborTableOffsetProvider, use gtx.as_connectivity (#1729) User-facing change: use `gtx.as_connectivity` to create a connectivity/neighbor table instead of `NeighborTableOffsetProvider` which is deprecated (and the backward-compatible mechanism broken for some use-cases). The internal concepts of `Connectivity` and `NeighborTable` are updated. `ConnectivityType` is introduced which contains the compile-time info of a `Connectivity`. See ADR 19. Additionally, the compile-time info is used (instead of the run-time connectivities) in many places of the toolchain when possible. --- .gitpod/.vscode/launch.json | 13 +- .../0008-Mapping_Domain_to_Cpp-Backend.md | 2 +- docs/development/ADRs/0019-Connectivities.md | 55 +++++ docs/user/next/QuickstartGuide.md | 6 +- .../exercises/2_divergence_exercise.ipynb | 4 +- .../2_divergence_exercise_solution.ipynb | 4 +- .../exercises/3_gradient_exercise.ipynb | 4 +- .../3_gradient_exercise_solution.ipynb | 4 +- .../workshop/exercises/4_curl_exercise.ipynb | 4 +- .../exercises/4_curl_exercise_solution.ipynb | 4 +- .../exercises/5_vector_laplace_exercise.ipynb | 10 +- .../5_vector_laplace_exercise_solution.ipynb | 10 +- .../8_diffusion_exercise_solution.ipynb | 8 +- docs/user/next/workshop/slides/slides_2.ipynb | 10 +- src/gt4py/_core/definitions.py | 10 +- src/gt4py/next/__init__.py | 6 +- src/gt4py/next/common.py | 170 ++++++++++---- src/gt4py/next/constructors.py | 24 +- src/gt4py/next/embedded/nd_array_field.py | 35 ++- src/gt4py/next/ffront/decorator.py | 47 ++-- src/gt4py/next/ffront/experimental.py | 2 +- src/gt4py/next/ffront/fbuiltins.py | 30 +-- src/gt4py/next/iterator/embedded.py | 215 +++++++++++------- .../next/iterator/ir_utils/domain_utils.py | 26 +-- src/gt4py/next/iterator/runtime.py | 10 +- .../iterator/transforms/collapse_tuple.py | 6 +- src/gt4py/next/iterator/transforms/cse.py | 6 +- .../iterator/transforms/fuse_as_fieldop.py | 9 +- .../next/iterator/transforms/global_tmps.py | 4 +- .../next/iterator/transforms/inline_scalar.py | 4 +- .../next/iterator/transforms/pass_manager.py | 29 ++- .../transforms/pass_manager_legacy.py | 14 +- .../next/iterator/transforms/unroll_reduce.py | 28 +-- .../next/iterator/type_system/inference.py | 34 +-- .../iterator/type_system/type_synthesizer.py | 48 ++-- src/gt4py/next/otf/arguments.py | 54 +---- .../codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py | 76 +------ .../codegens/gtfn/gtfn_module.py | 47 ++-- .../codegens/gtfn/itir_to_gtfn_ir.py | 31 +-- .../runners/dace_common/dace_backend.py | 21 +- .../runners/dace_common/utility.py | 15 +- .../runners/dace_fieldview/gtir_dataflow.py | 75 +++--- .../runners/dace_fieldview/gtir_sdfg.py | 33 ++- .../runners/dace_fieldview/workflow.py | 6 +- .../runners/dace_iterator/__init__.py | 53 +++-- .../runners/dace_iterator/itir_to_sdfg.py | 45 ++-- .../runners/dace_iterator/itir_to_tasklet.py | 97 ++++---- .../runners/dace_iterator/utility.py | 10 +- .../runners/dace_iterator/workflow.py | 6 +- .../next/program_processors/runners/gtfn.py | 16 +- .../program_processors/runners/roundtrip.py | 16 +- .../next/type_system/type_specifications.py | 1 + .../feature_tests/dace/test_orchestration.py | 86 ++++--- .../ffront_tests/ffront_test_utils.py | 91 +++++--- .../ffront_tests/test_execution.py | 36 +-- .../ffront_tests/test_external_local_field.py | 12 +- .../ffront_tests/test_gt4py_builtins.py | 18 +- .../test_temporaries_with_sizes.py | 2 +- .../iterator_tests/test_builtins.py | 40 +--- .../test_strided_offset_provider.py | 9 +- .../ffront_tests/test_ffront_fvm_nabla.py | 64 +++--- .../multi_feature_tests/fvm_nabla_setup.py | 56 +++-- .../iterator_tests/test_fvm_nabla.py | 114 ++++------ .../test_with_toy_connectivity.py | 38 ++-- tests/next_tests/toy_connectivity.py | 7 + tests/next_tests/unit_tests/conftest.py | 25 +- .../embedded_tests/test_nd_array_field.py | 15 +- .../test_embedded_field_with_list.py | 10 +- .../iterator_tests/test_runtime_domain.py | 10 +- .../iterator_tests/test_type_inference.py | 34 +-- .../transforms_tests/test_cse.py | 14 +- .../transforms_tests/test_domain_inference.py | 13 +- .../transforms_tests/test_fuse_as_fieldop.py | 13 +- .../transforms_tests/test_global_tmps.py | 8 +- .../transforms_tests/test_prune_casts.py | 6 +- .../transforms_tests/test_unroll_reduce.py | 69 ++++-- .../gtfn_tests/test_itir_to_gtfn_ir.py | 4 +- .../runners_tests/dace_tests/test_dace.py | 24 +- .../dace_tests/test_gtir_to_sdfg.py | 134 ++++++----- .../unit_tests/test_constructors.py | 14 +- 80 files changed, 1293 insertions(+), 1170 deletions(-) create mode 100644 docs/development/ADRs/0019-Connectivities.md diff --git a/.gitpod/.vscode/launch.json b/.gitpod/.vscode/launch.json index f682b56388..b25a182648 100644 --- a/.gitpod/.vscode/launch.json +++ b/.gitpod/.vscode/launch.json @@ -6,7 +6,7 @@ "configurations": [ { "name": "Python: Current File (just my code)", - "type": "python", + "type": "debugpy", "request": "launch", "program": "${file}", "console": "integratedTerminal", @@ -14,11 +14,20 @@ }, { "name": "Python: Current File (all)", - "type": "python", + "type": "debugpy", "request": "launch", "program": "${file}", "console": "integratedTerminal", "justMyCode": false + }, + { + "name": "Python: Debug Tests", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "purpose": ["debug-test"], + "console": "integratedTerminal", + "justMyCode": true } ] } diff --git a/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md b/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md index a1ee8575d2..1ce83431ee 100644 --- a/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md +++ b/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md @@ -20,7 +20,7 @@ The Python embedded execution for Iterator IR keeps track of the current locatio ### Python side -On the Python side, we label dimensions of fields with the location type, e.g. `Edge` or `Vertex`. The domain uses `named_ranges` that uses the same location types to express _where_ to iterate, e.g. `named_range(Vertex, range(0, 100))` is an iteration over the `Vertex` dimension, no order in the domain is required. Additionally, the `Connectivity` (aka `NeighborTableOffsetProvider` in the current implementation) describes the mapping between location types. +On the Python side, we label dimensions of fields with the location type, e.g. `Edge` or `Vertex`. The domain uses `named_ranges` that uses the same location types to express _where_ to iterate, e.g. `named_range(Vertex, range(0, 100))` is an iteration over the `Vertex` dimension, no order in the domain is required. Additionally, the `Connectivity` describes the mapping between location types. ### C++ side diff --git a/docs/development/ADRs/0019-Connectivities.md b/docs/development/ADRs/0019-Connectivities.md new file mode 100644 index 0000000000..76e85e49a6 --- /dev/null +++ b/docs/development/ADRs/0019-Connectivities.md @@ -0,0 +1,55 @@ +--- +tags: [] +--- + +# [Connectivities] + +- **Status**: valid +- **Authors**: Hannes Vogt (@havogt) +- **Created**: 2024-11-08 +- **Updated**: 2024-11-08 + +The representation of Connectivities (neighbor tables, `NeighborTableOffsetProvider`) and their identifier (offset tag, `FieldOffset`, etc.) was extended and modified based on the needs of different parts of the toolchain. Here we outline the ideas for consolidating the different closely-related concepts. + +## History + +In the early days of Iterator IR (ITIR), an `offset` was a literal in the IR. Its meaning was only provided at execution time by a mapping from `offset` tag to an entity that we labelled `OffsetProvider`. We had mainly 2 kinds of `OffsetProvider`: a `Dimension` representing a Cartesian shift and a `NeighborTableOffsetProvider` for unstructured shifts. Since the type of `offset` needs to be known for compilation (strided for Cartesian, lookup-table for unstructured), this prevents a clean interface for ahead-of-time compilation. +For the frontend type-checking we later introduce a `FieldOffset` which contained type information of the mapped dimensions. +For (field-view) embedded we introduced a `ConnectivityField` (now `Connectivity`) which could be generated from the OffsetProvider information. + +These different concepts had overlap but were not 1-to-1 replacements. + +## Decision + +We update and introduce the following concepts + +### Conceptual definitions + +**Connectivity** is a mapping from index (or product of indices) to index. It covers 1-to-1 mappings, e.g. Cartesian shifts, NeighborTables (2D mappings) and dynamic Cartesian shifts. + +**NeighborConnectivity** is a 2D mapping of the N neighbors of a Location A to a Location B. + +**NeighborTable** is a _NeighborConnectivity_ backed by a buffer. + +**ConnectivityType**, **NeighborConnectivityType** contains all information that is needed for compilation. + +### Full definitions + +See `next.common` module + +Note: Currently, the compiled backends supports only `NeighborConnectivity`s that are `NeighborTable`s. We do not yet encode this in the type and postpone discussion to the point where we support alternative implementations (e.g. `StridedNeighborConnectivity`). + +## Which parts of the toolchain use which concept? + +### Embedded + +Embedded execution of field-view supports any kind of `Connectivity`. +Embedded execution of iterator (local) view supports only `NeighborConnectivity`s. + +### IR transformations and compiled backends + +All transformations and code-generation should use `ConnectivityType`, not the `Connectivity` which contains the runtime mapping. + +Note, currently the `global_tmps` pass uses runtime information, therefore this is not strictly enforced. + +The only supported `Connectivity`s in compiled backends (currently) are `NeighborTable`s. diff --git a/docs/user/next/QuickstartGuide.md b/docs/user/next/QuickstartGuide.md index 81604c7770..2cb6647519 100644 --- a/docs/user/next/QuickstartGuide.md +++ b/docs/user/next/QuickstartGuide.md @@ -155,8 +155,6 @@ This section approaches the pseudo-laplacian by introducing the required APIs pr - [Using reductions on connected mesh elements](#Using-reductions-on-connected-mesh-elements) - [Implementing the actual pseudo-laplacian](#Implementing-the-pseudo-laplacian) -+++ - #### Defining the mesh and its connectivities The examples related to unstructured meshes use the mesh below. The edges (in blue) and the cells (in red) are numbered with zero-based indices. @@ -237,7 +235,7 @@ E2C = gtx.FieldOffset("E2C", source=CellDim, target=(EdgeDim,E2CDim)) Note that the field offset does not contain the actual connectivity table, that's provided through an _offset provider_: ```{code-cell} ipython3 -E2C_offset_provider = gtx.NeighborTableOffsetProvider(edge_to_cell_table, EdgeDim, CellDim, 2) +E2C_offset_provider = gtx.as_connectivity([EdgeDim, E2CDim], codomain=CellDim, data=edge_to_cell_table, skip_value=-1) ``` The field operator `nearest_cell_to_edge` below shows an example of applying this transform. There is a little twist though: the subscript in `E2C[0]` means that only the value of the first connected cell is taken, the second (if exists) is ignored. @@ -385,7 +383,7 @@ As explained in the section outline, the pseudo-laplacian needs the cell-to-edge C2EDim = gtx.Dimension("C2E", kind=gtx.DimensionKind.LOCAL) C2E = gtx.FieldOffset("C2E", source=EdgeDim, target=(CellDim, C2EDim)) -C2E_offset_provider = gtx.NeighborTableOffsetProvider(cell_to_edge_table, CellDim, EdgeDim, 3) +C2E_offset_provider = gtx.as_connectivity([CellDim, C2EDim], codomain=EdgeDim, data=cell_to_edge_table, skip_value=-1) ``` **Weights of edge differences:** diff --git a/docs/user/next/workshop/exercises/2_divergence_exercise.ipynb b/docs/user/next/workshop/exercises/2_divergence_exercise.ipynb index 50349e52b0..b0a1980d0f 100644 --- a/docs/user/next/workshop/exercises/2_divergence_exercise.ipynb +++ b/docs/user/next/workshop/exercises/2_divergence_exercise.ipynb @@ -81,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "5dbd2f62", "metadata": {}, "outputs": [], @@ -113,7 +113,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", "\n", " divergence_gt4py = gtx.zeros(cell_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/2_divergence_exercise_solution.ipynb b/docs/user/next/workshop/exercises/2_divergence_exercise_solution.ipynb index 6baac2b8c0..573ee6a44e 100644 --- a/docs/user/next/workshop/exercises/2_divergence_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/2_divergence_exercise_solution.ipynb @@ -86,7 +86,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "5dbd2f62", "metadata": {}, "outputs": [], @@ -118,7 +118,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", "\n", " divergence_gt4py = gtx.zeros(cell_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb b/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb index c8914120d3..2b422b1823 100644 --- a/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb +++ b/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb @@ -80,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "84b02762", "metadata": {}, "outputs": [], @@ -110,7 +110,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", "\n", " gradient_gt4py_x = gtx.zeros(cell_domain, allocator=backend)\n", " gradient_gt4py_y = gtx.zeros(cell_domain, allocator=backend)\n", diff --git a/docs/user/next/workshop/exercises/3_gradient_exercise_solution.ipynb b/docs/user/next/workshop/exercises/3_gradient_exercise_solution.ipynb index 5e940a4b71..85044b989f 100644 --- a/docs/user/next/workshop/exercises/3_gradient_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/3_gradient_exercise_solution.ipynb @@ -93,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "84b02762", "metadata": {}, "outputs": [], @@ -123,7 +123,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", "\n", " gradient_gt4py_x = gtx.zeros(cell_domain, allocator=backend)\n", " gradient_gt4py_y = gtx.zeros(cell_domain, allocator=backend)\n", diff --git a/docs/user/next/workshop/exercises/4_curl_exercise.ipynb b/docs/user/next/workshop/exercises/4_curl_exercise.ipynb index 4a6b37baf7..dc321f1bdd 100644 --- a/docs/user/next/workshop/exercises/4_curl_exercise.ipynb +++ b/docs/user/next/workshop/exercises/4_curl_exercise.ipynb @@ -102,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "5b6ffc9e", "metadata": {}, "outputs": [], @@ -134,7 +134,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", "\n", " curl_gt4py = gtx.zeros(vertex_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/4_curl_exercise_solution.ipynb b/docs/user/next/workshop/exercises/4_curl_exercise_solution.ipynb index 065cf02de7..251fe8239a 100644 --- a/docs/user/next/workshop/exercises/4_curl_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/4_curl_exercise_solution.ipynb @@ -107,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "5b6ffc9e", "metadata": {}, "outputs": [], @@ -139,7 +139,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", "\n", " curl_gt4py = gtx.zeros(vertex_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/5_vector_laplace_exercise.ipynb b/docs/user/next/workshop/exercises/5_vector_laplace_exercise.ipynb index 832375a86b..30f568de6f 100644 --- a/docs/user/next/workshop/exercises/5_vector_laplace_exercise.ipynb +++ b/docs/user/next/workshop/exercises/5_vector_laplace_exercise.ipynb @@ -228,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "f9cfc097", "metadata": {}, "outputs": [], @@ -272,10 +272,10 @@ " edge_orientation_cell.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", - " e2v_connectivity = gtx.NeighborTableOffsetProvider(e2v_table, E, V, 2, has_skip_values=False)\n", - " e2c_connectivity = gtx.NeighborTableOffsetProvider(e2c_table, E, C, 2, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", + " e2v_connectivity = gtx.as_connectivity([E, E2VDim], codomain=V, data=e2v_table)\n", + " e2c_connectivity = gtx.as_connectivity([E, E2CDim], codomain=C, data=e2c_table)\n", "\n", " laplacian_gt4py = gtx.zeros(edge_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/5_vector_laplace_exercise_solution.ipynb b/docs/user/next/workshop/exercises/5_vector_laplace_exercise_solution.ipynb index be846d199d..eaeb8c7b02 100644 --- a/docs/user/next/workshop/exercises/5_vector_laplace_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/5_vector_laplace_exercise_solution.ipynb @@ -249,7 +249,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "f9cfc097", "metadata": {}, "outputs": [], @@ -293,10 +293,10 @@ " edge_orientation_cell.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", - " e2v_connectivity = gtx.NeighborTableOffsetProvider(e2v_table, E, V, 2, has_skip_values=False)\n", - " e2c_connectivity = gtx.NeighborTableOffsetProvider(e2c_table, E, C, 2, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", + " e2v_connectivity = gtx.as_connectivity([E, E2VDim], codomain=V, data=e2v_table)\n", + " e2c_connectivity = gtx.as_connectivity([E, E2CDim], codomain=C, data=e2c_table)\n", "\n", " laplacian_gt4py = gtx.zeros(edge_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/8_diffusion_exercise_solution.ipynb b/docs/user/next/workshop/exercises/8_diffusion_exercise_solution.ipynb index d4bcdb33d5..b278cee26d 100644 --- a/docs/user/next/workshop/exercises/8_diffusion_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/8_diffusion_exercise_solution.ipynb @@ -118,7 +118,7 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": null, "id": "f9cfc097", "metadata": {}, "outputs": [], @@ -156,10 +156,8 @@ " dt,\n", " )\n", "\n", - " e2c2v_connectivity = gtx.NeighborTableOffsetProvider(\n", - " e2c2v_table, E, V, 4, has_skip_values=False\n", - " )\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", + " e2c2v_connectivity = gtx.as_connectivity([E, E2C2VDim], codomain=V, data=e2c2v_table)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", "\n", " diffusion_step(\n", " u,\n", diff --git a/docs/user/next/workshop/slides/slides_2.ipynb b/docs/user/next/workshop/slides/slides_2.ipynb index 1e8925087f..c6967df4b2 100644 --- a/docs/user/next/workshop/slides/slides_2.ipynb +++ b/docs/user/next/workshop/slides/slides_2.ipynb @@ -281,17 +281,19 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "6d30a5e1", "metadata": {}, "outputs": [], "source": [ - "E2C_offset_provider = gtx.NeighborTableOffsetProvider(e2c_table, Edge, Cell, 2)" + "E2C_offset_provider = gtx.as_connectivity(\n", + " [Edge, E2CDim], codomain=Cell, data=e2c_table, skip_value=-1\n", + ")" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "d62f6c98", "metadata": {}, "outputs": [ @@ -311,7 +313,7 @@ " return cell_field(E2C[0]) # 0th index to isolate edge dimension\n", "\n", "\n", - "@gtx.program # uses skip_values, therefore we cannot use embedded\n", + "@gtx.program\n", "def run_nearest_cell_to_edge(\n", " cell_field: gtx.Field[Dims[Cell], float64], edge_field: gtx.Field[Dims[Edge], float64]\n", "):\n", diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 9d07b2eb79..8f62788b8f 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -439,13 +439,21 @@ def ndim(self) -> int: ... @property def shape(self) -> tuple[int, ...]: ... + @property + def strides(self) -> tuple[int, ...]: ... + @property def dtype(self) -> Any: ... + @property + def itemsize(self) -> int: ... + def item(self) -> Any: ... def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: ... + def any(self) -> bool: ... + def __getitem__(self, item: Any) -> NDArrayObject: ... def __abs__(self) -> NDArrayObject: ... @@ -496,4 +504,4 @@ def __and__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... def __or__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __xor(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... + def __xor__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index 80bb276c70..4fa9215706 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -20,6 +20,7 @@ from . import common, ffront, iterator, program_processors from .common import ( + Connectivity, Dimension, DimensionKind, Dims, @@ -39,8 +40,7 @@ from .ffront.fbuiltins import * # noqa: F403 [undefined-local-with-import-star] explicitly reexport all from fbuiltins.__all__ from .ffront.fbuiltins import FieldOffset from .iterator.embedded import ( - NeighborTableOffsetProvider, - StridedNeighborOffsetProvider, + NeighborTableOffsetProvider, # TODO(havogt): deprecated index_field, np_as_located_field, ) @@ -61,6 +61,7 @@ "Dimension", "DimensionKind", "Field", + "Connectivity", "GridType", "domain", "Domain", @@ -75,7 +76,6 @@ "as_connectivity", # from iterator "NeighborTableOffsetProvider", - "StridedNeighborOffsetProvider", "index_field", "np_as_located_field", # from ffront diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 4aa0dd03aa..9b2870e1c0 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -18,7 +18,6 @@ from collections.abc import Mapping, Sequence import numpy as np -import numpy.typing as npt from gt4py._core import definitions as core_defs from gt4py.eve import utils @@ -95,7 +94,7 @@ def __str__(self) -> str: def __call__(self, val: int) -> NamedIndex: return NamedIndex(self, val) - def __add__(self, offset: int) -> ConnectivityField: + def __add__(self, offset: int) -> Connectivity: # TODO(sf-n): just to avoid circular import. Move or refactor the FieldOffset to avoid this. from gt4py.next.ffront import fbuiltins @@ -104,7 +103,7 @@ def __add__(self, offset: int) -> ConnectivityField: dimension_to_implicit_offset(self.value), source=self, target=(self,) )[offset] - def __sub__(self, offset: int) -> ConnectivityField: + def __sub__(self, offset: int) -> Connectivity: return self + (-offset) @@ -678,6 +677,9 @@ def codomain(self) -> type[core_defs.ScalarT] | Dimension: ... @property def dtype(self) -> core_defs.DType[core_defs.ScalarT]: ... + # TODO(havogt) + # This property is wrong, because for a function field we would not know to which NDArrayObject we want to convert + # at the very least, we need to take an allocator and rename this to `as_ndarray`. @property def ndarray(self) -> core_defs.NDArrayObject: ... @@ -688,7 +690,7 @@ def __str__(self) -> str: def asnumpy(self) -> np.ndarray: ... @abc.abstractmethod - def premap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... + def premap(self, index_field: Connectivity | fbuiltins.FieldOffset) -> Field: ... @abc.abstractmethod def restrict(self, item: AnyIndexSpec) -> Self: ... @@ -700,8 +702,8 @@ def as_scalar(self) -> core_defs.ScalarT: ... @abc.abstractmethod def __call__( self, - index_field: ConnectivityField | fbuiltins.FieldOffset, - *args: ConnectivityField | fbuiltins.FieldOffset, + index_field: Connectivity | fbuiltins.FieldOffset, + *args: Connectivity | fbuiltins.FieldOffset, ) -> Field: ... @abc.abstractmethod @@ -811,12 +813,64 @@ def remapping(cls) -> ConnectivityKind: return cls.ALTER_DIMS | cls.ALTER_STRUCT +@dataclasses.dataclass(frozen=True) +class ConnectivityType: # TODO(havogt): would better live in type_specifications but would have to solve a circular import + domain: tuple[Dimension, ...] + codomain: Dimension + skip_value: Optional[core_defs.IntegralScalar] + dtype: core_defs.DType + + @property + def has_skip_values(self) -> bool: + return self.skip_value is not None + + +@dataclasses.dataclass(frozen=True) +class NeighborConnectivityType(ConnectivityType): + # TODO(havogt): refactor towards encoding this information in the local dimensions of the ConnectivityType.domain + max_neighbors: int + + @property + def source_dim(self) -> Dimension: + return self.domain[0] + + @property + def neighbor_dim(self) -> Dimension: + return self.domain[1] + + @runtime_checkable # type: ignore[misc] # DimT should be covariant, but then it breaks in other places -class ConnectivityField(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): +class Connectivity(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): @property @abc.abstractmethod - def codomain(self) -> DimT: ... + def codomain(self) -> DimT: + """ + The `codomain` is the set of all indices in a certain `Dimension`. + + We use the `Dimension` itself to describe the (infinite) set of all indices. + + Note: + We could restrict the infinite codomain to only the indices that are actually contained in the mapping. + Currently, this would just complicate implementation as we do not use this information. + """ + + def __gt_type__(self) -> ConnectivityType: + if is_neighbor_connectivity(self): + return NeighborConnectivityType( + domain=self.domain.dims, + codomain=self.codomain, + dtype=self.dtype, + skip_value=self.skip_value, + max_neighbors=self.ndarray.shape[1], + ) + else: + return ConnectivityType( + domain=self.domain.dims, + codomain=self.codomain, + dtype=self.dtype, + skip_value=self.skip_value, + ) @property def kind(self) -> ConnectivityKind: @@ -831,61 +885,61 @@ def skip_value(self) -> Optional[core_defs.IntegralScalar]: ... # Operators def __abs__(self) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __neg__(self) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __invert__(self) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __eq__(self, other: Any) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __ne__(self, other: Any) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __add__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __radd__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __sub__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __rsub__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __mul__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __rmul__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __truediv__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __rtruediv__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __floordiv__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __rfloordiv__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __pow__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __and__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __or__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __xor__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") # Utility function to construct a `Field` from different buffer representations. @@ -911,38 +965,58 @@ def _connectivity( domain: Optional[DomainLike] = None, dtype: Optional[core_defs.DType] = None, skip_value: Optional[core_defs.IntegralScalar] = None, -) -> ConnectivityField: +) -> Connectivity: raise NotImplementedError -@runtime_checkable -class Connectivity(Protocol): - max_neighbors: int - has_skip_values: bool - origin_axis: Dimension - neighbor_axis: Dimension - index_type: type[int] | type[np.int32] | type[np.int64] +class NeighborConnectivity(Connectivity, Protocol): + # TODO(havogt): work towards encoding this properly in the type + def __gt_type__(self) -> NeighborConnectivityType: ... + - def mapped_index( - self, cur_index: int | np.integer, neigh_index: int | np.integer - ) -> Optional[int | np.integer]: - """Return neighbor index.""" +def is_neighbor_connectivity(obj: Any) -> TypeGuard[NeighborConnectivity]: + if not isinstance(obj, Connectivity): + return False + domain_dims = obj.domain.dims + return ( + len(domain_dims) == 2 + and domain_dims[0].kind is DimensionKind.HORIZONTAL + and domain_dims[1].kind is DimensionKind.LOCAL + ) -@runtime_checkable -class NeighborTable(Connectivity, Protocol): - table: npt.NDArray +class NeighborTable( + NeighborConnectivity, Protocol +): # TODO(havogt): try to express by inheriting from NdArrayConnectivityField (but this would require a protocol to move it out of `embedded.nd_array_field`) + @property + def ndarray(self) -> core_defs.NDArrayObject: + # Note that this property is currently already there from inheriting from `Field`, + # however this seems wrong, therefore we explicitly introduce it here (or it should come + # implicitly from the `NdArrayConnectivityField` protocol). + ... -OffsetProviderElem: TypeAlias = Dimension | Connectivity +def is_neighbor_table(obj: Any) -> TypeGuard[NeighborTable]: + return is_neighbor_connectivity(obj) and hasattr(obj, "ndarray") + + +OffsetProviderElem: TypeAlias = Dimension | NeighborConnectivity +OffsetProviderTypeElem: TypeAlias = Dimension | NeighborConnectivityType OffsetProvider: TypeAlias = Mapping[Tag, OffsetProviderElem] +OffsetProviderType: TypeAlias = Mapping[Tag, OffsetProviderTypeElem] + + +def offset_provider_to_type(offset_provider: OffsetProvider) -> OffsetProviderType: + return { + k: v.__gt_type__() if isinstance(v, Connectivity) else v for k, v in offset_provider.items() + } DomainDimT = TypeVar("DomainDimT", bound="Dimension") @dataclasses.dataclass(frozen=True, eq=False) -class CartesianConnectivity(ConnectivityField[Dims[DomainDimT], DimT]): +class CartesianConnectivity(Connectivity[Dims[DomainDimT], DimT]): domain_dim: DomainDimT codomain: DimT offset: int = 0 @@ -981,7 +1055,7 @@ def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]: return core_defs.Int32DType() # type: ignore[return-value] # This is a workaround to make this class concrete, since `codomain` is an - # abstract property of the `ConnectivityField` Protocol. + # abstract property of the `Connectivity` Protocol. if not TYPE_CHECKING: @functools.cached_property @@ -1024,9 +1098,9 @@ def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRa def premap( self, - index_field: ConnectivityField | fbuiltins.FieldOffset, - *args: ConnectivityField | fbuiltins.FieldOffset, - ) -> ConnectivityField: + index_field: Connectivity | fbuiltins.FieldOffset, + *args: Connectivity | fbuiltins.FieldOffset, + ) -> Connectivity: raise NotImplementedError() __call__ = premap diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index dd52559e85..7b39511674 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -290,22 +290,24 @@ def as_connectivity( *, allocator: Optional[next_allocators.FieldBufferAllocatorProtocol] = None, device: Optional[core_defs.Device] = None, - skip_value: Optional[core_defs.IntegralScalar] = None, + skip_value: core_defs.IntegralScalar | eve.NothingType | None = eve.NOTHING, # TODO: copy=False -) -> common.ConnectivityField: +) -> common.Connectivity: """ - Construct a connectivity field from the given domain, codomain, and data. + Construct a `Connectivity` from the given domain, codomain, and data. Arguments: - domain: The domain of the connectivity field. It can be either a `common.DomainLike` object or a + domain: The domain of the connectivity. It can be either a `common.DomainLike` object or a sequence of `common.Dimension` objects. - codomain: The codomain dimension of the connectivity field. + codomain: The codomain dimension of the connectivity. data: The data used to construct the connectivity field. - dtype: The data type of the connectivity field. If not provided, it will be inferred from the data. - allocator: The allocator used to allocate the buffer for the connectivity field. If not provided, + dtype: The data type of the connectivity. If not provided, it will be inferred from the data. + allocator: The allocator used to allocate the buffer for the connectivity. If not provided, a default allocator will be used. - device: The device on which the connectivity field will be allocated. If not provided, the default + device: The device on which the connectivity will be allocated. If not provided, the default device will be used. + skip_value: The value that signals missing entries in the neighbor table. Defaults to the default + skip value if it is found in data, otherwise to `None` (= no skip value). Returns: The constructed connectivity field. @@ -313,9 +315,15 @@ def as_connectivity( Raises: ValueError: If the domain or codomain is invalid, or if the shape of the data does not match the domain shape. """ + if skip_value is eve.NOTHING: + skip_value = ( + common._DEFAULT_SKIP_VALUE if (data == common._DEFAULT_SKIP_VALUE).any() else None + ) + assert ( skip_value is None or skip_value == common._DEFAULT_SKIP_VALUE ) # TODO(havogt): not yet configurable + skip_value = cast(Optional[core_defs.IntegralScalar], skip_value) if isinstance(domain, Sequence) and all(isinstance(dim, common.Dimension) for dim in domain): domain = cast(Sequence[common.Dimension], domain) if len(domain) != data.ndim: diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 9ff5feaaee..e15fb4266a 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -36,7 +36,6 @@ exceptions as embedded_exceptions, ) from gt4py.next.ffront import experimental, fbuiltins -from gt4py.next.iterator import embedded as itir_embedded try: @@ -189,10 +188,10 @@ def from_array( def premap( self: NdArrayField, - *connectivities: common.ConnectivityField | fbuiltins.FieldOffset, + *connectivities: common.Connectivity | fbuiltins.FieldOffset, ) -> NdArrayField: """ - Rearrange the field content using the provided connectivity fields as index mappings. + Rearrange the field content using the provided connectivities (index mappings). This operation is conceptually equivalent to a regular composition of mappings `f∘c`, being `c` the `connectivity` argument and `f` the `self` data field. @@ -206,7 +205,7 @@ def premap( argument used in the right hand side of the operator should therefore have the same product of dimensions `c: S × T → A × B`. Such a mapping can also be expressed as a pair of mappings `c1: S × T → A` and `c2: S × T → B`, and this - is actually the only supported form in GT4Py because `ConnectivityField` instances + is actually the only supported form in GT4Py because `Connectivity` instances can only deal with a single dimension in its codomain. This approach makes connectivities reusable for any combination of dimensions in a field domain and matches the NumPy advanced indexing API, which basically is a @@ -261,15 +260,15 @@ def premap( """ # noqa: RUF002 # TODO(egparedes): move docstring to the `premap` builtin function when it exists - conn_fields: list[common.ConnectivityField] = [] + conn_fields: list[common.Connectivity] = [] codomains_counter: collections.Counter[common.Dimension] = collections.Counter() for connectivity in connectivities: - # For neighbor reductions, a FieldOffset is passed instead of an actual ConnectivityField - if not isinstance(connectivity, common.ConnectivityField): + # For neighbor reductions, a FieldOffset is passed instead of an actual Connectivity + if not isinstance(connectivity, common.Connectivity): assert isinstance(connectivity, fbuiltins.FieldOffset) connectivity = connectivity.as_connectivity_field() - assert isinstance(connectivity, common.ConnectivityField) + assert isinstance(connectivity, common.Connectivity) # Current implementation relies on skip_value == -1: # if we assume the indexed array has at least one element, @@ -318,8 +317,8 @@ def premap( def __call__( self, - index_field: common.ConnectivityField | fbuiltins.FieldOffset, - *args: common.ConnectivityField | fbuiltins.FieldOffset, + index_field: common.Connectivity | fbuiltins.FieldOffset, + *args: common.Connectivity | fbuiltins.FieldOffset, ) -> common.Field: return functools.reduce( lambda field, current_index_field: field.premap(current_index_field), @@ -460,7 +459,7 @@ def _dace_descriptor(self) -> Any: @dataclasses.dataclass(frozen=True) class NdArrayConnectivityField( # type: ignore[misc] # for __ne__, __eq__ - common.ConnectivityField[common.DimsT, common.DimT], + common.Connectivity[common.DimsT, common.DimT], NdArrayField[common.DimsT, core_defs.IntegralScalar], ): _codomain: common.DimT @@ -579,7 +578,7 @@ def restrict(self, index: common.AnyIndexSpec) -> NdArrayConnectivityField: __getitem__ = restrict -def _domain_premap(data: NdArrayField, *connectivities: common.ConnectivityField) -> NdArrayField: +def _domain_premap(data: NdArrayField, *connectivities: common.Connectivity) -> NdArrayField: """`premap` implementation transforming only the field domain not the data (i.e. translation and relocation).""" new_domain = data.domain for connectivity in connectivities: @@ -668,7 +667,7 @@ def _reshuffling_premap( ) -def _remapping_premap(data: NdArrayField, connectivity: common.ConnectivityField) -> NdArrayField: +def _remapping_premap(data: NdArrayField, connectivity: common.Connectivity) -> NdArrayField: new_dims = {*connectivity.domain.dims} - {connectivity.codomain} if repeated_dims := (new_dims & {*data.domain.dims}): raise ValueError(f"Remapped field will contain repeated dimensions '{repeated_dims}'.") @@ -693,7 +692,7 @@ def _remapping_premap(data: NdArrayField, connectivity: common.ConnectivityField if restricted_connectivity_domain != connectivity.domain else connectivity ) - assert isinstance(restricted_connectivity, common.ConnectivityField) + assert isinstance(restricted_connectivity, common.Connectivity) # 2- then compute the index array new_idx_array = xp.asarray(restricted_connectivity.ndarray) - current_range.start @@ -971,7 +970,7 @@ def _concat_where( return cls_.from_array(result_array, domain=result_domain) -NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[has-type] +NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] def _make_reduction( @@ -996,15 +995,15 @@ def _builtin_op( offset_definition = current_offset_provider[ axis.value ] # assumes offset and local dimension have same name - assert isinstance(offset_definition, itir_embedded.NeighborTableOffsetProvider) + assert common.is_neighbor_table(offset_definition) new_domain = common.Domain(*[nr for nr in field.domain if nr.dim != axis]) broadcast_slice = tuple( - slice(None) if d in [axis, offset_definition.origin_axis] else xp.newaxis + slice(None) if d in [axis, offset_definition.domain.dims[0]] else xp.newaxis for d in field.domain.dims ) masked_array = xp.where( - xp.asarray(offset_definition.table[broadcast_slice]) != common._DEFAULT_SKIP_VALUE, + xp.asarray(offset_definition.ndarray[broadcast_slice]) != common._DEFAULT_SKIP_VALUE, field.ndarray, initial_value_op(field), ) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index dc2421e1d2..9ce07d01bb 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -30,7 +30,6 @@ embedded as next_embedded, errors, ) -from gt4py.next.common import Connectivity, Dimension, GridType from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( field_operator_ast as foast, @@ -82,15 +81,15 @@ class Program: definition_stage: ffront_stages.ProgramDefinition backend: Optional[next_backend.Backend] - connectivities: Optional[dict[str, Connectivity]] + connectivities: Optional[common.OffsetProviderType] = None @classmethod def from_function( cls, definition: types.FunctionType, backend: Optional[next_backend], - grid_type: Optional[GridType] = None, - connectivities: Optional[dict[str, Connectivity]] = None, + grid_type: Optional[common.GridType] = None, + connectivities: Optional[common.OffsetProviderType] = None, ) -> Program: program_def = ffront_stages.ProgramDefinition(definition=definition, grid_type=grid_type) return cls(definition_stage=program_def, backend=backend, connectivities=connectivities) @@ -140,10 +139,10 @@ def _frontend_transforms(self) -> next_backend.Transforms: def with_backend(self, backend: next_backend.Backend) -> Program: return dataclasses.replace(self, backend=backend) - def with_connectivities(self, connectivities: dict[str, Connectivity]) -> Program: + def with_connectivities(self, connectivities: common.OffsetProviderType) -> Program: return dataclasses.replace(self, connectivities=connectivities) - def with_grid_type(self, grid_type: GridType) -> Program: + def with_grid_type(self, grid_type: common.GridType) -> Program: return dataclasses.replace( self, definition_stage=dataclasses.replace(self.definition_stage, grid_type=grid_type) ) @@ -199,7 +198,7 @@ def itir(self) -> itir.FencilDefinition: return self._frontend_transforms.past_to_itir(no_args_past).data @functools.cached_property - def _implicit_offset_provider(self) -> dict[common.Tag, common.OffsetProviderElem]: + def _implicit_offset_provider(self) -> dict[str, common.Dimension]: """ Add all implicit offset providers. @@ -226,9 +225,7 @@ def _implicit_offset_provider(self) -> dict[common.Tag, common.OffsetProviderEle ) return implicit_offset_provider - def __call__( - self, *args: Any, offset_provider: dict[str, Dimension | Connectivity], **kwargs: Any - ) -> None: + def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any) -> None: offset_provider = offset_provider | self._implicit_offset_provider if self.backend is None: warnings.warn( @@ -287,19 +284,17 @@ def definition(self) -> str: def with_backend(self, backend: next_backend.Backend) -> FrozenProgram: return self.__class__(program=self.program, backend=backend) - def with_grid_type(self, grid_type: GridType) -> FrozenProgram: + def with_grid_type(self, grid_type: common.GridType) -> FrozenProgram: return self.__class__( program=dataclasses.replace(self.program, grid_type=grid_type), backend=self.backend ) def jit( - self, *args: Any, offset_provider: dict[str, Dimension | Connectivity], **kwargs: Any + self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any ) -> stages.CompiledProgram: return self.backend.jit(self.program, *args, offset_provider=offset_provider, **kwargs) - def __call__( - self, *args: Any, offset_provider: dict[str, Dimension | Connectivity], **kwargs: Any - ) -> None: + def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any) -> None: args, kwargs = signature.convert_to_positional(self.program, *args, **kwargs) if not self._compiled_program: @@ -328,7 +323,7 @@ class ProgramFromPast(Program): past_stage: ffront_stages.PastProgramDefinition - def __call__(self, *args: Any, offset_provider: dict[str, Dimension], **kwargs: Any) -> None: + def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any) -> None: if self.backend is None: raise NotImplementedError( "Programs created from a PAST node (without a function definition) can not be executed in embedded mode" @@ -350,7 +345,7 @@ def __post_init__(self): class ProgramWithBoundArgs(Program): bound_args: dict[str, typing.Union[float, int, bool]] = None - def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs): + def __call__(self, *args, offset_provider: common.OffsetProvider, **kwargs): type_ = self.past_stage.past_node.type new_type = ts_ffront.ProgramType( definition=ts.FunctionType( @@ -436,7 +431,7 @@ def program( *, # `NOTHING` -> default backend, `None` -> no backend (embedded execution) backend: next_backend.Backend | eve.NOTHING = eve.NOTHING, - grid_type: Optional[GridType] = None, + grid_type: Optional[common.GridType] = None, frozen: bool = False, ) -> Program | FrozenProgram | Callable[[types.FunctionType], Program | FrozenProgram]: """ @@ -506,7 +501,7 @@ def from_function( cls, definition: types.FunctionType, backend: Optional[next_backend.Backend], - grid_type: Optional[GridType] = None, + grid_type: Optional[common.GridType] = None, *, operator_node_cls: type[OperatorNodeT] = foast.FieldOperator, operator_attributes: Optional[dict[str, Any]] = None, @@ -557,7 +552,7 @@ def __gt_type__(self) -> ts.CallableType: def with_backend(self, backend: next_backend.Backend) -> FieldOperator: return dataclasses.replace(self, backend=backend) - def with_grid_type(self, grid_type: GridType) -> FieldOperator: + def with_grid_type(self, grid_type: common.GridType) -> FieldOperator: return dataclasses.replace( self, definition_stage=dataclasses.replace(self.definition_stage, grid_type=grid_type) ) @@ -688,33 +683,33 @@ def field_operator_inner(definition: types.FunctionType) -> FieldOperator[foast. def scan_operator( definition: types.FunctionType, *, - axis: Dimension, + axis: common.Dimension, forward: bool, init: core_defs.Scalar, backend: Optional[str], - grid_type: GridType, + grid_type: common.GridType, ) -> FieldOperator[foast.ScanOperator]: ... @typing.overload def scan_operator( *, - axis: Dimension, + axis: common.Dimension, forward: bool, init: core_defs.Scalar, backend: Optional[str], - grid_type: GridType, + grid_type: common.GridType, ) -> Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]]: ... def scan_operator( definition: Optional[types.FunctionType] = None, *, - axis: Dimension, + axis: common.Dimension, forward: bool = True, init: core_defs.Scalar = 0.0, backend=eve.NOTHING, - grid_type: GridType = None, + grid_type: common.GridType = None, ) -> ( FieldOperator[foast.ScanOperator] | Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]] diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index 8a94c20832..bd22aebe57 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -14,7 +14,7 @@ @BuiltInFunction -def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.ConnectivityField: +def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivity: raise NotImplementedError() diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index d932431b51..b60fa63f95 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -16,7 +16,6 @@ import numpy as np from numpy import float32, float64, int32, int64 -import gt4py.next as gtx from gt4py._core import definitions as core_defs from gt4py.next import common from gt4py.next.common import Dimension, Field # noqa: F401 [unused-import] for TYPE_BUILTINS @@ -55,7 +54,7 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp return ts.DimensionType elif t is FieldOffset: return ts.OffsetType - elif t is common.ConnectivityField: + elif t is common.Connectivity: return ts.OffsetType elif t is core_defs.ScalarT: return ts.ScalarType @@ -321,7 +320,7 @@ def __post_init__(self) -> None: def __gt_type__(self) -> ts.OffsetType: return ts.OffsetType(source=self.source, target=self.target) - def __getitem__(self, offset: int) -> common.ConnectivityField: + def __getitem__(self, offset: int) -> common.Connectivity: """Serve as a connectivity factory.""" from gt4py.next import embedded # avoid circular import @@ -330,22 +329,19 @@ def __getitem__(self, offset: int) -> common.ConnectivityField: assert current_offset_provider is not None offset_definition = current_offset_provider[self.value] - connectivity: common.ConnectivityField + connectivity: common.Connectivity if isinstance(offset_definition, common.Dimension): connectivity = common.CartesianConnectivity(offset_definition, offset) - elif isinstance( - offset_definition, (gtx.NeighborTableOffsetProvider, common.ConnectivityField) - ): - unrestricted_connectivity = self.as_connectivity_field() - assert unrestricted_connectivity.domain.ndim > 1 + elif isinstance(offset_definition, common.Connectivity): + assert common.is_neighbor_connectivity(offset_definition) named_index = common.NamedIndex(self.target[-1], offset) - connectivity = unrestricted_connectivity[named_index] + connectivity = offset_definition[named_index] else: raise NotImplementedError() return connectivity - def as_connectivity_field(self) -> common.ConnectivityField: + def as_connectivity_field(self) -> common.Connectivity: """Convert to connectivity field using the offset providers in current embedded execution context.""" from gt4py.next import embedded # avoid circular import @@ -356,18 +352,8 @@ def as_connectivity_field(self) -> common.ConnectivityField: cache_key = id(offset_definition) if (connectivity := self._cache.get(cache_key, None)) is None: - if isinstance(offset_definition, common.ConnectivityField): + if isinstance(offset_definition, common.Connectivity): connectivity = offset_definition - elif isinstance(offset_definition, gtx.NeighborTableOffsetProvider): - connectivity = gtx.as_connectivity( - domain=self.target, - codomain=self.source, - data=offset_definition.table, - dtype=offset_definition.index_type, - skip_value=( - common._DEFAULT_SKIP_VALUE if offset_definition.has_skip_values else None - ), - ) else: raise NotImplementedError() diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 6221c95522..3c63ffef30 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -93,77 +93,113 @@ class SparseTag(Tag): ... -class NeighborTableOffsetProvider: +@xtyping.deprecated("Use a 'Connectivity' instead.") +def NeighborTableOffsetProvider( + table: core_defs.NDArrayObject, + origin_axis: common.Dimension, + neighbor_axis: common.Dimension, + max_neighbors: int, + has_skip_values=True, +) -> common.Connectivity: + return common._connectivity( + table, + codomain=neighbor_axis, + domain={ + origin_axis: table.shape[0], + common.Dimension( + value="_DummyLocalDim", kind=common.DimensionKind.LOCAL + ): max_neighbors, + }, + skip_value=common._DEFAULT_SKIP_VALUE if has_skip_values else None, + ) + + +# TODO(havogt): complete implementation and make available for fieldview embedded +@dataclasses.dataclass(frozen=True) +class StridedConnectivityField(common.Connectivity): + domain_dims: tuple[common.Dimension, common.Dimension] + codomain_dim: common.Dimension + _max_neighbors: int + def __init__( self, - table: core_defs.NDArrayObject, - origin_axis: common.Dimension, - neighbor_axis: common.Dimension, + domain_dims: Sequence[common.Dimension], + codomain_dim: common.Dimension, max_neighbors: int, - has_skip_values=True, - ) -> None: - self.table = table - self.origin_axis = origin_axis - self.neighbor_axis = neighbor_axis - assert not hasattr(table, "shape") or table.shape[1] == max_neighbors - self.max_neighbors = max_neighbors - self.has_skip_values = has_skip_values - self.index_type = table.dtype - - def mapped_index( - self, primary: common.IntIndex, neighbor_idx: common.IntIndex - ) -> common.IntIndex: - res = self.table[(primary, neighbor_idx)] - assert common.is_int_index(res) - return res + ): + object.__setattr__(self, "domain_dims", tuple(domain_dims)) + object.__setattr__(self, "codomain_dim", codomain_dim) + object.__setattr__(self, "_max_neighbors", max_neighbors) - if dace: - # Extension of NeighborTableOffsetProvider adding SDFGConvertible support in GT4Py Programs - def _dace_data_ptr(self) -> int: - obj = self.table - if dace.dtypes.is_array(obj): - if hasattr(obj, "__array_interface__"): - return obj.__array_interface__["data"][0] - if hasattr(obj, "__cuda_array_interface__"): - return obj.__cuda_array_interface__["data"][0] - raise ValueError("Unsupported data container.") - - def _dace_descriptor(self) -> dace.data.Data: - return dace.data.create_datadescriptor(self.table) - else: + @property + def __gt_origin__(self) -> xtyping.Never: + raise NotImplementedError + + def __gt_type__(self) -> common.NeighborConnectivityType: + return common.NeighborConnectivityType( + domain=self.domain_dims, + codomain=self.codomain_dim, + max_neighbors=self._max_neighbors, + skip_value=self.skip_value, + dtype=self.dtype, + ) - def _dace_data_ptr(self) -> NoReturn: # type: ignore[misc] - raise NotImplementedError( - "data_ptr is only supported when the 'dace' module is available." - ) + @property + def domain(self) -> common.Domain: + return common.Domain( + dims=self.domain_dims, + ranges=(common.UnitRange.infinite(), common.unit_range(self._max_neighbors)), + ) - def _dace_descriptor(self) -> NoReturn: # type: ignore[misc] - raise NotImplementedError( - "__descriptor__ is only supported when the 'dace' module is available." - ) + @property + def codomain(self) -> common.Dimension: + return self.codomain_dim - data_ptr = _dace_data_ptr - __descriptor__ = _dace_descriptor + @property + def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]: + return core_defs.Int32DType() # type: ignore[return-value] + @property + def ndarray(self) -> core_defs.NDArrayObject: + raise NotImplementedError -class StridedNeighborOffsetProvider: - def __init__( + def asnumpy(self) -> np.ndarray: + raise NotImplementedError + + def premap(self, index_field: common.Connectivity | fbuiltins.FieldOffset) -> common.Field: + raise NotImplementedError + + def restrict( # type: ignore[override] self, - origin_axis: common.Dimension, - neighbor_axis: common.Dimension, - max_neighbors: int, - has_skip_values=True, - ) -> None: - self.origin_axis = origin_axis - self.neighbor_axis = neighbor_axis - self.max_neighbors = max_neighbors - self.has_skip_values = has_skip_values - self.index_type = int + item: common.AnyIndexSpec, + ) -> common.Field: + if not isinstance(item, tuple) or (isinstance(item, tuple) and not len(item) == 2): + raise NotImplementedError() # TODO(havogt): add proper slicing + index = item[0] * self._max_neighbors + item[1] # type: ignore[operator, call-overload] + return ConstantField(index) + + def as_scalar(self) -> xtyping.Never: + raise NotImplementedError() + + def __call__( + self, + index_field: common.Connectivity | fbuiltins.FieldOffset, + *args: common.Connectivity | fbuiltins.FieldOffset, + ) -> common.Field: + raise NotImplementedError() - def mapped_index( - self, primary: common.IntIndex, neighbor_idx: common.IntIndex - ) -> common.IntIndex: - return primary * self.max_neighbors + neighbor_idx + __getitem__ = restrict # type: ignore[assignment] + + def inverse_image( + self, image_range: common.UnitRange | common.NamedRange + ) -> Sequence[common.NamedRange]: + raise NotImplementedError + + @property + def skip_value( + self, + ) -> None: + return None # Offsets @@ -597,10 +633,11 @@ def execute_shift( new_entry[i] = 0 else: offset_implementation = offset_provider[tag] - assert isinstance(offset_implementation, common.Connectivity) - cur_index = pos[offset_implementation.origin_axis.value] + assert common.is_neighbor_connectivity(offset_implementation) + source_dim = offset_implementation.__gt_type__().source_dim + cur_index = pos[source_dim.value] assert common.is_int_index(cur_index) - if offset_implementation.mapped_index(cur_index, index) in [ + if offset_implementation[cur_index, index].as_scalar() in [ None, common._DEFAULT_SKIP_VALUE, ]: @@ -620,22 +657,22 @@ def execute_shift( else: raise AssertionError() return new_pos - else: - assert isinstance(offset_implementation, common.Connectivity) - assert offset_implementation.origin_axis.value in pos + elif common.is_neighbor_connectivity(offset_implementation): + source_dim = offset_implementation.__gt_type__().source_dim + assert source_dim.value in pos new_pos = pos.copy() - new_pos.pop(offset_implementation.origin_axis.value) - cur_index = pos[offset_implementation.origin_axis.value] + new_pos.pop(source_dim.value) + cur_index = pos[source_dim.value] assert common.is_int_index(cur_index) - if offset_implementation.mapped_index(cur_index, index) in [ + if offset_implementation[cur_index, index].as_scalar() in [ None, common._DEFAULT_SKIP_VALUE, ]: return None else: - new_index = offset_implementation.mapped_index(cur_index, index) + new_index = offset_implementation[cur_index, index].as_scalar() assert new_index is not None - new_pos[offset_implementation.neighbor_axis.value] = int(new_index) + new_pos[offset_implementation.codomain.value] = int(new_index) return new_pos @@ -1196,8 +1233,8 @@ def as_scalar(self) -> core_defs.IntegralScalar: def premap( self, - index_field: common.ConnectivityField | fbuiltins.FieldOffset, - *args: common.ConnectivityField | fbuiltins.FieldOffset, + index_field: common.Connectivity | fbuiltins.FieldOffset, + *args: common.Connectivity | fbuiltins.FieldOffset, ) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() @@ -1322,8 +1359,8 @@ def asnumpy(self) -> np.ndarray: def premap( self, - index_field: common.ConnectivityField | fbuiltins.FieldOffset, - *args: common.ConnectivityField | fbuiltins.FieldOffset, + index_field: common.Connectivity | fbuiltins.FieldOffset, + *args: common.Connectivity | fbuiltins.FieldOffset, ) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() @@ -1428,10 +1465,12 @@ def __gt_type__(self) -> itir_ts.ListType: assert isinstance(offset_tag, str) element_type = type_translation.from_value(self.values[0]) assert isinstance(element_type, ts.DataType) - return itir_ts.ListType( - element_type=element_type, - offset_type=common.Dimension(value=offset_tag, kind=common.DimensionKind.LOCAL), - ) + offset_provider = embedded_context.offset_provider.get() + assert offset_provider is not None + connectivity = offset_provider[offset_tag] + assert common.is_neighbor_connectivity(connectivity) + local_dim = connectivity.__gt_type__().neighbor_dim + return itir_ts.ListType(element_type=element_type, offset_type=local_dim) @dataclasses.dataclass(frozen=True) @@ -1457,11 +1496,11 @@ def neighbors(offset: runtime.Offset, it: ItIterator) -> _List: offset_provider = embedded_context.offset_provider.get() assert offset_provider is not None connectivity = offset_provider[offset_str] - assert isinstance(connectivity, common.Connectivity) + assert common.is_neighbor_connectivity(connectivity) return _List( values=tuple( shifted.deref() - for i in range(connectivity.max_neighbors) + for i in range(connectivity.__gt_type__().max_neighbors) if (shifted := it.shift(offset_str, i)).can_deref() ), offset=offset, @@ -1533,11 +1572,11 @@ def deref(self) -> Any: offset_provider = embedded_context.offset_provider.get() assert offset_provider is not None connectivity = offset_provider[self.list_offset] - assert isinstance(connectivity, common.Connectivity) + assert common.is_neighbor_connectivity(connectivity) return _List( values=tuple( shifted.deref() - for i in range(connectivity.max_neighbors) + for i in range(connectivity.__gt_type__().max_neighbors) if ( shifted := self.it.shift(*self.offsets, SparseTag(self.list_offset), i) ).can_deref() @@ -1671,9 +1710,9 @@ def _dimension_to_tag(domain: Domain) -> dict[Tag, range]: return {k.value if isinstance(k, common.Dimension) else k: v for k, v in domain.items()} -def _validate_domain(domain: Domain, offset_provider: OffsetProvider) -> None: +def _validate_domain(domain: Domain, offset_provider_type: common.OffsetProviderType) -> None: if isinstance(domain, runtime.CartesianDomain): - if any(isinstance(o, common.Connectivity) for o in offset_provider.values()): + if any(isinstance(o, common.ConnectivityType) for o in offset_provider_type.values()): raise RuntimeError( "Got a 'CartesianDomain', but found a 'Connectivity' in 'offset_provider', expected 'UnstructuredDomain'." ) @@ -1770,10 +1809,10 @@ def _fieldspec_list_to_value( offset_type = type_.offset_type assert isinstance(offset_type, common.Dimension) connectivity = offset_provider[offset_type.value] - assert isinstance(connectivity, common.Connectivity) + assert common.is_neighbor_connectivity(connectivity) return domain.insert( len(domain), - common.named_range((offset_type, connectivity.max_neighbors)), + common.named_range((offset_type, connectivity.__gt_type__().max_neighbors)), ), type_.element_type return domain, type_ @@ -1809,7 +1848,7 @@ def closure( ) -> None: assert embedded_context.within_valid_context() offset_provider = embedded_context.offset_provider.get() - _validate_domain(domain_, offset_provider) + _validate_domain(domain_, common.offset_provider_to_type(offset_provider)) domain: dict[Tag, range] = _dimension_to_tag(domain_) if not (isinstance(out, common.Field) or is_tuple_of_field(out)): raise TypeError("'Out' needs to be a located field.") diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 8f842e1c13..f5625b509c 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -12,7 +12,6 @@ import functools from typing import Any, Literal, Mapping, Optional -import gt4py.next as gtx from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im @@ -23,20 +22,19 @@ def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> di """ Extract horizontal domain sizes from an `offset_provider`. - Considers the shape of the neighbor table to get the size of each `origin_axis` and the maximum - value inside the neighbor table to get the size of each `neighbor_axis`. + Considers the shape of the neighbor table to get the size of each `source_dim` and the maximum + value inside the neighbor table to get the size of each `codomain`. """ sizes = dict[str, int]() for provider in offset_provider.values(): - if isinstance(provider, gtx.NeighborTableOffsetProvider): - assert provider.origin_axis.kind == gtx.DimensionKind.HORIZONTAL - assert provider.neighbor_axis.kind == gtx.DimensionKind.HORIZONTAL - sizes[provider.origin_axis.value] = max( - sizes.get(provider.origin_axis.value, 0), provider.table.shape[0] + if common.is_neighbor_connectivity(provider): + conn_type = provider.__gt_type__() + sizes[conn_type.source_dim.value] = max( + sizes.get(conn_type.source_dim.value, 0), provider.ndarray.shape[0] ) - sizes[provider.neighbor_axis.value] = max( - sizes.get(provider.neighbor_axis.value, 0), - provider.table.max() + 1, # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject + sizes[conn_type.codomain.value] = max( + sizes.get(conn_type.codomain.value, 0), + provider.ndarray.max() + 1, # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject ) return sizes @@ -114,7 +112,7 @@ def translate( new_ranges[current_dim] = SymbolicRange.translate( self.ranges[current_dim], val.value ) - elif isinstance(nbt_provider, common.Connectivity): + elif common.is_neighbor_connectivity(nbt_provider): # unstructured shift assert ( isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int) @@ -132,8 +130,8 @@ def translate( for k, v in _max_domain_sizes_by_location_type(offset_provider).items() } - old_dim = nbt_provider.origin_axis - new_dim = nbt_provider.neighbor_axis + old_dim = nbt_provider.__gt_type__().source_dim + new_dim = nbt_provider.__gt_type__().codomain assert new_dim not in new_ranges or old_dim == new_dim diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index ad85d154cb..d42f961202 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -12,7 +12,7 @@ import functools import types from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Union import devtools @@ -127,7 +127,9 @@ def fendef( ) -def _deduce_domain(domain: dict[common.Dimension, range], offset_provider: dict[str, Any]): +def _deduce_domain( + domain: dict[common.Dimension, range], offset_provider_type: common.OffsetProviderType +): if isinstance(domain, UnstructuredDomain): domain_builtin = builtins.unstructured_domain elif isinstance(domain, CartesianDomain): @@ -135,7 +137,7 @@ def _deduce_domain(domain: dict[common.Dimension, range], offset_provider: dict[ else: domain_builtin = ( builtins.unstructured_domain - if any(isinstance(o, common.Connectivity) for o in offset_provider.values()) + if any(isinstance(o, common.ConnectivityType) for o in offset_provider_type.values()) else builtins.cartesian_domain ) @@ -160,7 +162,7 @@ def impl(out, *inps): elif isinstance(dom, dict): # if passed as a dict, we need to convert back to builtins for interpretation by the backends assert offset_provider is not None - dom = _deduce_domain(dom, offset_provider) + dom = _deduce_domain(dom, common.offset_provider_to_type(offset_provider)) closure(dom, self.fundef_dispatcher, out, [*inps]) return impl diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index f84714e779..e71a24127f 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -105,7 +105,7 @@ def apply( *, ignore_tuple_size: bool = False, remove_letified_make_tuple_elements: bool = True, - offset_provider: Optional[common.OffsetProvider] = None, + offset_provider_type: Optional[common.OffsetProviderType] = None, within_stencil: Optional[bool] = None, # manually passing flags is mostly for allowing separate testing of the modes flags: Optional[Flag] = None, @@ -126,7 +126,7 @@ def apply( `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)` -> {1, 2}` """ flags = flags or cls.flags - offset_provider = offset_provider or {} + offset_provider_type = offset_provider_type or {} if isinstance(node, (ir.Program, ir.FencilDefinition)): within_stencil = False @@ -138,7 +138,7 @@ def apply( if not ignore_tuple_size: node = itir_type_inference.infer( node, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, allow_undeclared_symbols=allow_undeclared_symbols, ) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 38ea1fd53d..824adfdd8d 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -411,7 +411,7 @@ def apply( cls, node: ProgramOrExpr, within_stencil: bool | None = None, - offset_provider: common.OffsetProvider | None = None, + offset_provider_type: common.OffsetProviderType | None = None, ) -> ProgramOrExpr: is_program = isinstance(node, (itir.Program, itir.FencilDefinition)) if is_program: @@ -422,9 +422,9 @@ def apply( within_stencil is not None ), "The expression's context must be specified using `within_stencil`." - offset_provider = offset_provider or {} + offset_provider_type = offset_provider_type or {} node = itir_type_inference.infer( - node, offset_provider=offset_provider, allow_undeclared_symbols=not is_program + node, offset_provider_type=offset_provider_type, allow_undeclared_symbols=not is_program ) return cls().visit(node, within_stencil=within_stencil) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index da238733da..9076bf2d3f 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -11,6 +11,7 @@ from gt4py import eve from gt4py.eve import utils as eve_utils +from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import ( @@ -89,7 +90,7 @@ class FuseAsFieldOp(eve.NodeTranslator): ) >>> print( ... FuseAsFieldOp.apply( - ... nested_as_fieldop, offset_provider={}, allow_undeclared_symbols=True + ... nested_as_fieldop, offset_provider_type={}, allow_undeclared_symbols=True ... ) ... ) as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2, inp3) @@ -134,12 +135,14 @@ def apply( cls, node: itir.Program, *, - offset_provider, + offset_provider_type: common.OffsetProviderType, uids: Optional[eve_utils.UIDGenerator] = None, allow_undeclared_symbols=False, ): node = type_inference.infer( - node, offset_provider=offset_provider, allow_undeclared_symbols=allow_undeclared_symbols + node, + offset_provider_type=offset_provider_type, + allow_undeclared_symbols=allow_undeclared_symbols, ) if not uids: diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 90f8a6cded..a6d39883e3 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -187,7 +187,9 @@ def create_global_tmps( arguments into temporaries. """ program = infer_domain.infer_program(program, offset_provider) - program = type_inference.infer(program, offset_provider=offset_provider) + program = type_inference.infer( + program, offset_provider_type=common.offset_provider_to_type(offset_provider) + ) if not uids: uids = eve_utils.UIDGenerator(prefix="__tmp") diff --git a/src/gt4py/next/iterator/transforms/inline_scalar.py b/src/gt4py/next/iterator/transforms/inline_scalar.py index c6e2c38b90..87b576d14d 100644 --- a/src/gt4py/next/iterator/transforms/inline_scalar.py +++ b/src/gt4py/next/iterator/transforms/inline_scalar.py @@ -17,8 +17,8 @@ class InlineScalar(eve.NodeTranslator): @classmethod - def apply(cls, program: itir.Program, offset_provider: common.OffsetProvider): - program = itir_inference.infer(program, offset_provider=offset_provider) + def apply(cls, program: itir.Program, offset_provider_type: common.OffsetProviderType): + program = itir_inference.infer(program, offset_provider_type=offset_provider_type) return cls().visit(program) def visit_Expr(self, node: itir.Expr): diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 52a452155a..ec6f89685a 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -43,8 +43,8 @@ def __call__( def apply_common_transforms( ir: itir.Program | itir.FencilDefinition, *, + offset_provider=None, # TODO(havogt): should be replaced by offset_provider_type, but global_tmps currently relies on runtime info extract_temporaries=False, - offset_provider=None, unroll_reduce=False, common_subexpression_elimination=True, force_inline_lambda_args=False, @@ -56,7 +56,12 @@ def apply_common_transforms( #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for #: more details. symbolic_domain_sizes: Optional[dict[str, str]] = None, + offset_provider_type: Optional[common.OffsetProviderType] = None, ) -> itir.Program: + # TODO(havogt): if the runtime `offset_provider` is not passed, we cannot run global_tmps + if offset_provider_type is None: + offset_provider_type = common.offset_provider_to_type(offset_provider) + # FIXME[#1582](tehrengruber): Rewrite iterator tests with itir.Program and remove this if isinstance(ir, itir.FencilDefinition): ir = fencil_to_program.FencilToProgram.apply(ir) @@ -75,7 +80,7 @@ def apply_common_transforms( # Inline. The domain inference can not handle "user" functions, e.g. `let f = λ(...) → ... in f(...)` ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) - ir = CollapseTuple.apply(ir, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program + ir = CollapseTuple.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( ir, # type: ignore[arg-type] # always an itir.Program offset_provider=offset_provider, @@ -89,15 +94,15 @@ def apply_common_transforms( inlined = ConstantFolding.apply(inlined) # type: ignore[assignment] # always an itir.Program # This pass is required to be in the loop such that when an `if_` call with tuple arguments # is constant-folded the surrounding tuple_get calls can be removed. - inlined = CollapseTuple.apply(inlined, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program - inlined = InlineScalar.apply(inlined, offset_provider=offset_provider) + inlined = CollapseTuple.apply(inlined, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program + inlined = InlineScalar.apply(inlined, offset_provider_type=offset_provider_type) # This pass is required to run after CollapseTuple as otherwise we can not inline # expressions like `tuple_get(make_tuple(as_fieldop(stencil)(...)))` where stencil returns # a list. Such expressions must be inlined however because no backend supports such # field operators right now. inlined = fuse_as_fieldop.FuseAsFieldOp.apply( - inlined, uids=mergeasfop_uids, offset_provider=offset_provider + inlined, uids=mergeasfop_uids, offset_provider_type=offset_provider_type ) if inlined == ir: @@ -108,19 +113,21 @@ def apply_common_transforms( # breaks in test_zero_dim_tuple_arg as trivial tuple_get is not inlined if common_subexpression_elimination: - ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) + ir = CommonSubexpressionElimination.apply(ir, offset_provider_type=offset_provider_type) ir = MergeLet().visit(ir) ir = InlineLambdas.apply(ir, opcount_preserving=True) if extract_temporaries: - ir = infer(ir, inplace=True, offset_provider=offset_provider) + ir = infer(ir, inplace=True, offset_provider_type=offset_provider_type) ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider, uids=tmp_uids) # type: ignore[arg-type] # always an itir.Program # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can # only run the unconditional version here instead of in the loop above. if unconditionally_collapse_tuples: - ir = CollapseTuple.apply(ir, ignore_tuple_size=True, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program + ir = CollapseTuple.apply( + ir, ignore_tuple_size=True, offset_provider_type=offset_provider_type + ) # type: ignore[assignment] # always an itir.Program ir = NormalizeShifts().visit(ir) @@ -129,7 +136,7 @@ def apply_common_transforms( if unroll_reduce: for _ in range(10): - unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider) + unrolled = UnrollReduce.apply(ir, offset_provider_type=offset_provider_type) if unrolled == ir: break ir = unrolled # type: ignore[assignment] # still a `itir.Program` @@ -156,6 +163,8 @@ def apply_fieldview_transforms( ir = inline_fundefs.InlineFundefs().visit(ir) ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) - ir = CollapseTuple.apply(ir, offset_provider=offset_provider) # type: ignore[assignment] # type is still `itir.Program` + ir = CollapseTuple.apply( + ir, offset_provider_type=common.offset_provider_to_type(offset_provider) + ) # type: ignore[assignment] # type is still `itir.Program` ir = infer_domain.infer_program(ir, offset_provider=offset_provider) return ir diff --git a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py index 792bb421f1..94c962e92d 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py +++ b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py @@ -10,6 +10,7 @@ from typing import Callable, Optional from gt4py.eve import utils as eve_utils +from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import fencil_to_program, inline_fundefs from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet @@ -75,8 +76,13 @@ def apply_common_transforms( Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] ] = None, symbolic_domain_sizes: Optional[dict[str, str]] = None, + offset_provider_type: Optional[common.OffsetProviderType] = None, ) -> itir.Program: assert isinstance(ir, itir.FencilDefinition) + # TODO(havogt): if the runtime `offset_provider` is not passed, we cannot run global_tmps + if offset_provider_type is None: + offset_provider_type = common.offset_provider_to_type(offset_provider) + ir = fencil_to_program.FencilToProgram().apply(ir) icdlv_uids = eve_utils.UIDGenerator() @@ -109,7 +115,7 @@ def apply_common_transforms( # is constant-folded the surrounding tuple_get calls can be removed. inlined = CollapseTuple.apply( inlined, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, # TODO(tehrengruber): disabled since it increases compile-time too much right now flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, ) @@ -134,7 +140,7 @@ def apply_common_transforms( ir = CollapseTuple.apply( ir, ignore_tuple_size=True, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, # TODO(tehrengruber): disabled since it increases compile-time too much right now flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, ) @@ -149,7 +155,7 @@ def apply_common_transforms( if unroll_reduce: for _ in range(10): - unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider) + unrolled = UnrollReduce.apply(ir, offset_provider_type=offset_provider_type) if unrolled == ir: break ir = unrolled @@ -164,7 +170,7 @@ def apply_common_transforms( ir = ScanEtaReduction().visit(ir) if common_subexpression_elimination: - ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) # type: ignore[type-var] # always an itir.Program + ir = CommonSubexpressionElimination.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[type-var] # always an itir.Program ir = MergeLet().visit(ir) ir = InlineLambdas.apply( diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index ec9c3efb2b..042a86cd8e 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -64,16 +64,16 @@ def _get_partial_offset_tags(reduce_args: Iterable[itir.Expr]) -> Iterable[str]: def _get_connectivity( applied_reduce_node: itir.FunCall, - offset_provider: dict[str, common.Dimension | common.Connectivity], -) -> common.Connectivity: + offset_provider_type: common.OffsetProviderType, +) -> common.NeighborConnectivityType: """Return single connectivity that is compatible with the arguments of the reduce.""" if not cpm.is_applied_reduce(applied_reduce_node): raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.") - connectivities: list[common.Connectivity] = [] + connectivities: list[common.NeighborConnectivityType] = [] for o in _get_partial_offset_tags(applied_reduce_node.args): - conn = offset_provider[o] - assert isinstance(conn, common.Connectivity) + conn = offset_provider_type[o] + assert isinstance(conn, common.NeighborConnectivityType) connectivities.append(conn) if not connectivities: @@ -120,15 +120,15 @@ class UnrollReduce(PreserveLocationVisitor, NodeTranslator): uids: UIDGenerator = dataclasses.field(init=False, repr=False, default_factory=UIDGenerator) @classmethod - def apply(cls, node: itir.Node, **kwargs) -> itir.Node: - return cls().visit(node, **kwargs) - - def _visit_reduce(self, node: itir.FunCall, **kwargs) -> itir.Expr: - offset_provider = kwargs["offset_provider"] - assert offset_provider is not None - connectivity = _get_connectivity(node, offset_provider) - max_neighbors = connectivity.max_neighbors - has_skip_values = connectivity.has_skip_values + def apply(cls, node: itir.Node, offset_provider_type: common.OffsetProviderType) -> itir.Node: + return cls().visit(node, offset_provider_type=offset_provider_type) + + def _visit_reduce( + self, node: itir.FunCall, offset_provider_type: common.OffsetProviderType + ) -> itir.Expr: + connectivity_type = _get_connectivity(node, offset_provider_type) + max_neighbors = connectivity_type.max_neighbors + has_skip_values = connectivity_type.has_skip_values acc = itir.SymRef(id=self.uids.sequential_id(prefix="_acc")) offset = itir.SymRef(id=self.uids.sequential_id(prefix="_i")) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 66d8345b94..987eb0f308 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -155,7 +155,7 @@ class ObservableTypeSynthesizer(type_synthesizer.TypeSynthesizer): >>> square_func_type_synthesizer = type_synthesizer.TypeSynthesizer( ... type_synthesizer=lambda base: power(base, int_type) ... ) - >>> square_func_type_synthesizer(float_type, offset_provider={}) + >>> square_func_type_synthesizer(float_type, offset_provider_type={}) ScalarType(kind=, shape=None) Note that without a corresponding call the function itself can not be fully typed and as such @@ -169,7 +169,7 @@ class ObservableTypeSynthesizer(type_synthesizer.TypeSynthesizer): ... node=square_func, ... store_inferred_type_in_node=True, ... ) - >>> o_type_synthesizer(float_type, offset_provider={}) + >>> o_type_synthesizer(float_type, offset_provider_type={}) ScalarType(kind=, shape=None) >>> square_func.type == ts.FunctionType( ... pos_only_args=[float_type], pos_or_kw_args={}, kw_only_args={}, returns=float_type @@ -225,13 +225,15 @@ def on_type_ready(self, cb: Callable[[ts.TypeSpec], None]) -> None: def __call__( self, *args: type_synthesizer.TypeOrTypeSynthesizer, - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, ) -> Union[ts.TypeSpec, ObservableTypeSynthesizer]: assert all( isinstance(arg, (ts.TypeSpec, ObservableTypeSynthesizer)) for arg in args ), "ObservableTypeSynthesizer can only be used with arguments that are TypeSpec or ObservableTypeSynthesizer" - return_type_or_synthesizer = self.type_synthesizer(*args, offset_provider=offset_provider) + return_type_or_synthesizer = self.type_synthesizer( + *args, offset_provider_type=offset_provider_type + ) # return type is a typing rule by itself if isinstance(return_type_or_synthesizer, type_synthesizer.TypeSynthesizer): @@ -250,18 +252,18 @@ def __call__( def _get_dimensions_from_offset_provider( - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, ) -> dict[str, common.Dimension]: dimensions: dict[str, common.Dimension] = {} - for offset_name, provider in offset_provider.items(): + for offset_name, provider in offset_provider_type.items(): dimensions[offset_name] = common.Dimension( value=offset_name, kind=common.DimensionKind.LOCAL ) if isinstance(provider, common.Dimension): dimensions[provider.value] = provider - elif isinstance(provider, common.Connectivity): - dimensions[provider.origin_axis.value] = provider.origin_axis - dimensions[provider.neighbor_axis.value] = provider.neighbor_axis + elif isinstance(provider, common.NeighborConnectivityType): + dimensions[provider.source_dim.value] = provider.source_dim + dimensions[provider.codomain.value] = provider.codomain return dimensions @@ -318,7 +320,7 @@ class ITIRTypeInference(eve.NodeTranslator): PRESERVED_ANNEX_ATTRS = ("domain",) - offset_provider: common.OffsetProvider + offset_provider_type: common.OffsetProviderType #: Mapping from a dimension name to the actual dimension instance. dimensions: dict[str, common.Dimension] #: Allow sym refs to symbols that have not been declared. Mostly used in testing. @@ -329,7 +331,7 @@ def apply( cls, node: T, *, - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, inplace: bool = False, allow_undeclared_symbols: bool = False, ) -> T: @@ -340,7 +342,7 @@ def apply( node: The :class:`itir.Node` to infer the types of. Keyword Arguments: - offset_provider: Offset provider dictionary. + offset_provider_type: Offset provider dictionary. inplace: Write types directly to the given ``node`` instead of returning a copy. allow_undeclared_symbols: Allow references to symbols that don't have a corresponding declaration. This is useful for testing or inference on partially inferred sub-nodes. @@ -403,9 +405,9 @@ def apply( ) instance = cls( - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, dimensions=( - _get_dimensions_from_offset_provider(offset_provider) + _get_dimensions_from_offset_provider(offset_provider_type) | _get_dimensions_from_types( node.pre_walk_values() .if_isinstance(itir.Node) @@ -540,7 +542,7 @@ def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx) -> it_ts.Stenc for input_ in inputs ] stencil_returns = stencil_type_synthesizer( - *stencil_args, offset_provider=self.offset_provider + *stencil_args, offset_provider_type=self.offset_provider_type ) return it_ts.StencilClosureType( @@ -632,7 +634,7 @@ def visit_FunCall( fun = self.visit(node.fun, ctx=ctx) args = self.visit(node.args, ctx=ctx) - result = fun(*args, offset_provider=self.offset_provider) + result = fun(*args, offset_provider_type=self.offset_provider_type) if isinstance(result, ObservableTypeSynthesizer): assert not result.node diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 43c4465576..5be9ed7438 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -35,20 +35,20 @@ class TypeSynthesizer: - isinstance checks to determine if an object is actually (meant to be) a type synthesizer and not just any callable. - writing simple type synthesizers without cluttering the signature with the additional - offset_provider argument that is only needed by some. + offset_provider_type argument that is only needed by some. """ type_synthesizer: Callable[..., TypeOrTypeSynthesizer] def __post_init__(self): - if "offset_provider" not in inspect.signature(self.type_synthesizer).parameters: + if "offset_provider_type" not in inspect.signature(self.type_synthesizer).parameters: synthesizer = self.type_synthesizer - self.type_synthesizer = lambda *args, offset_provider: synthesizer(*args) + self.type_synthesizer = lambda *args, offset_provider_type: synthesizer(*args) def __call__( - self, *args: TypeOrTypeSynthesizer, offset_provider: common.OffsetProvider + self, *args: TypeOrTypeSynthesizer, offset_provider_type: common.OffsetProviderType ) -> TypeOrTypeSynthesizer: - return self.type_synthesizer(*args, offset_provider=offset_provider) + return self.type_synthesizer(*args, offset_provider_type=offset_provider_type) TypeOrTypeSynthesizer = Union[ts.TypeSpec, TypeSynthesizer] @@ -212,7 +212,7 @@ def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) - def lift(stencil: TypeSynthesizer) -> TypeSynthesizer: @TypeSynthesizer def apply_lift( - *its: it_ts.IteratorType, offset_provider: common.OffsetProvider + *its: it_ts.IteratorType, offset_provider_type: common.OffsetProviderType ) -> it_ts.IteratorType: assert all(isinstance(it, it_ts.IteratorType) for it in its) stencil_args = [ @@ -224,7 +224,7 @@ def apply_lift( ) for it in its ] - stencil_return_type = stencil(*stencil_args, offset_provider=offset_provider) + stencil_return_type = stencil(*stencil_args, offset_provider_type=offset_provider_type) assert isinstance(stencil_return_type, ts.DataType) position_dims = its[0].position_dims if its else [] @@ -282,7 +282,7 @@ def as_fieldop( stencil: TypeSynthesizer, domain: Optional[it_ts.DomainType] = None, *, - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, ) -> TypeSynthesizer: # In case we don't have a domain argument to `as_fieldop` we can not infer the exact result # type. In order to still allow some passes which don't need this information to run before the @@ -308,7 +308,7 @@ def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: stencil_return = stencil( *(_convert_as_fieldop_input_to_iterator(domain, field) for field in fields), - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, ) assert isinstance(stencil_return, ts.DataType) return type_info.apply_to_primitive_constituents( @@ -328,8 +328,10 @@ def scan( assert isinstance(direction, ts.ScalarType) and direction.kind == ts.ScalarKind.BOOL @TypeSynthesizer - def apply_scan(*its: it_ts.IteratorType, offset_provider: common.OffsetProvider) -> ts.DataType: - result = scan_pass(init, *its, offset_provider=offset_provider) + def apply_scan( + *its: it_ts.IteratorType, offset_provider_type: common.OffsetProviderType + ) -> ts.DataType: + result = scan_pass(init, *its, offset_provider_type=offset_provider_type) assert isinstance(result, ts.DataType) return result @@ -340,12 +342,12 @@ def apply_scan(*its: it_ts.IteratorType, offset_provider: common.OffsetProvider) def map_(op: TypeSynthesizer) -> TypeSynthesizer: @TypeSynthesizer def applied_map( - *args: it_ts.ListType, offset_provider: common.OffsetProvider + *args: it_ts.ListType, offset_provider_type: common.OffsetProviderType ) -> it_ts.ListType: assert len(args) > 0 assert all(isinstance(arg, it_ts.ListType) for arg in args) arg_el_types = [arg.element_type for arg in args] - el_type = op(*arg_el_types, offset_provider=offset_provider) + el_type = op(*arg_el_types, offset_provider_type=offset_provider_type) assert isinstance(el_type, ts.DataType) return it_ts.ListType(element_type=el_type) @@ -355,15 +357,17 @@ def applied_map( @_register_builtin_type_synthesizer def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer: @TypeSynthesizer - def applied_reduce(*args: it_ts.ListType, offset_provider: common.OffsetProvider): + def applied_reduce(*args: it_ts.ListType, offset_provider_type: common.OffsetProviderType): assert all(isinstance(arg, it_ts.ListType) for arg in args) - return op(init, *(arg.element_type for arg in args), offset_provider=offset_provider) + return op( + init, *(arg.element_type for arg in args), offset_provider_type=offset_provider_type + ) return applied_reduce @_register_builtin_type_synthesizer -def shift(*offset_literals, offset_provider: common.OffsetProvider) -> TypeSynthesizer: +def shift(*offset_literals, offset_provider_type: common.OffsetProviderType) -> TypeSynthesizer: @TypeSynthesizer def apply_shift( it: it_ts.IteratorType | ts.DeferredType, @@ -379,19 +383,19 @@ def apply_shift( assert isinstance(offset_axis, it_ts.OffsetLiteralType) and isinstance( offset_axis.value, common.Dimension ) - provider = offset_provider[offset_axis.value.value] # TODO: naming - if isinstance(provider, common.Dimension): + type_ = offset_provider_type[offset_axis.value.value] + if isinstance(type_, common.Dimension): pass - elif isinstance(provider, common.Connectivity): + elif isinstance(type_, common.NeighborConnectivityType): found = False for i, dim in enumerate(new_position_dims): - if dim.value == provider.origin_axis.value: + if dim.value == type_.source_dim.value: assert not found - new_position_dims[i] = provider.neighbor_axis + new_position_dims[i] = type_.codomain found = True assert found else: - raise NotImplementedError() + raise NotImplementedError(f"{type_} is not a supported Connectivity type.") return it_ts.IteratorType( position_dims=new_position_dims, defined_dims=it.defined_dims, diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 802ad2155f..69d8985beb 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -26,7 +26,6 @@ import typing from typing import Any, Iterable, Iterator, Optional -import numpy as np from typing_extensions import Self from gt4py.next import common @@ -49,47 +48,19 @@ def from_signature(cls, *args: Any, **kwargs: Any) -> Self: return cls(args=args, kwargs=kwargs) -@dataclasses.dataclass(frozen=True) -class CompileTimeConnectivity(common.Connectivity): - """Compile-time standin for a GTX connectivity, retaining everything except the connectivity tables.""" - - max_neighbors: int - has_skip_values: bool - origin_axis: common.Dimension - neighbor_axis: common.Dimension - index_type: type[int] | type[np.int32] | type[np.int64] - - def mapped_index( - self, cur_index: int | np.integer, neigh_index: int | np.integer - ) -> Optional[int | np.integer]: - raise NotImplementedError( - "A CompileTimeConnectivity instance should not call `mapped_index`." - ) - - @classmethod - def from_connectivity(cls, connectivity: common.Connectivity) -> Self: - return cls( - max_neighbors=connectivity.max_neighbors, - has_skip_values=connectivity.has_skip_values, - origin_axis=connectivity.origin_axis, - neighbor_axis=connectivity.neighbor_axis, - index_type=connectivity.index_type, - ) - - @property - def table(self) -> None: - return None - - @dataclasses.dataclass(frozen=True) class CompileTimeArgs: """Compile-time standins for arguments to a GTX program to be used in ahead-of-time compilation.""" args: tuple[ts.TypeSpec, ...] kwargs: dict[str, ts.TypeSpec] - offset_provider: dict[str, common.Connectivity | common.Dimension] + offset_provider: common.OffsetProvider # TODO(havogt): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information column_axis: Optional[common.Dimension] + @property + def offset_provider_type(self) -> common.OffsetProviderType: + return common.offset_provider_to_type(self.offset_provider) + @classmethod def from_concrete_no_size(cls, *args: Any, **kwargs: Any) -> Self: """Convert concrete GTX program arguments into their compile-time counterparts.""" @@ -98,8 +69,7 @@ def from_concrete_no_size(cls, *args: Any, **kwargs: Any) -> Self: offset_provider = kwargs_copy.pop("offset_provider", {}) return cls( args=compile_args, - offset_provider=offset_provider, # TODO(ricoh): replace with the line below once the temporaries pass is AOT-ready. If unsure, just try it and run the tests. - # offset_provider={k: connectivity_or_dimension(v) for k, v in offset_provider.items()}, # noqa: ERA001 [commented-out-code] + offset_provider=offset_provider, column_axis=kwargs_copy.pop("column_axis", None), kwargs={ k: type_translation.from_value(v) for k, v in kwargs_copy.items() if v is not None @@ -138,18 +108,6 @@ def adapted_jit_to_aot_args_factory() -> ( return toolchain.ArgsOnlyAdapter(jit_to_aot_args) -def connectivity_or_dimension( - some_offset_provider: common.Connectivity | common.Dimension, -) -> CompileTimeConnectivity | common.Dimension: - match some_offset_provider: - case common.Dimension(): - return some_offset_provider - case common.Connectivity(): - return CompileTimeConnectivity.from_connectivity(some_offset_provider) - case _: - raise ValueError - - def find_first_field(tuple_arg: tuple[Any, ...]) -> Optional[common.Field]: for element in tuple_arg: match element: diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py index cc57c137bf..b2aea05641 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py @@ -12,7 +12,6 @@ import gt4py.eve as eve from gt4py.eve import NodeTranslator, concepts from gt4py.eve.utils import UIDGenerator -from gt4py.next import common from gt4py.next.program_processors.codegens.gtfn import gtfn_ir, gtfn_ir_common from gt4py.next.program_processors.codegens.gtfn.gtfn_im_ir import ( AssignStmt, @@ -84,54 +83,9 @@ def _is_reduce(node: gtfn_ir.FunCall) -> TypeGuard[gtfn_ir.FunCall]: ) -def _get_connectivity( - applied_reduce_node: gtfn_ir.FunCall, - offset_provider: dict[str, common.Dimension | common.Connectivity], -) -> common.Connectivity: - """Return single connectivity that is compatible with the arguments of the reduce.""" - if not _is_reduce(applied_reduce_node): - raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.") - - connectivities: list[common.Connectivity] = [] - for o in _get_partial_offset_tags(applied_reduce_node.args): - conn = offset_provider[o] - assert isinstance(conn, common.Connectivity) - connectivities.append(conn) - - if not connectivities: - raise RuntimeError("Couldn't detect partial shift in any arguments of 'reduce'.") - - if len({(c.max_neighbors, c.has_skip_values) for c in connectivities}) != 1: - # The condition for this check is required but not sufficient: the actual neighbor tables could still be incompatible. - raise RuntimeError("Arguments to 'reduce' have incompatible partial shifts.") - return connectivities[0] - - # TODO: end of code clone -def _make_dense_acess( - shift_call: gtfn_ir.FunCall, nbh_iter: gtfn_ir_common.SymRef -) -> gtfn_ir.FunCall: - return gtfn_ir.FunCall( - fun=gtfn_ir_common.SymRef(id="deref"), - args=[ - gtfn_ir.FunCall( - fun=gtfn_ir_common.SymRef(id="shift"), args=[*shift_call.args, nbh_iter] - ) - ], - ) - - -def _make_sparse_acess( - field_ref: gtfn_ir_common.SymRef, nbh_iter: gtfn_ir_common.SymRef -) -> gtfn_ir.FunCall: - return gtfn_ir.FunCall( - fun=gtfn_ir_common.SymRef(id="tuple_get"), - args=[nbh_iter, gtfn_ir.FunCall(fun=gtfn_ir_common.SymRef(id="deref"), args=[field_ref])], - ) - - class PlugInCurrentIdx(NodeTranslator): def visit_SymRef( self, node: gtfn_ir_common.SymRef @@ -225,32 +179,6 @@ def _expand_symref( ) self.imp_list_ir.append(AssignStmt(lhs=gtfn_ir_common.SymRef(id=red_idx), rhs=rhs)) - def handle_Reduction(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common.SymRef: - offset_provider = kwargs["offset_provider"] - assert offset_provider is not None - - connectivity = _get_connectivity(node, offset_provider) - - args = node.args - # do the following transformations to the node arguments - # dense fields: shift(dense_f, X2Y) -> deref(shift(dense_f, X2Y, nbh_iterator) - # sparse_fields: sparse_f -> tuple_get(nbh_iterator, deref(sparse_f))) - new_args = [] - nbh_iter = gtfn_ir_common.SymRef(id="nbh_iter") - for arg in args: - if isinstance(arg, gtfn_ir.FunCall) and arg.fun.id == "shift": # type: ignore - new_args.append(_make_dense_acess(arg, nbh_iter)) - if isinstance(arg, gtfn_ir_common.SymRef): - new_args.append(_make_sparse_acess(arg, nbh_iter)) - - red_idx = self.uids.sequential_id(prefix="red") - if isinstance(node.fun.args[0], gtfn_ir.Lambda): # type: ignore - self._expand_lambda(node, new_args, red_idx, connectivity.max_neighbors, **kwargs) - elif isinstance(node.fun.args[0], gtfn_ir_common.SymRef): # type: ignore - self._expand_symref(node, new_args, red_idx, connectivity.max_neighbors, **kwargs) - - return gtfn_ir_common.SymRef(id=red_idx) - def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common.Expr: if any(isinstance(arg, gtfn_ir.Lambda) for arg in node.args): # do not try to lower constructs that take lambdas as argument to something more readable @@ -278,7 +206,9 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common. self.imp_list_ir.append(InitStmt(lhs=gtfn_ir_common.Sym(id=f"{lam_idx}"), rhs=expr)) return gtfn_ir_common.SymRef(id=f"{lam_idx}") if _is_reduce(node): - return self.handle_Reduction(node, **kwargs) + raise AssertionError( + "Not implemented. The code-path was removed as it was not actively used and tested." + ) if isinstance(node.fun, gtfn_ir_common.SymRef) and node.fun.id == "make_tuple": tupl_id = self.uids.sequential_id(prefix="tupl") tuple_fun = self.commit_args(node, tupl_id, "make_tuple", **kwargs) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index ce459f7970..f1649112a7 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -82,7 +82,7 @@ def _process_regular_arguments( self, program: itir.FencilDefinition | itir.Program, arg_types: tuple[ts.TypeSpec, ...], - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, ) -> tuple[list[interface.Parameter], list[str]]: parameters: list[interface.Parameter] = [] arg_exprs: list[str] = [] @@ -104,22 +104,22 @@ def _process_regular_arguments( ): # translate sparse dimensions to tuple dtype dim_name = dim.value - connectivity = offset_provider[dim_name] - assert isinstance(connectivity, common.Connectivity) + connectivity = offset_provider_type[dim_name] + assert isinstance(connectivity, common.NeighborConnectivityType) size = connectivity.max_neighbors arg = f"gridtools::sid::dimension_to_tuple_like({arg})" arg_exprs.append(arg) return parameters, arg_exprs def _process_connectivity_args( - self, offset_provider: dict[str, common.Connectivity | common.Dimension] + self, offset_provider_type: common.OffsetProviderType ) -> tuple[list[interface.Parameter], list[str]]: parameters: list[interface.Parameter] = [] arg_exprs: list[str] = [] - for name, connectivity in offset_provider.items(): - if isinstance(connectivity, common.Connectivity): - if connectivity.index_type not in [np.int32, np.int64]: + for name, connectivity_type in offset_provider_type.items(): + if isinstance(connectivity_type, common.NeighborConnectivityType): + if connectivity_type.dtype.scalar_type not in [np.int32, np.int64]: raise ValueError( "Neighbor table indices must be of type 'np.int32' or 'np.int64'." ) @@ -129,15 +129,8 @@ def _process_connectivity_args( interface.Parameter( name=GENERATED_CONNECTIVITY_PARAM_PREFIX + name.lower(), type_=ts.FieldType( - dims=[ - connectivity.origin_axis, - common.Dimension( - name, kind=common.DimensionKind.LOCAL - ), # TODO(havogt): we should not use the name of the offset as the name of the local dimension - ], - dtype=ts.ScalarType( - type_translation.get_scalar_kind(connectivity.index_type) - ), + dims=list(connectivity_type.domain), + dtype=type_translation.from_dtype(connectivity_type.dtype), ), ) ) @@ -145,19 +138,19 @@ def _process_connectivity_args( # connectivity argument expression nbtbl = ( f"gridtools::fn::sid_neighbor_table::as_neighbor_table<" - f"generated::{connectivity.origin_axis.value}_t, " - f"generated::{name}_t, {connectivity.max_neighbors}" + f"generated::{connectivity_type.source_dim.value}_t, " + f"generated::{name}_t, {connectivity_type.max_neighbors}" f">(std::forward({GENERATED_CONNECTIVITY_PARAM_PREFIX}{name.lower()}))" ) arg_exprs.append( f"gridtools::hymap::keys::make_values({nbtbl})" ) - elif isinstance(connectivity, common.Dimension): + elif isinstance(connectivity_type, common.Dimension): pass else: raise AssertionError( - f"Expected offset provider '{name}' to be a 'Connectivity' or 'Dimension', " - f"got '{type(connectivity).__name__}'." + f"Expected offset provider type '{name}' to be a 'NeighborConnectivityType' or 'Dimension', " + f"got '{type(connectivity_type).__name__}'." ) return parameters, arg_exprs @@ -165,7 +158,7 @@ def _process_connectivity_args( def _preprocess_program( self, program: itir.FencilDefinition | itir.Program, - offset_provider: dict[str, common.Connectivity | common.Dimension], + offset_provider: common.OffsetProvider, ) -> itir.Program: apply_common_transforms = functools.partial( pass_manager.apply_common_transforms, @@ -194,7 +187,7 @@ def _preprocess_program( def generate_stencil_source( self, program: itir.FencilDefinition | itir.Program, - offset_provider: dict[str, common.Connectivity | common.Dimension], + offset_provider: common.OffsetProvider, column_axis: Optional[common.Dimension], ) -> str: if self.enable_itir_transforms: @@ -204,7 +197,9 @@ def generate_stencil_source( new_program = program gtfn_ir = GTFN_lowering.apply( - new_program, offset_provider=offset_provider, column_axis=column_axis + new_program, + offset_provider_type=common.offset_provider_to_type(offset_provider), + column_axis=column_axis, ) if self.use_imperative_backend: @@ -224,13 +219,13 @@ def __call__( # handle regular parameters and arguments of the program (i.e. what the user defined in # the program) regular_parameters, regular_args_expr = self._process_regular_arguments( - program, inp.args.args, inp.args.offset_provider + program, inp.args.args, inp.args.offset_provider_type ) # handle connectivity parameters and arguments (i.e. what the user provided in the offset # provider) connectivity_parameters, connectivity_args_expr = self._process_connectivity_args( - inp.args.offset_provider + inp.args.offset_provider_type ) # combine into a format that is aligned with what the backend expects diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index bc2bd645e8..129d81d6f9 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -159,7 +159,7 @@ def _collect_dimensions_from_domain( def _collect_offset_definitions( node: itir.Node, grid_type: common.GridType, - offset_provider: dict[str, common.Dimension | common.Connectivity], + offset_provider_type: common.OffsetProviderType, ) -> dict[str, TagDefinition]: used_offset_tags: set[itir.OffsetLiteral] = ( node.walk_values() @@ -167,13 +167,13 @@ def _collect_offset_definitions( .filter(lambda offset_literal: isinstance(offset_literal.value, str)) .getattr("value") ).to_set() - if not used_offset_tags.issubset(set(offset_provider.keys())): + if not used_offset_tags.issubset(set(offset_provider_type.keys())): raise AssertionError("ITIR contains an offset tag without a corresponding offset provider.") offset_definitions = {} - for offset_name, dim_or_connectivity in offset_provider.items(): - if isinstance(dim_or_connectivity, common.Dimension): - dim: common.Dimension = dim_or_connectivity + for offset_name, dim_or_connectivity_type in offset_provider_type.items(): + if isinstance(dim_or_connectivity_type, common.Dimension): + dim: common.Dimension = dim_or_connectivity_type if grid_type == common.GridType.CARTESIAN: # create alias from offset to dimension offset_definitions[dim.value] = TagDefinition(name=Sym(id=dim.value)) @@ -201,12 +201,13 @@ def _collect_offset_definitions( offset_definitions[offset_name] = TagDefinition( name=Sym(id=offset_name), alias=SymRef(id=dim.value) ) - elif isinstance(dim_or_connectivity, common.Connectivity): + elif isinstance( + connectivity_type := dim_or_connectivity_type, common.NeighborConnectivityType + ): assert grid_type == common.GridType.UNSTRUCTURED offset_definitions[offset_name] = TagDefinition(name=Sym(id=offset_name)) - connectivity: common.Connectivity = dim_or_connectivity - for dim in [connectivity.origin_axis, connectivity.neighbor_axis]: + for dim in [connectivity_type.source_dim, connectivity_type.codomain]: if dim.kind != common.DimensionKind.HORIZONTAL: raise NotImplementedError() offset_definitions[dim.value] = TagDefinition( @@ -323,7 +324,7 @@ class GTFN_lowering(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): } _unary_op_map: ClassVar[dict[str, str]] = {"not_": "!"} - offset_provider: dict + offset_provider_type: common.OffsetProviderType column_axis: Optional[common.Dimension] grid_type: common.GridType @@ -338,18 +339,18 @@ def apply( cls, node: itir.Program, *, - offset_provider: dict, + offset_provider_type: common.OffsetProviderType, column_axis: Optional[common.Dimension], ) -> Program: if not isinstance(node, itir.Program): raise TypeError(f"Expected a 'Program', got '{type(node).__name__}'.") - node = itir_type_inference.infer(node, offset_provider=offset_provider) + node = itir_type_inference.infer(node, offset_provider_type=offset_provider_type) grid_type = _get_gridtype(node.body) if grid_type == common.GridType.UNSTRUCTURED: node = _CannonicalizeUnstructuredDomain.apply(node) return cls( - offset_provider=offset_provider, column_axis=column_axis, grid_type=grid_type + offset_provider_type=offset_provider_type, column_axis=column_axis, grid_type=grid_type ).visit(node) def visit_Sym(self, node: itir.Sym, **kwargs: Any) -> Sym: @@ -484,8 +485,8 @@ def _visit_unstructured_domain(self, node: itir.FunCall, **kwargs: Any) -> Node: if "stencil" in kwargs: shift_offsets = self._collect_offset_or_axis_node(itir.OffsetLiteral, kwargs["stencil"]) for o in shift_offsets: - if o in self.offset_provider and isinstance( - self.offset_provider[o], common.Connectivity + if o in self.offset_provider_type and isinstance( + self.offset_provider_type[o], common.NeighborConnectivityType ): connectivities.append(SymRef(id=o)) return UnstructuredDomain( @@ -679,7 +680,7 @@ def visit_Program(self, node: itir.Program, **kwargs: Any) -> Program: function_definitions = self.visit(node.function_definitions) + extracted_functions offset_definitions = { **_collect_dimensions_from_domain(node.body), - **_collect_offset_definitions(node, self.grid_type, self.offset_provider), + **_collect_offset_definitions(node, self.grid_type, self.offset_provider_type), } return Program( id=SymbolName(node.id), diff --git a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py index db0df7d121..56ba08015b 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py @@ -12,6 +12,7 @@ import dace import numpy as np +from gt4py._core import definitions as core_defs from gt4py.next import common as gtx_common, utils as gtx_utils from . import utility as dace_utils @@ -65,8 +66,8 @@ def _get_args( def _ensure_is_on_device( - connectivity_arg: np.typing.NDArray, device: dace.dtypes.DeviceType -) -> np.typing.NDArray: + connectivity_arg: core_defs.NDArrayObject, device: dace.dtypes.DeviceType +) -> core_defs.NDArrayObject: if device == dace.dtypes.DeviceType.GPU: if not isinstance(connectivity_arg, cp.ndarray): warnings.warn( @@ -78,7 +79,7 @@ def _ensure_is_on_device( def _get_shape_args( - arrays: Mapping[str, dace.data.Array], args: Mapping[str, np.typing.NDArray] + arrays: Mapping[str, dace.data.Array], args: Mapping[str, core_defs.NDArrayObject] ) -> dict[str, int]: shape_args: dict[str, int] = {} for name, value in args.items(): @@ -103,7 +104,7 @@ def _get_shape_args( def _get_stride_args( - arrays: Mapping[str, dace.data.Array], args: Mapping[str, np.typing.NDArray] + arrays: Mapping[str, dace.data.Array], args: Mapping[str, core_defs.NDArrayObject] ) -> dict[str, int]: stride_args = {} for name, value in args.items(): @@ -134,7 +135,7 @@ def get_sdfg_conn_args( sdfg: dace.SDFG, offset_provider: gtx_common.OffsetProvider, on_gpu: bool, -) -> dict[str, np.typing.NDArray]: +) -> dict[str, core_defs.NDArrayObject]: """ Extracts the connectivity tables that are used in the sdfg and ensures that the memory buffers are allocated for the target device. @@ -142,11 +143,11 @@ def get_sdfg_conn_args( device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU connectivity_args = {} - for offset, connectivity in dace_utils.filter_connectivities(offset_provider).items(): - assert isinstance(connectivity, gtx_common.NeighborTable) - param = dace_utils.connectivity_identifier(offset) - if param in sdfg.arrays: - connectivity_args[param] = _ensure_is_on_device(connectivity.table, device) + for offset, connectivity in offset_provider.items(): + if gtx_common.is_neighbor_table(connectivity): + param = dace_utils.connectivity_identifier(offset) + if param in sdfg.arrays: + connectivity_args[param] = _ensure_is_on_device(connectivity.ndarray, device) return connectivity_args diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index bc01e2abda..29395a30c1 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -79,19 +79,18 @@ def debug_info( return default -def filter_connectivities( - offset_provider: gtx_common.OffsetProvider, -) -> dict[str, gtx_common.Connectivity]: +def filter_connectivity_types( + offset_provider_type: gtx_common.OffsetProviderType, +) -> dict[str, gtx_common.NeighborConnectivityType]: """ - Filter offset providers of type `Connectivity`. + Filter offset provider types of type `NeighborConnectivityType`. In other words, filter out the cartesian offset providers. - Returns a new dictionary containing only `Connectivity` values. """ return { - offset: table - for offset, table in offset_provider.items() - if isinstance(table, gtx_common.Connectivity) + offset: conn + for offset, conn in offset_provider_type.items() + if isinstance(conn, gtx_common.NeighborConnectivityType) } diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 73b6e2ed4c..74142dec66 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -527,14 +527,14 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: assert isinstance(node.args[0], gtir.OffsetLiteral) offset = node.args[0].value assert isinstance(offset, str) - offset_provider = self.subgraph_builder.get_offset_provider(offset) - assert isinstance(offset_provider, gtx_common.Connectivity) + offset_provider = self.subgraph_builder.get_offset_provider_type(offset) + assert isinstance(offset_provider, gtx_common.NeighborConnectivityType) it = self.visit(node.args[1]) assert isinstance(it, IteratorExpr) - assert offset_provider.neighbor_axis in it.dimensions - assert offset_provider.origin_axis in it.indices - origin_index = it.indices[offset_provider.origin_axis] + assert offset_provider.codomain in it.dimensions + assert offset_provider.source_dim in it.indices + origin_index = it.indices[offset_provider.source_dim] assert isinstance(origin_index, SymbolExpr) assert all(isinstance(index, SymbolExpr) for index in it.indices.values()) @@ -561,7 +561,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: subset=sbs.Range.from_string( ",".join( it.indices[dim].value # type: ignore[union-attr] - if dim != offset_provider.neighbor_axis + if dim != offset_provider.codomain else f"0:{size}" for dim, size in zip(it.dimensions, field_desc.shape, strict=True) ) @@ -657,7 +657,9 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: tasklet_expression = f"{output_connector} = {fun_python_code}" input_args = [self.visit(arg) for arg in node.args] - input_connectivities: dict[gtx_common.Dimension, gtx_common.Connectivity] = {} + input_connectivity_types: dict[ + gtx_common.Dimension, gtx_common.NeighborConnectivityType + ] = {} for input_arg in input_args: assert isinstance(input_arg.gt_dtype, itir_ts.ListType) assert input_arg.gt_dtype.offset_type is not None @@ -665,11 +667,11 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: if offset_type == _CONST_DIM: # this input argument is the result of `make_const_list` continue - offset_provider = self.subgraph_builder.get_offset_provider(offset_type.value) - assert isinstance(offset_provider, gtx_common.Connectivity) - input_connectivities[offset_type] = offset_provider + offset_provider_t = self.subgraph_builder.get_offset_provider_type(offset_type.value) + assert isinstance(offset_provider_t, gtx_common.NeighborConnectivityType) + input_connectivity_types[offset_type] = offset_provider_t - if len(input_connectivities) == 0: + if len(input_connectivity_types) == 0: raise ValueError(f"Missing information on local dimension for map node {node}.") # GT4Py guarantees that all connectivities used to generate lists of neighbors @@ -678,14 +680,14 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: len( set( (conn.has_skip_values, conn.max_neighbors) - for conn in input_connectivities.values() + for conn in input_connectivity_types.values() ) ) != 1 ): raise ValueError("Unexpected arguments to map expression with different neighborhood.") - offset_type, offset_provider = next(iter(input_connectivities.items())) - local_size = offset_provider.max_neighbors + offset_type, offset_provider_type = next(iter(input_connectivity_types.items())) + local_size = offset_provider_type.max_neighbors map_index = dace_gtir_utils.get_map_variable(offset_type) # The dataflow we build in this class has some loose connections on input edges. @@ -717,14 +719,14 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: result, _ = self.sdfg.add_temp_transient((local_size,), dc_dtype) result_node = self.state.add_access(result) - if offset_provider.has_skip_values: + if offset_provider_type.has_skip_values: # In case the `map_` input expressions contain skip values, we use # the connectivity-based offset provider as mask for map computation. connectivity = dace_utils.connectivity_identifier(offset_type.value) connectivity_desc = self.sdfg.arrays[connectivity] connectivity_desc.transient = False - origin_map_index = dace_gtir_utils.get_map_variable(offset_provider.origin_axis) + origin_map_index = dace_gtir_utils.get_map_variable(offset_provider_type.source_dim) connectivity_slice = self._construct_local_view( MemletExpr( @@ -733,7 +735,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: element_type=node.type.element_type, offset_type=offset_type ), subset=sbs.Range.from_string( - f"{origin_map_index}, 0:{offset_provider.max_neighbors}" + f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" ), ) ) @@ -774,7 +776,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: def _make_reduce_with_skip_values( self, input_expr: ValueExpr | MemletExpr, - offset_provider: gtx_common.Connectivity, + offset_provider_type: gtx_common.NeighborConnectivityType, reduce_init: SymbolExpr, reduce_identity: SymbolExpr, reduce_wcr: str, @@ -792,7 +794,7 @@ def _make_reduce_with_skip_values( corresponding neighbor index in the connectivity table is valid, or the identity value if the neighbor index is missing. """ - origin_map_index = dace_gtir_utils.get_map_variable(offset_provider.origin_axis) + origin_map_index = dace_gtir_utils.get_map_variable(offset_provider_type.source_dim) assert ( isinstance(input_expr.gt_dtype, itir_ts.ListType) @@ -815,7 +817,7 @@ def _make_reduce_with_skip_values( f"Found {len(local_dim_indices)} local dimensions in reduce expression, expected one." ) local_dim_index = local_dim_indices[0] - assert desc.shape[local_dim_index] == offset_provider.max_neighbors + assert desc.shape[local_dim_index] == offset_provider_type.max_neighbors # we lower the reduction map with WCR out memlet in a nested SDFG nsdfg = dace.SDFG(name=self.unique_nsdfg_name("reduce_with_skip_values")) @@ -853,7 +855,7 @@ def _make_reduce_with_skip_values( # TODO(phimuell): decide if auto-optimizer should reset `wcr_nonatomic` properties, as DaCe does. st_reduce.add_mapped_tasklet( name="reduce_with_skip_values", - map_ranges={"i": f"0:{offset_provider.max_neighbors}"}, + map_ranges={"i": f"0:{offset_provider_type.max_neighbors}"}, inputs={ "__val": dace.Memlet(data="values", subset="i"), "__neighbor_idx": dace.Memlet(data="neighbor_indices", subset="i"), @@ -882,7 +884,7 @@ def _make_reduce_with_skip_values( ) self._add_input_data_edge( connectivity_node, - sbs.Range.from_string(f"{origin_map_index}, 0:{offset_provider.max_neighbors}"), + sbs.Range.from_string(f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}"), nsdfg_node, "neighbor_indices", ) @@ -910,12 +912,17 @@ def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: and input_expr.gt_dtype.offset_type is not None ) offset_type = input_expr.gt_dtype.offset_type - offset_provider = self.subgraph_builder.get_offset_provider(offset_type.value) - assert isinstance(offset_provider, gtx_common.Connectivity) + offset_provider_type = self.subgraph_builder.get_offset_provider_type(offset_type.value) + assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType) - if offset_provider.has_skip_values: + if offset_provider_type.has_skip_values: self._make_reduce_with_skip_values( - input_expr, offset_provider, reduce_init, reduce_identity, reduce_wcr, result_node + input_expr, + offset_provider_type, + reduce_init, + reduce_identity, + reduce_wcr, + result_node, ) else: @@ -1082,16 +1089,16 @@ def _make_dynamic_neighbor_offset( def _make_unstructured_shift( self, it: IteratorExpr, - connectivity: gtx_common.Connectivity, + connectivity: gtx_common.NeighborConnectivityType, offset_table_node: dace.nodes.AccessNode, offset_expr: DataExpr, ) -> IteratorExpr: """Implements shift in unstructured domain by means of a neighbor table.""" - assert connectivity.neighbor_axis in it.dimensions - neighbor_dim = connectivity.neighbor_axis + assert connectivity.codomain in it.dimensions + neighbor_dim = connectivity.codomain assert neighbor_dim not in it.indices - origin_dim = connectivity.origin_axis + origin_dim = connectivity.source_dim assert origin_dim in it.indices origin_index = it.indices[origin_dim] assert isinstance(origin_index, SymbolExpr) @@ -1132,7 +1139,7 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: assert isinstance(offset_provider_arg, gtir.OffsetLiteral) offset = offset_provider_arg.value assert isinstance(offset, str) - offset_provider = self.subgraph_builder.get_offset_provider(offset) + offset_provider_type = self.subgraph_builder.get_offset_provider_type(offset) # second argument should be the offset value, which could be a symbolic expression or a dynamic offset offset_expr = ( SymbolExpr(offset_value_arg.value, IndexDType) @@ -1140,8 +1147,8 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: else self.visit(offset_value_arg) ) - if isinstance(offset_provider, gtx_common.Dimension): - return self._make_cartesian_shift(it, offset_provider, offset_expr) + if isinstance(offset_provider_type, gtx_common.Dimension): + return self._make_cartesian_shift(it, offset_provider_type, offset_expr) else: # initially, the storage for the connectivity tables is created as transient; # when the tables are used, the storage is changed to non-transient, @@ -1151,7 +1158,7 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: offset_table_node = self.state.add_access(offset_table) return self._make_unstructured_shift( - it, offset_provider, offset_table_node, offset_expr + it, offset_provider_type, offset_table_node, offset_expr ) def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index ad8f490f12..52284edfac 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -41,7 +41,7 @@ class DataflowBuilder(Protocol): """Visitor interface to build a dataflow subgraph.""" @abc.abstractmethod - def get_offset_provider(self, offset: str) -> gtx_common.OffsetProviderElem: ... + def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderTypeElem: ... @abc.abstractmethod def unique_nsdfg_name(self, sdfg: dace.SDFG, prefix: str) -> str: ... @@ -155,7 +155,7 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): from where to continue building the SDFG. """ - offset_provider: gtx_common.OffsetProvider + offset_provider_type: gtx_common.OffsetProviderType global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=lambda: {}) map_uids: eve.utils.UIDGenerator = dataclasses.field( init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map") @@ -164,8 +164,8 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="tlet") ) - def get_offset_provider(self, offset: str) -> gtx_common.OffsetProviderElem: - return self.offset_provider[offset] + def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderTypeElem: + return self.offset_provider_type[offset] def get_symbol_type(self, symbol_name: str) -> ts.DataType: return self.global_symbols[symbol_name] @@ -195,10 +195,10 @@ def _make_array_shape_and_strides( Two lists of symbols, one for the shape and the other for the strides of the array. """ dc_dtype = gtir_builtin_translators.INDEX_DTYPE - neighbor_tables = dace_utils.filter_connectivities(self.offset_provider) + neighbor_table_types = dace_utils.filter_connectivity_types(self.offset_provider_type) shape = [ ( - neighbor_tables[dim.value].max_neighbors + neighbor_table_types[dim.value].max_neighbors if dim.kind == gtx_common.DimensionKind.LOCAL else dace.symbol(dace_utils.field_size_symbol_name(name, i), dc_dtype) ) @@ -374,13 +374,12 @@ def _add_sdfg_params( self.global_symbols[pname] = param.type # add SDFG storage for connectivity tables - for offset, offset_provider in dace_utils.filter_connectivities( - self.offset_provider + for offset, connectivity_type in dace_utils.filter_connectivity_types( + self.offset_provider_type ).items(): - scalar_kind = tt.get_scalar_kind(offset_provider.index_type) - local_dim = gtx_common.Dimension(offset, kind=gtx_common.DimensionKind.LOCAL) + scalar_type = tt.from_dtype(connectivity_type.dtype) gt_type = ts.FieldType( - [offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind) + [connectivity_type.source_dim, connectivity_type.neighbor_dim], scalar_type ) # We store all connectivity tables as transient arrays here; later, while building # the field operator expressions, we change to non-transient (i.e. allocated externally) @@ -585,7 +584,7 @@ def visit_Lambda( } # lower let-statement lambda node as a nested SDFG - lambda_translator = GTIRToSDFG(self.offset_provider, lambda_symbols) + lambda_translator = GTIRToSDFG(self.offset_provider_type, lambda_symbols) nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) nstate = nsdfg.add_state("lambda") @@ -630,7 +629,7 @@ def _flatten_tuples( ) connectivity_arrays = { dace_utils.connectivity_identifier(offset) - for offset in dace_utils.filter_connectivities(self.offset_provider) + for offset in dace_utils.filter_connectivity_types(self.offset_provider_type) } input_memlets = {} @@ -778,7 +777,7 @@ def visit_SymRef( def build_sdfg_from_gtir( ir: gtir.Program, - offset_provider: gtx_common.OffsetProvider, + offset_provider_type: gtx_common.OffsetProviderType, ) -> dace.SDFG: """ Receives a GTIR program and lowers it to a DaCe SDFG. @@ -788,15 +787,15 @@ def build_sdfg_from_gtir( Args: ir: The GTIR program node to be lowered to SDFG - offset_provider: The definitions of offset providers used by the program node + offset_provider_type: The definitions of offset providers used by the program node Returns: An SDFG in the DaCe canonical form (simplified) """ - ir = gtir_type_inference.infer(ir, offset_provider=offset_provider) + ir = gtir_type_inference.infer(ir, offset_provider_type=offset_provider_type) ir = ir_prune_casts.PruneCasts().visit(ir) - sdfg_genenerator = GTIRToSDFG(offset_provider) + sdfg_genenerator = GTIRToSDFG(offset_provider_type) sdfg = sdfg_genenerator.visit(ir) assert isinstance(sdfg, dace.SDFG) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index aa4fd0cd3e..40d44f5ab0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -52,7 +52,9 @@ def generate_sdfg( on_gpu: bool, ) -> dace.SDFG: ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider) - sdfg = gtir_sdfg.build_sdfg_from_gtir(ir, offset_provider=offset_provider) + sdfg = gtir_sdfg.build_sdfg_from_gtir( + ir, offset_provider_type=common.offset_provider_to_type(offset_provider) + ) if auto_opt: gtx_transformations.gt_auto_optimize(sdfg, gpu=on_gpu) @@ -75,7 +77,7 @@ def __call__( sdfg = self.generate_sdfg( program, - inp.args.offset_provider, + inp.args.offset_provider, # TODO(havogt): should be offset_provider_type once the transformation don't require run-time info inp.args.column_axis, auto_opt=self.auto_optimize, on_gpu=(self.device_type == gtx_allocators.CUPY_DEVICE), diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index fc2772027e..ef09cf51cd 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -9,7 +9,7 @@ import dataclasses import warnings from collections import OrderedDict -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Sequence from dataclasses import field from inspect import currentframe, getframeinfo from pathlib import Path @@ -38,7 +38,7 @@ def preprocess_program( program: itir.FencilDefinition, - offset_provider: Mapping[str, Any], + offset_provider_type: common.OffsetProviderType, lift_mode: legacy_itir_transforms.LiftMode, symbolic_domain_sizes: Optional[dict[str, str]] = None, temporary_extraction_heuristics: Optional[ @@ -51,13 +51,13 @@ def preprocess_program( common_subexpression_elimination=False, force_inline_lambda_args=True, lift_mode=lift_mode, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, symbolic_domain_sizes=symbolic_domain_sizes, temporary_extraction_heuristics=temporary_extraction_heuristics, unroll_reduce=unroll_reduce, ) - node = itir_type_inference.infer(node, offset_provider=offset_provider) + node = itir_type_inference.infer(node, offset_provider_type=offset_provider_type) if isinstance(node, itir.Program): fencil_definition = program_to_fencil.program_to_fencil(node) @@ -72,7 +72,7 @@ def preprocess_program( def build_sdfg_from_itir( program: itir.FencilDefinition, arg_types: Sequence[ts.TypeSpec], - offset_provider: dict[str, Any], + offset_provider_type: common.OffsetProviderType, auto_optimize: bool = False, on_gpu: bool = False, column_axis: Optional[common.Dimension] = None, @@ -109,10 +109,18 @@ def build_sdfg_from_itir( # visit ITIR and generate SDFG program, tmps = preprocess_program( - program, offset_provider, lift_mode, symbolic_domain_sizes, temporary_extraction_heuristics + program, + offset_provider_type, + lift_mode, + symbolic_domain_sizes, + temporary_extraction_heuristics, ) sdfg_genenerator = ItirToSDFG( - list(arg_types), offset_provider, tmps, use_field_canonical_representation, column_axis + list(arg_types), + offset_provider_type, + tmps, + use_field_canonical_representation, + column_axis, ) sdfg = sdfg_genenerator.visit(program) if sdfg is None: @@ -186,14 +194,12 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: raise ValueError( "[DaCe Orchestration] Connectivities -at compile time- are required to generate the SDFG. Use `with_connectivities` method." ) - offset_provider = ( - self.connectivities | self._implicit_offset_provider - ) # tables are None at this point + offset_provider_type = {**self.connectivities, **self._implicit_offset_provider} sdfg = self.backend.executor.step.translation.generate_sdfg( # type: ignore[union-attr] self.itir, arg_types, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, column_axis=kwargs.get("column_axis", None), ) self.sdfg_closure_vars["sdfg.arrays"] = sdfg.arrays # use it in __sdfg_closure__ @@ -238,7 +244,7 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: sdfg.offset_providers_per_input_field = {} itir_tmp = legacy_itir_transforms.apply_common_transforms( - self.itir, offset_provider=offset_provider + self.itir, offset_provider_type=offset_provider_type ) itir_tmp_fencil = program_to_fencil.program_to_fencil(itir_tmp) for closure in itir_tmp_fencil.closures: @@ -267,7 +273,7 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ the offset providers are not part of GT4Py Program's arguments. Keep in mind, that `__sdfg_closure__` is called after `__sdfg__` method. """ - offset_provider = self.connectivities + offset_provider_type = self.connectivities # Define DaCe symbols connectivity_table_size_symbols = { @@ -276,9 +282,9 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ ): dace.symbol( dace_utils.field_size_symbol_name(dace_utils.connectivity_identifier(k), axis) ) - for k, v in offset_provider.items() # type: ignore[union-attr] + for k, v in offset_provider_type.items() # type: ignore[union-attr] for axis in [0, 1] - if hasattr(v, "table") + if isinstance(v, common.NeighborConnectivityType) and dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"] } @@ -288,9 +294,9 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ ): dace.symbol( dace_utils.field_stride_symbol_name(dace_utils.connectivity_identifier(k), axis) ) - for k, v in offset_provider.items() # type: ignore[union-attr] + for k, v in offset_provider_type.items() # type: ignore[union-attr] for axis in [0, 1] - if hasattr(v, "table") + if isinstance(v, common.NeighborConnectivityType) and dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"] } @@ -298,8 +304,8 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ # Define the storage location (e.g. CPU, GPU) of the connectivity tables if "storage" not in Program.connectivity_tables_data_descriptors: - for k, v in offset_provider.items(): # type: ignore[union-attr] - if not hasattr(v, "table"): + for k, v in offset_provider_type.items(): # type: ignore[union-attr] + if not isinstance(v, common.NeighborConnectivityType): continue if dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"]: Program.connectivity_tables_data_descriptors["storage"] = ( @@ -311,12 +317,15 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ # Build the closure dictionary closure_dict = {} - for k, v in offset_provider.items(): # type: ignore[union-attr] + for k, v in offset_provider_type.items(): # type: ignore[union-attr] conn_id = dace_utils.connectivity_identifier(k) - if hasattr(v, "table") and conn_id in self.sdfg_closure_vars["sdfg.arrays"]: + if ( + isinstance(v, common.NeighborConnectivityType) + and conn_id in self.sdfg_closure_vars["sdfg.arrays"] + ): if conn_id not in Program.connectivity_tables_data_descriptors: Program.connectivity_tables_data_descriptors[conn_id] = dace.data.Array( - dtype=dace.int64 if v.index_type == np.int64 else dace.int32, + dtype=dace.int64 if v.dtype.scalar_type == np.int64 else dace.int32, shape=[ symbols[dace_utils.field_size_symbol_name(conn_id, 0)], symbols[dace_utils.field_size_symbol_name(conn_id, 1)], diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index a0f4b83d35..823943cfd5 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -7,14 +7,13 @@ # SPDX-License-Identifier: BSD-3-Clause import warnings -from typing import Any, Mapping, Optional, Sequence, cast +from typing import Optional, Sequence, cast import dace from dace.sdfg.state import LoopRegion import gt4py.eve as eve -from gt4py.next import Dimension, DimensionKind -from gt4py.next.common import Connectivity +from gt4py.next import Dimension, DimensionKind, common from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef @@ -91,7 +90,10 @@ def _get_scan_dim( def _make_array_shape_and_strides( - name: str, dims: Sequence[Dimension], offset_provider: Mapping[str, Any], sort_dims: bool + name: str, + dims: Sequence[Dimension], + offset_provider_type: common.OffsetProviderType, + sort_dims: bool, ) -> tuple[list[dace.symbol], list[dace.symbol]]: """ Parse field dimensions and allocate symbols for array shape and strides. @@ -106,10 +108,10 @@ def _make_array_shape_and_strides( """ dtype = dace.dtype_to_typeclass(gtx_fbuiltins.IndexType) sorted_dims = dace_utils.get_sorted_dims(dims) if sort_dims else list(enumerate(dims)) - neighbor_tables = dace_utils.filter_connectivities(offset_provider) + connectivity_types = dace_utils.filter_connectivity_types(offset_provider_type) shape = [ ( - neighbor_tables[dim.value].max_neighbors + connectivity_types[dim.value].max_neighbors if dim.kind == DimensionKind.LOCAL # we reuse the same gt4py symbol for field size passed as scalar argument which is used in closure domain else dace.symbol(dace_utils.field_size_symbol_name(name, i), dtype) @@ -144,21 +146,21 @@ class ItirToSDFG(eve.NodeVisitor): param_types: list[ts.TypeSpec] storage_types: dict[str, ts.TypeSpec] column_axis: Optional[Dimension] - offset_provider: dict[str, Any] + offset_provider_type: common.OffsetProviderType unique_id: int use_field_canonical_representation: bool def __init__( self, param_types: list[ts.TypeSpec], - offset_provider: dict[str, Connectivity | Dimension], + offset_provider_type: common.OffsetProviderType, tmps: list[itir.Temporary], use_field_canonical_representation: bool, column_axis: Optional[Dimension] = None, ): self.param_types = param_types self.column_axis = column_axis - self.offset_provider = offset_provider + self.offset_provider_type = offset_provider_type self.storage_types = {} self.tmps = tmps self.use_field_canonical_representation = use_field_canonical_representation @@ -166,7 +168,7 @@ def __init__( def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, sort_dimensions: bool): if isinstance(type_, ts.FieldType): shape, strides = _make_array_shape_and_strides( - name, type_.dims, self.offset_provider, sort_dimensions + name, type_.dims, self.offset_provider_type, sort_dimensions ) dtype = dace_utils.as_dace_type(type_.dtype) sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) @@ -255,7 +257,7 @@ def get_output_nodes( # Visit output node again to generate the corresponding tasklet context = Context(sdfg, state, output_symbols_pass.symbol_refs) translator = PythonTaskletCodegen( - self.offset_provider, context, self.use_field_canonical_representation + self.offset_provider_type, context, self.use_field_canonical_representation ) output_nodes = flatten_list(translator.visit(closure.output)) return {node.value.data: node.value for node in output_nodes} @@ -266,7 +268,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): entry_state = program_sdfg.add_state("program_entry", is_start_block=True) # Filter neighbor tables from offset providers. - neighbor_tables = get_used_connectivities(node, self.offset_provider) + connectivity_types = get_used_connectivities(node, self.offset_provider_type) # Add program parameters as SDFG storages. for param, type_ in zip(node.params, self.param_types): @@ -285,11 +287,10 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): last_state = entry_state # Add connectivities as SDFG storages. - for offset, offset_provider in neighbor_tables.items(): - scalar_kind = tt.get_scalar_kind(offset_provider.index_type) - local_dim = Dimension(offset, kind=DimensionKind.LOCAL) + for offset, connectivity_type in connectivity_types.items(): + scalar_type = tt.from_dtype(connectivity_type.dtype) type_ = ts.FieldType( - [offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind) + [connectivity_type.source_dim, connectivity_type.neighbor_dim], scalar_type ) self.add_storage( program_sdfg, @@ -362,7 +363,7 @@ def visit_StencilClosure( isinstance(inp, SymRef) for inp in node.inputs ) # backend only supports SymRef inputs, not `index` calls input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert - neighbor_tables = get_used_connectivities(node, self.offset_provider) + neighbor_tables = get_used_connectivities(node, self.offset_provider_type) connectivity_names = [ dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() ] @@ -568,7 +569,7 @@ def _visit_scan_stencil_closure( ) assert isinstance(node.output, SymRef) - neighbor_tables = get_used_connectivities(node, self.offset_provider) + neighbor_tables = get_used_connectivities(node, self.offset_provider_type) assert all( isinstance(inp, SymRef) for inp in node.inputs ) # backend only supports SymRef inputs, not `index` calls @@ -673,7 +674,7 @@ def _visit_scan_stencil_closure( connectivity_arrays = [(scan_sdfg.arrays[name], name) for name in connectivity_names] lambda_context, lambda_outputs = closure_to_tasklet_sdfg( node, - self.offset_provider, + self.offset_provider_type, lambda_domain, input_arrays, connectivity_arrays, @@ -738,7 +739,7 @@ def _visit_parallel_stencil_closure( tuple[str, tuple[ValueExpr | SymbolExpr, ValueExpr | SymbolExpr]], ... ], ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]: - neighbor_tables = get_used_connectivities(node, self.offset_provider) + neighbor_tables = get_used_connectivities(node, self.offset_provider_type) assert all( isinstance(inp, SymRef) for inp in node.inputs ) # backend only supports SymRef inputs, not `index` calls @@ -762,7 +763,7 @@ def _visit_parallel_stencil_closure( context, results = closure_to_tasklet_sdfg( node, - self.offset_provider, + self.offset_provider_type, index_domain, input_arrays, connectivity_arrays, @@ -788,7 +789,7 @@ def _visit_domain( lower_bound = named_range.args[1] upper_bound = named_range.args[2] translator = PythonTaskletCodegen( - self.offset_provider, + self.offset_provider_type, context, self.use_field_canonical_representation, ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 991053b4a5..2b2669187a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -19,8 +19,8 @@ import gt4py.eve.codegen from gt4py import eve -from gt4py.next import Dimension -from gt4py.next.common import _DEFAULT_SKIP_VALUE as neighbor_skip_value, Connectivity +from gt4py.next import common +from gt4py.next.common import _DEFAULT_SKIP_VALUE as neighbor_skip_value from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir import FunCall, Lambda from gt4py.next.iterator.type_system import type_specifications as it_ts @@ -187,15 +187,15 @@ def _visit_lift_in_neighbors_reduction( transformer: PythonTaskletCodegen, node: itir.FunCall, node_args: Sequence[IteratorExpr | list[ValueExpr]], - offset_provider: Connectivity, + connectivity_type: common.NeighborConnectivityType, map_entry: dace.nodes.MapEntry, map_exit: dace.nodes.MapExit, neighbor_index_node: dace.nodes.AccessNode, neighbor_value_node: dace.nodes.AccessNode, ) -> list[ValueExpr]: assert transformer.context.reduce_identity is not None - neighbor_dim = offset_provider.neighbor_axis.value - origin_dim = offset_provider.origin_axis.value + neighbor_dim = connectivity_type.codomain.value + origin_dim = connectivity_type.source_dim.value lifted_args: list[IteratorExpr | ValueExpr] = [] for arg in node_args: @@ -232,7 +232,7 @@ def _visit_lift_in_neighbors_reduction( assert isinstance(y, ValueExpr) input_nodes[x] = y.value - neighbor_tables = get_used_connectivities(node.args[0], transformer.offset_provider) + neighbor_tables = get_used_connectivities(node.args[0], transformer.offset_provider_type) connectivity_names = [ dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() ] @@ -294,7 +294,7 @@ def _visit_lift_in_neighbors_reduction( memlet=dace.Memlet(data=neighbor_value_node.data, subset=",".join(map_entry.params)), ) - if offset_provider.has_skip_values: + if connectivity_type.has_skip_values: # check neighbor validity on if/else inter-state edge # use one branch for connectivity case start_state = lift_context.body.add_state_before( @@ -333,8 +333,8 @@ def builtin_neighbors( assert isinstance(offset_literal, itir.OffsetLiteral) offset_dim = offset_literal.value assert isinstance(offset_dim, str) - offset_provider = transformer.offset_provider[offset_dim] - if not isinstance(offset_provider, Connectivity): + connectivity_type = transformer.offset_provider_type[offset_dim] + if not isinstance(connectivity_type, common.NeighborConnectivityType): raise NotImplementedError( "Neighbor reduction only implemented for connectivity based on neighbor tables." ) @@ -351,7 +351,7 @@ def builtin_neighbors( iterator = transformer.visit(data) assert isinstance(iterator, IteratorExpr) field_desc = iterator.field.desc(transformer.context.body) - origin_index_node = iterator.indices[offset_provider.origin_axis.value] + origin_index_node = iterator.indices[connectivity_type.source_dim.value] assert transformer.context.reduce_identity is not None assert transformer.context.reduce_identity.dtype == iterator.dtype @@ -361,7 +361,7 @@ def builtin_neighbors( sdfg.add_array( neighbor_value_var, dtype=iterator.dtype, - shape=(offset_provider.max_neighbors,), + shape=(connectivity_type.max_neighbors,), transient=True, ) neighbor_value_node = state.add_access(neighbor_value_var, debuginfo=di) @@ -375,7 +375,7 @@ def builtin_neighbors( neighbor_map_index = unique_name(f"{offset_dim}_neighbor_map_idx") me, mx = state.add_map( f"{offset_dim}_neighbor_map", - ndrange={neighbor_map_index: f"0:{offset_provider.max_neighbors}"}, + ndrange={neighbor_map_index: f"0:{connectivity_type.max_neighbors}"}, debuginfo=di, ) @@ -414,7 +414,7 @@ def builtin_neighbors( transformer, lift_node, lift_args, - offset_provider, + connectivity_type, me, mx, neighbor_index_node, @@ -423,13 +423,13 @@ def builtin_neighbors( else: sorted_dims = transformer.get_sorted_field_dimensions(iterator.dimensions) data_access_index = ",".join(f"{dim}_v" for dim in sorted_dims) - connector_neighbor_dim = f"{offset_provider.neighbor_axis.value}_v" + connector_neighbor_dim = f"{connectivity_type.codomain.value}_v" data_access_tasklet = state.add_tasklet( "data_access", code=f"__data = __field[{data_access_index}] " + ( f"if {connector_neighbor_dim} != {neighbor_skip_value} else {transformer.context.reduce_identity.value}" - if offset_provider.has_skip_values + if connectivity_type.has_skip_values else "" ), inputs={"__field"} | {f"{dim}_v" for dim in iterator.dimensions}, @@ -445,7 +445,7 @@ def builtin_neighbors( ) for dim in iterator.dimensions: connector = f"{dim}_v" - if dim == offset_provider.neighbor_axis.value: + if dim == connectivity_type.codomain.value: state.add_edge( neighbor_index_node, None, @@ -470,7 +470,7 @@ def builtin_neighbors( src_conn="__data", ) - if not offset_provider.has_skip_values: + if not connectivity_type.has_skip_values: return [ValueExpr(neighbor_value_node, iterator.dtype)] else: """ @@ -483,7 +483,7 @@ def builtin_neighbors( sdfg.add_array( neighbor_valid_var, dtype=dace.dtypes.bool, - shape=(offset_provider.max_neighbors,), + shape=(connectivity_type.max_neighbors,), transient=True, ) neighbor_valid_node = state.add_access(neighbor_valid_var, debuginfo=di) @@ -572,7 +572,7 @@ def build_if_state(arg, state): symbol_map = copy.deepcopy(transformer.context.symbol_map) node_context = Context(sdfg, state, symbol_map) node_taskgen = PythonTaskletCodegen( - transformer.offset_provider, + transformer.offset_provider_type, node_context, transformer.use_field_canonical_representation, ) @@ -884,21 +884,12 @@ def visit_SymRef(self, node: itir.SymRef): ) +@dataclasses.dataclass class PythonTaskletCodegen(gt4py.eve.codegen.TemplatedGenerator): - offset_provider: dict[str, Any] + offset_provider_type: common.OffsetProviderType context: Context use_field_canonical_representation: bool - def __init__( - self, - offset_provider: dict[str, Any], - context: Context, - use_field_canonical_representation: bool, - ): - self.offset_provider = offset_provider - self.context = context - self.use_field_canonical_representation = use_field_canonical_representation - def get_sorted_field_dimensions(self, dims: Sequence[str]): return sorted(dims) if self.use_field_canonical_representation else dims @@ -914,7 +905,7 @@ def visit_Lambda( ]: func_name = f"lambda_{abs(hash(node)):x}" neighbor_tables = ( - get_used_connectivities(node, self.offset_provider) if use_neighbor_tables else {} + get_used_connectivities(node, self.offset_provider_type) if use_neighbor_tables else {} ) connectivity_names = [ dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() @@ -974,7 +965,7 @@ def visit_Lambda( reduce_identity=self.context.reduce_identity, ) lambda_taskgen = PythonTaskletCodegen( - self.offset_provider, + self.offset_provider_type, lambda_context, self.use_field_canonical_representation, ) @@ -1066,7 +1057,7 @@ def _visit_call(self, node: itir.FunCall): store, self.context.body.arrays[store] ) - neighbor_tables = get_used_connectivities(node.fun, self.offset_provider) + neighbor_tables = get_used_connectivities(node.fun, self.offset_provider_type) for offset in neighbor_tables.keys(): var = dace_utils.connectivity_identifier(offset) nsdfg_inputs[var] = dace.Memlet.from_array(var, self.context.body.arrays[var]) @@ -1136,12 +1127,13 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: dims_not_indexed = [dim for dim in iterator.dimensions if dim not in iterator.indices] assert len(dims_not_indexed) == 1 offset = dims_not_indexed[0] - offset_provider = self.offset_provider[offset] - neighbor_dim = offset_provider.neighbor_axis.value + offset_provider_type = self.offset_provider_type[offset] + assert isinstance(offset_provider_type, common.NeighborConnectivityType) + neighbor_dim = offset_provider_type.codomain.value result_name = unique_var_name() self.context.body.add_array( - result_name, (offset_provider.max_neighbors,), iterator.dtype, transient=True + result_name, (offset_provider_type.max_neighbors,), iterator.dtype, transient=True ) result_array = self.context.body.arrays[result_name] result_node = self.context.state.add_access(result_name, debuginfo=di) @@ -1158,7 +1150,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: # we create a mapped tasklet for array slicing index_name = unique_name(f"_i_{neighbor_dim}") - map_ranges = {index_name: f"0:{offset_provider.max_neighbors}"} + map_ranges = {index_name: f"0:{offset_provider_type.max_neighbors}"} src_subset = ",".join( [f"_i_{dim}" if dim in iterator.indices else index_name for dim in sorted_dims] ) @@ -1212,27 +1204,30 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: offset_node = self.visit(tail[1])[0] assert offset_node.dtype in dace.dtypes.INTEGER_TYPES - if isinstance(self.offset_provider[offset_dim], Connectivity): - offset_provider = self.offset_provider[offset_dim] + if isinstance(self.offset_provider_type[offset_dim], common.NeighborConnectivityType): + offset_provider_type = cast( + common.NeighborConnectivityType, self.offset_provider_type[offset_dim] + ) # ensured by condition connectivity = self.context.state.add_access( dace_utils.connectivity_identifier(offset_dim), debuginfo=di ) - shifted_dim = offset_provider.origin_axis.value - target_dim = offset_provider.neighbor_axis.value + shifted_dim_tag = offset_provider_type.source_dim.value + target_dim_tag = offset_provider_type.codomain.value args = [ ValueExpr(connectivity, _INDEX_DTYPE), - ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), + ValueExpr(iterator.indices[shifted_dim_tag], offset_node.dtype), offset_node, ] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{internals[1]}, {internals[2]}]" else: - assert isinstance(self.offset_provider[offset_dim], Dimension) + shifted_dim = self.offset_provider_type[offset_dim] + assert isinstance(shifted_dim, common.Dimension) - shifted_dim = self.offset_provider[offset_dim].value - target_dim = shifted_dim - args = [ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), offset_node] + shifted_dim_tag = shifted_dim.value + target_dim_tag = shifted_dim_tag + args = [ValueExpr(iterator.indices[shifted_dim_tag], offset_node.dtype), offset_node] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]} + {internals[1]}" @@ -1241,8 +1236,8 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: )[0].value shifted_index = {dim: value for dim, value in iterator.indices.items()} - del shifted_index[shifted_dim] - shifted_index[target_dim] = shifted_value + del shifted_index[shifted_dim_tag] + shifted_index[target_dim_tag] = shifted_value return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) @@ -1506,7 +1501,7 @@ def is_scan(node: itir.Node) -> bool: def closure_to_tasklet_sdfg( node: itir.StencilClosure, - offset_provider: dict[str, Any], + offset_provider_type: common.OffsetProviderType, domain: dict[str, str], inputs: Sequence[tuple[str, ts.TypeSpec]], connectivities: Sequence[tuple[dace.ndarray, str]], @@ -1547,7 +1542,9 @@ def closure_to_tasklet_sdfg( body.add_array(name, shape=shape, strides=strides, dtype=arr.dtype) context = Context(body, state, symbol_map) - translator = PythonTaskletCodegen(offset_provider, context, use_field_canonical_representation) + translator = PythonTaskletCodegen( + offset_provider_type, context, use_field_canonical_representation + ) args = [itir.SymRef(id=name) for name, _ in inputs] if is_scan(node.stencil): diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index d367eb0883..72bb32f003 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -7,21 +7,21 @@ # SPDX-License-Identifier: BSD-3-Clause import itertools -from typing import Any, Mapping +from typing import Any import dace import gt4py.next.iterator.ir as itir from gt4py import eve -from gt4py.next.common import Connectivity +from gt4py.next import common from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.program_processors.runners.dace_common import utility as dace_utils def get_used_connectivities( - node: itir.Node, offset_provider: Mapping[str, Any] -) -> dict[str, Connectivity]: - connectivities = dace_utils.filter_connectivities(offset_provider) + node: itir.Node, offset_provider_type: common.OffsetProviderType +) -> dict[str, common.NeighborConnectivityType]: + connectivities = dace_utils.filter_connectivity_types(offset_provider_type) offset_dims = set(eve.walk_values(node).if_isinstance(itir.OffsetLiteral).getattr("value")) return {offset: connectivities[offset] for offset in offset_dims if offset in connectivities} diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py index 740f1979cd..653ed4719d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py @@ -52,7 +52,7 @@ def generate_sdfg( self, program: itir.FencilDefinition, arg_types: Sequence[ts.TypeSpec], - offset_provider: dict[str, common.Dimension | common.Connectivity], + offset_provider_type: common.OffsetProviderType, column_axis: Optional[common.Dimension], ) -> dace.SDFG: on_gpu = ( @@ -64,7 +64,7 @@ def generate_sdfg( return build_sdfg_from_itir( program, arg_types, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, auto_optimize=self.auto_optimize, on_gpu=on_gpu, column_axis=column_axis, @@ -87,7 +87,7 @@ def __call__( sdfg = self.generate_sdfg( program, inp.args.args, - inp.args.offset_provider, + common.offset_provider_to_type(inp.args.offset_provider), inp.args.column_axis, ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 965c6417b2..1f3778f227 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -12,14 +12,12 @@ import diskcache import factory -import numpy.typing as npt import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators from gt4py.eve import utils from gt4py.eve.utils import content_hash from gt4py.next import backend, common, config -from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind @@ -63,8 +61,8 @@ def decorated_program( def _ensure_is_on_device( - connectivity_arg: npt.NDArray, device: core_defs.DeviceType -) -> npt.NDArray: + connectivity_arg: core_defs.NDArrayObject, device: core_defs.DeviceType +) -> core_defs.NDArrayObject: if device in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM]: import cupy as cp @@ -79,17 +77,17 @@ def _ensure_is_on_device( def extract_connectivity_args( offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType -) -> list[tuple[npt.NDArray, tuple[int, ...]]]: +) -> list[tuple[core_defs.NDArrayObject, tuple[int, ...]]]: # note: the order here needs to agree with the order of the generated bindings - args: list[tuple[npt.NDArray, tuple[int, ...]]] = [] + args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = [] for name, conn in offset_provider.items(): if isinstance(conn, common.Connectivity): - if not isinstance(conn, common.NeighborTable): + if not common.is_neighbor_table(conn): raise NotImplementedError( "Only 'NeighborTable' connectivities implemented at this point." ) # copying to device here is a fallback for easy testing and might be removed later - conn_arg = _ensure_is_on_device(conn.table, device) + conn_arg = _ensure_is_on_device(conn.ndarray, device) args.append((conn_arg, tuple([0] * 2))) elif isinstance(conn, common.Dimension): pass @@ -125,7 +123,7 @@ def fingerprint_compilable_program(inp: stages.CompilableProgram) -> str: the program, sorted offset_provider, and column_axis. """ program: itir.FencilDefinition | itir.Program = inp.data - offset_provider: dict[str, Connectivity | Dimension] = inp.args.offset_provider + offset_provider: common.OffsetProvider = inp.args.offset_provider column_axis: Optional[common.Dimension] = inp.args.column_axis program_hash = utils.content_hash( diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 4d518d7fcc..1dd568b95a 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -94,7 +94,7 @@ def fencil_generator( ir: itir.Program | itir.FencilDefinition, debug: bool, use_embedded: bool, - offset_provider: dict[str, common.Connectivity | common.Dimension], + offset_provider: common.OffsetProvider, transforms: itir_transforms.ITIRTransform, ) -> stages.CompiledProgram: """ @@ -111,7 +111,15 @@ def fencil_generator( """ # TODO(tehrengruber): just a temporary solution until we have a proper generic # caching mechanism - cache_key = hash((ir, transforms, debug, use_embedded, tuple(offset_provider.items()))) + cache_key = hash( + ( + ir, + transforms, + debug, + use_embedded, + tuple(common.offset_provider_to_type(offset_provider).items()), + ) + ) if cache_key in _FENCIL_CACHE: if debug: print(f"Using cached fencil for key {cache_key}") @@ -151,7 +159,9 @@ def fencil_generator( """ ) - with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as source_file: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", encoding="utf-8", delete=False + ) as source_file: source_file_name = source_file.name if debug: print(source_file_name) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 0827d99cdc..fa8c9b9ab1 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -63,6 +63,7 @@ class DimensionType(TypeSpec): @dataclass(frozen=True) class OffsetType(TypeSpec): + # TODO(havogt): replace by ConnectivityType source: func_common.Dimension target: tuple[func_common.Dimension] | tuple[func_common.Dimension, func_common.Dimension] diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index 1da34db3c0..f5646c71e4 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -6,30 +6,32 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import numpy as np -from typing import Optional from types import ModuleType +from typing import Optional + +import numpy as np import pytest import gt4py.next as gtx -from gt4py.next import backend as next_backend -from gt4py.next.otf import arguments +from gt4py.next import backend as next_backend, common from next_tests.integration_tests import cases from next_tests.integration_tests.cases import cartesian_case, unstructured_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + E2V, + E2VDim, + Edge, + Vertex, exec_alloc_descriptor, mesh_descriptor, - Vertex, - Edge, - E2V, ) from next_tests.integration_tests.multi_feature_tests.ffront_tests.test_laplacian import ( lap_program, - laplap_program, lap_ref, + laplap_program, ) + try: import dace from gt4py.next.program_processors.runners.dace import ( @@ -57,25 +59,20 @@ def test_sdfgConvertible_laplap(cartesian_case): in_field = cases.allocate(cartesian_case, laplap_program, "in_field")() out_field = cases.allocate(cartesian_case, laplap_program, "out_field")() - connectivities = {} # Dict of NeighborOffsetProviders, where self.table = None - for k, v in cartesian_case.offset_provider.items(): - if hasattr(v, "table"): - connectivities[k] = arguments.CompileTimeConnectivity( - v.max_neighbors, v.has_skip_values, v.origin_axis, v.neighbor_axis, v.table.dtype - ) - else: - connectivities[k] = v - # Test DaCe closure support @dace.program def sdfg(): tmp_field = xp.empty_like(out_field) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( cartesian_case.backend - ).with_connectivities(connectivities)(in_field, tmp_field) + ).with_connectivities(common.offset_provider_to_type(cartesian_case.offset_provider))( + in_field, tmp_field + ) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( cartesian_case.backend - ).with_connectivities(connectivities)(tmp_field, out_field) + ).with_connectivities(common.offset_provider_to_type(cartesian_case.offset_provider))( + tmp_field, out_field + ) sdfg() @@ -130,13 +127,13 @@ def sdfg( a, out, offset_provider=offset_provider ) - e2v = gtx.NeighborTableOffsetProvider( - xp.asarray([[0, 1], [1, 2], [2, 0]]), Edge, Vertex, 2, False - ) - connectivities = {} - connectivities["E2V"] = arguments.CompileTimeConnectivity( - e2v.max_neighbors, e2v.has_skip_values, e2v.origin_axis, e2v.neighbor_axis, e2v.table.dtype + e2v = gtx.as_connectivity( + [Edge, E2VDim], + codomain=Vertex, + data=xp.asarray([[0, 1], [1, 2], [2, 0]]), + allocator=allocator, ) + connectivities = {"E2V": e2v.__gt_type__()} offset_provider = OffsetProvider_t.dtype._typeclass.as_ctypes()(E2V=e2v.data_ptr()) SDFG = sdfg.to_sdfg(connectivities=connectivities) @@ -144,6 +141,9 @@ def sdfg( a = gtx.as_field([Vertex], xp.asarray([0.0, 1.0, 2.0]), allocator=allocator) out = gtx.zeros({Edge: 3}, allocator=allocator) + e2v_ndarray_copy = ( + e2v.ndarray.copy() + ) # otherwise DaCe complains about the gt4py custom allocated view # This is a low level interface to call the compiled SDFG. # It is not supposed to be used in user code. # The high level interface should be provided by a DaCe Orchestrator, @@ -155,21 +155,21 @@ def sdfg( offset_provider, rows=3, cols=2, - connectivity_E2V=e2v.table, - __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace( - xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table, 0 - ), - __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace( - xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table, 1 - ), + connectivity_E2V=e2v_ndarray_copy, + __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 0), + __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 1), ) - e2v_xp = xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table - assert np.allclose(gtx.field_utils.asnumpy(out), gtx.field_utils.asnumpy(a)[e2v_xp[:, 0]]) + e2v_np = e2v.asnumpy() + assert np.allclose(out.asnumpy(), a.asnumpy()[e2v_np[:, 0]]) - e2v = gtx.NeighborTableOffsetProvider( - xp.asarray([[1, 0], [2, 1], [0, 2]]), Edge, Vertex, 2, False + e2v = gtx.as_connectivity( + [Edge, E2VDim], + codomain=Vertex, + data=xp.asarray([[1, 0], [2, 1], [0, 2]]), + allocator=allocator, ) + e2v_ndarray_copy = e2v.ndarray.copy() offset_provider = OffsetProvider_t.dtype._typeclass.as_ctypes()(E2V=e2v.data_ptr()) cSDFG( a, @@ -177,17 +177,13 @@ def sdfg( offset_provider, rows=3, cols=2, - connectivity_E2V=e2v.table, - __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace( - xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table, 0 - ), - __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace( - xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table, 1 - ), + connectivity_E2V=e2v_ndarray_copy, + __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 0), + __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 1), ) - e2v_xp = xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table - assert np.allclose(gtx.field_utils.asnumpy(out), gtx.field_utils.asnumpy(a)[e2v_xp[:, 0]]) + e2v_np = e2v.asnumpy() + assert np.allclose(out.asnumpy(), a.asnumpy()[e2v_np[:, 0]]) def get_stride_from_numpy_to_dace(numpy_array: np.ndarray, axis: int) -> int: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index c64efb27d2..794dd06709 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -152,7 +152,10 @@ def num_edges(self) -> int: ... def num_levels(self) -> int: ... @property - def offset_provider(self) -> dict[str, common.Connectivity]: ... + def offset_provider(self) -> common.OffsetProvider: ... + + @property + def offset_provider_type(self) -> common.OffsetProviderType: ... def simple_mesh() -> MeshDescriptor: @@ -211,25 +214,40 @@ def simple_mesh() -> MeshDescriptor: assert all(len(row) == 2 for row in e2v_arr) e2v_arr = np.asarray(e2v_arr, dtype=gtx.IndexType) + offset_provider = { + V2E.value: common._connectivity( + v2e_arr, + codomain=Edge, + domain={Vertex: v2e_arr.shape[0], V2EDim: 4}, + skip_value=None, + ), + E2V.value: common._connectivity( + e2v_arr, + codomain=Vertex, + domain={Edge: e2v_arr.shape[0], E2VDim: 2}, + skip_value=None, + ), + C2V.value: common._connectivity( + c2v_arr, + codomain=Vertex, + domain={Cell: c2v_arr.shape[0], C2VDim: 4}, + skip_value=None, + ), + C2E.value: common._connectivity( + c2e_arr, + codomain=Edge, + domain={Cell: c2e_arr.shape[0], C2EDim: 4}, + skip_value=None, + ), + } + return types.SimpleNamespace( name="simple_mesh", num_vertices=num_vertices, num_edges=np.int32(num_edges), num_cells=num_cells, - offset_provider={ - V2E.value: gtx.NeighborTableOffsetProvider( - v2e_arr, Vertex, Edge, 4, has_skip_values=False - ), - E2V.value: gtx.NeighborTableOffsetProvider( - e2v_arr, Edge, Vertex, 2, has_skip_values=False - ), - C2V.value: gtx.NeighborTableOffsetProvider( - c2v_arr, Cell, Vertex, 4, has_skip_values=False - ), - C2E.value: gtx.NeighborTableOffsetProvider( - c2e_arr, Cell, Edge, 4, has_skip_values=False - ), - }, + offset_provider=offset_provider, + offset_provider_type=common.offset_provider_to_type(offset_provider), ) @@ -287,25 +305,40 @@ def skip_value_mesh() -> MeshDescriptor: dtype=gtx.IndexType, ) + offset_provider = { + V2E.value: common._connectivity( + v2e_arr, + codomain=Edge, + domain={Vertex: v2e_arr.shape[0], V2EDim: 5}, + skip_value=common._DEFAULT_SKIP_VALUE, + ), + E2V.value: common._connectivity( + e2v_arr, + codomain=Vertex, + domain={Edge: e2v_arr.shape[0], E2VDim: 2}, + skip_value=None, + ), + C2V.value: common._connectivity( + c2v_arr, + codomain=Vertex, + domain={Cell: c2v_arr.shape[0], C2VDim: 3}, + skip_value=None, + ), + C2E.value: common._connectivity( + c2e_arr, + codomain=Edge, + domain={Cell: c2e_arr.shape[0], C2EDim: 3}, + skip_value=None, + ), + } + return types.SimpleNamespace( name="skip_value_mesh", num_vertices=num_vertices, num_edges=num_edges, num_cells=num_cells, - offset_provider={ - V2E.value: gtx.NeighborTableOffsetProvider( - v2e_arr, Vertex, Edge, 5, has_skip_values=True - ), - E2V.value: gtx.NeighborTableOffsetProvider( - e2v_arr, Edge, Vertex, 2, has_skip_values=False - ), - C2V.value: gtx.NeighborTableOffsetProvider( - c2v_arr, Cell, Vertex, 3, has_skip_values=False - ), - C2E.value: gtx.NeighborTableOffsetProvider( - c2e_arr, Cell, Edge, 3, has_skip_values=False - ), - }, + offset_provider=offset_provider, + offset_provider_type=common.offset_provider_to_type(offset_provider), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index a5453151e6..1a51e3667d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -89,7 +89,7 @@ def testee(a: cases.VField) -> cases.EField: cases.verify_with_default_data( unstructured_case, testee, - ref=lambda a: a[unstructured_case.offset_provider["E2V"].table[:, 0]], + ref=lambda a: a[unstructured_case.offset_provider["E2V"].ndarray[:, 0]], ) @@ -115,16 +115,16 @@ def composed_shift_unstructured(inp: cases.VField) -> cases.CField: cases.verify_with_default_data( unstructured_case, composed_shift_unstructured_flat, - ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].table[:, 0]][ - unstructured_case.offset_provider["C2E"].table[:, 0] + ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].ndarray[:, 0]][ + unstructured_case.offset_provider["C2E"].ndarray[:, 0] ], ) cases.verify_with_default_data( unstructured_case, composed_shift_unstructured_intermediate_result, - ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].table[:, 0]][ - unstructured_case.offset_provider["C2E"].table[:, 0] + ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].ndarray[:, 0]][ + unstructured_case.offset_provider["C2E"].ndarray[:, 0] ], comparison=lambda inp, tmp: np.all(inp == tmp), ) @@ -132,8 +132,8 @@ def composed_shift_unstructured(inp: cases.VField) -> cases.CField: cases.verify_with_default_data( unstructured_case, composed_shift_unstructured, - ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].table[:, 0]][ - unstructured_case.offset_provider["C2E"].table[:, 0] + ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].ndarray[:, 0]][ + unstructured_case.offset_provider["C2E"].ndarray[:, 0] ], ) @@ -583,11 +583,11 @@ def testee(a: cases.VField) -> cases.VField: unstructured_case, testee, ref=lambda a: np.sum( - np.sum(a[unstructured_case.offset_provider["E2V"].table], axis=1, initial=0)[ - unstructured_case.offset_provider["V2E"].table + np.sum(a[unstructured_case.offset_provider["E2V"].ndarray], axis=1, initial=0)[ + unstructured_case.offset_provider["V2E"].ndarray ], axis=1, - where=unstructured_case.offset_provider["V2E"].table != common._DEFAULT_SKIP_VALUE, + where=unstructured_case.offset_provider["V2E"].ndarray != common._DEFAULT_SKIP_VALUE, ), comparison=lambda a, tmp_2: np.all(a == tmp_2), ) @@ -606,8 +606,8 @@ def testee(inp: cases.EField) -> cases.EField: unstructured_case, testee, ref=lambda inp: np.sum( - np.sum(inp[unstructured_case.offset_provider["V2E"].table], axis=1)[ - unstructured_case.offset_provider["E2V"].table + np.sum(inp[unstructured_case.offset_provider["V2E"].ndarray], axis=1)[ + unstructured_case.offset_provider["E2V"].ndarray ], axis=1, ), @@ -627,8 +627,8 @@ def testee(a: cases.EField, b: cases.EField) -> tuple[cases.VField, cases.VField unstructured_case, testee, ref=lambda a, b: [ - np.sum(a[unstructured_case.offset_provider["V2E"].table], axis=1), - np.sum(b[unstructured_case.offset_provider["V2E"].table], axis=1), + np.sum(a[unstructured_case.offset_provider["V2E"].ndarray], axis=1), + np.sum(b[unstructured_case.offset_provider["V2E"].ndarray], axis=1), ], comparison=lambda a, tmp: (np.all(a[0] == tmp[0]), np.all(a[1] == tmp[1])), ) @@ -649,11 +649,11 @@ def reduce_tuple_element(e: cases.EField, v: cases.VField) -> cases.EField: unstructured_case, reduce_tuple_element, ref=lambda e, v: np.sum( - e[v2e.table] + np.tile(v, (v2e.max_neighbors, 1)).T, + e[v2e.ndarray] + np.tile(v, (v2e.shape[1], 1)).T, axis=1, initial=0, - where=v2e.table != common._DEFAULT_SKIP_VALUE, - )[unstructured_case.offset_provider["E2V"].table[:, 0]], + where=v2e.ndarray != common._DEFAULT_SKIP_VALUE, + )[unstructured_case.offset_provider["E2V"].ndarray[:, 0]], ) @@ -780,7 +780,7 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField: tmp = neighbor_sum(b(V2E) if 2 < 3 else a(V2E), axis=V2EDim) return tmp - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify_with_default_data( unstructured_case, testee, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 37f4ee2cd1..33832fb5f0 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -33,11 +33,11 @@ def testee( ) # multiplication with shifted `ones` because reduction of only non-shifted field with local dimension is not supported inp = unstructured_case.as_field( - [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table + [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].ndarray ) ones = cases.allocate(unstructured_case, testee, "ones").strategy(cases.ConstInitializer(1))() - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify( unstructured_case, testee, @@ -57,7 +57,7 @@ def testee(inp: gtx.Field[[Vertex, V2EDim], int32]) -> gtx.Field[[Vertex], int32 return neighbor_sum(inp, axis=V2EDim) inp = unstructured_case.as_field( - [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table + [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].ndarray ) cases.verify( @@ -65,7 +65,7 @@ def testee(inp: gtx.Field[[Vertex, V2EDim], int32]) -> gtx.Field[[Vertex], int32 testee, inp, out=cases.allocate(unstructured_case, testee, cases.RETURN)(), - ref=np.sum(unstructured_case.offset_provider["V2E"].table, axis=1), + ref=np.sum(unstructured_case.offset_provider["V2E"].ndarray, axis=1), ) @@ -76,7 +76,7 @@ def testee(inp: gtx.Field[[Edge], int32]) -> gtx.Field[[Vertex, V2EDim], int32]: return inp(V2E) out = unstructured_case.as_field( - [Vertex, V2EDim], np.zeros_like(unstructured_case.offset_provider["V2E"].table) + [Vertex, V2EDim], np.zeros_like(unstructured_case.offset_provider["V2E"].ndarray) ) inp = cases.allocate(unstructured_case, testee, "inp")() cases.verify( @@ -84,5 +84,5 @@ def testee(inp: gtx.Field[[Edge], int32]) -> gtx.Field[[Vertex, V2EDim], int32]: testee, inp, out=out, - ref=inp.asnumpy()[unstructured_case.offset_provider["V2E"].table], + ref=inp.asnumpy()[unstructured_case.offset_provider["V2E"].ndarray], ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 29966c30ad..7648d34db7 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -52,7 +52,7 @@ def testee(edge_f: cases.EField) -> cases.VField: inp = cases.allocate(unstructured_case, testee, "edge_f", strategy=strategy)() out = cases.allocate(unstructured_case, testee, cases.RETURN)() - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray ref = np.max( inp.asnumpy()[v2e_table], axis=1, @@ -69,7 +69,7 @@ def minover(edge_f: cases.EField) -> cases.VField: out = min_over(edge_f(V2E), axis=V2EDim) return out - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify_with_default_data( unstructured_case, minover, @@ -106,7 +106,7 @@ def reduction_ke_field( "fop", [reduction_e_field, reduction_ek_field, reduction_ke_field], ids=lambda fop: fop.__name__ ) def test_neighbor_sum(unstructured_case, fop): - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray edge_f = cases.allocate(unstructured_case, fop, "edge_f")() @@ -157,7 +157,7 @@ def fencil_op(edge_f: EKField) -> VKField: def fencil(edge_f: EKField, out: VKField): fencil_op(edge_f, out=out) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray field = cases.allocate(unstructured_case, fencil, "edge_f", sizes={KDim: 2})() out = cases.allocate(unstructured_case, fencil_op, cases.RETURN, sizes={KDim: 1})() @@ -190,7 +190,7 @@ def reduce_expr(edge_f: cases.EField) -> cases.VField: def fencil(edge_f: cases.EField, out: cases.VField): reduce_expr(edge_f, out=out) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify_with_default_data( unstructured_case, fencil, @@ -210,7 +210,7 @@ def test_reduction_with_common_expression(unstructured_case): def testee(flux: cases.EField) -> cases.VField: return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify_with_default_data( unstructured_case, testee, @@ -226,7 +226,7 @@ def test_reduction_expression_with_where(unstructured_case): def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: return neighbor_sum(where(mask, inp(V2E), inp(V2E)), axis=V2EDim) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray mask = unstructured_case.as_field( [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) @@ -255,7 +255,7 @@ def test_reduction_expression_with_where_and_tuples(unstructured_case): def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: return neighbor_sum(where(mask, (inp(V2E), inp(V2E)), (inp(V2E), inp(V2E)))[1], axis=V2EDim) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray mask = unstructured_case.as_field( [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) @@ -284,7 +284,7 @@ def test_reduction_expression_with_where_and_scalar(unstructured_case): def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: return neighbor_sum(inp(V2E) + where(mask, inp(V2E), 1), axis=V2EDim) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray mask = unstructured_case.as_field( [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 11e28de9e1..66c56c4827 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -90,7 +90,7 @@ def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh a = cases.allocate(unstructured_case, testee, "a")() out = cases.allocate(unstructured_case, testee, "out")() - first_nbs, second_nbs = (mesh_descriptor.offset_provider["E2V"].table[:, i] for i in [0, 1]) + first_nbs, second_nbs = (mesh_descriptor.offset_provider["E2V"].ndarray[:, i] for i in [0, 1]) ref = (a.ndarray * 2)[first_nbs] + (a.ndarray * 2)[second_nbs] cases.verify( diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 3fc4ed9945..5e3a2fcd14 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -248,11 +248,14 @@ def test_can_deref(program_processor, stencil): program_processor, validate = program_processor Node = gtx.Dimension("Node") + NeighDim = gtx.Dimension("Neighbor", kind=gtx.DimensionKind.LOCAL) inp = gtx.as_field([Node], np.ones((1,), dtype=np.int32)) out = gtx.as_field([Node], np.asarray([0], dtype=inp.dtype)) - no_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[-1]]), Node, Node, 1) + no_neighbor_tbl = gtx.as_connectivity( + domain={Node: 1, NeighDim: 1}, codomain=Node, data=np.array([[-1]]), skip_value=-1 + ) run_processor( stencil[{Node: range(1)}], program_processor, @@ -264,7 +267,9 @@ def test_can_deref(program_processor, stencil): if validate: assert np.allclose(out.asnumpy(), -1.0) - a_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[0]]), Node, Node, 1) + a_neighbor_tbl = gtx.as_connectivity( + domain={Node: 1, NeighDim: 1}, codomain=Node, data=np.array([[0]]), skip_value=-1 + ) run_processor( stencil[{Node: range(1)}], program_processor, @@ -277,37 +282,6 @@ def test_can_deref(program_processor, stencil): assert np.allclose(out.asnumpy(), 1.0) -# def test_can_deref_lifted(program_processor): -# program_processor, validate = program_processor - -# Neighbor = offset("Neighbor") -# Node = gtx.Dimension("Node") - -# @fundef -# def _can_deref(inp): -# shifted = shift(Neighbor, 0)(inp) -# return if_(can_deref(shifted), 1, -1) - -# inp = gtx.as_field([Node], np.zeros((1,))) -# out = gtx.as_field([Node], np.asarray([0])) - -# no_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[None]]), Node, Node, 1) -# _can_deref[{Node: range(1)}]( -# inp, out=out, offset_provider={"Neighbor": no_neighbor_tbl}, program_processor=program_processor -# ) - -# if validate: -# assert np.allclose(np.asarray(out), -1.0) - -# a_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[0]]), Node, Node, 1) -# _can_deref[{Node: range(1)}]( -# inp, out=out, offset_provider={"Neighbor": a_neighbor_tbl}, program_processor=program_processor -# ) - -# if validate: -# assert np.allclose(np.asarray(out), 1.0) - - @pytest.mark.parametrize( "input_value, dtype, np_dtype", [ diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py index 69786b323b..7bde55bfd2 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py @@ -14,6 +14,7 @@ from gt4py.next.iterator.runtime import closure, fendef, fundef, offset from next_tests.unit_tests.conftest import program_processor, run_processor +from gt4py.next.iterator.embedded import StridedConnectivityField LocA = gtx.Dimension("LocA") @@ -21,8 +22,10 @@ LocB = gtx.Dimension("LocB") # unused LocA2LocAB = offset("O") -LocA2LocAB_offset_provider = gtx.StridedNeighborOffsetProvider( - origin_axis=LocA, neighbor_axis=LocAB, max_neighbors=2, has_skip_values=False +LocA2LocAB_offset_provider = StridedConnectivityField( + domain_dims=(LocA, gtx.Dimension("Dummy", kind=gtx.DimensionKind.LOCAL)), + codomain_dim=LocAB, + max_neighbors=2, ) @@ -41,7 +44,7 @@ def test_strided_offset_provider(program_processor): program_processor, validate = program_processor LocA_size = 2 - max_neighbors = LocA2LocAB_offset_provider.max_neighbors + max_neighbors = LocA2LocAB_offset_provider.__gt_type__().max_neighbors LocAB_size = LocA_size * max_neighbors rng = np.random.default_rng() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py index eb59c77201..6c6ca7e4bc 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py @@ -11,7 +11,6 @@ import numpy as np import pytest - pytest.importorskip("atlas4py") from gt4py import next as gtx @@ -22,20 +21,17 @@ exec_alloc_descriptor, ) from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import ( + E2V, + V2E, + E2VDim, + Edge, + V2EDim, + Vertex, assert_close, nabla_setup, ) -Vertex = gtx.Dimension("Vertex") -Edge = gtx.Dimension("Edge") -V2EDim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL) -E2VDim = gtx.Dimension("E2V", kind=gtx.DimensionKind.LOCAL) - -V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) -E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim)) - - @gtx.field_operator def compute_zavgS( pp: gtx.Field[[Vertex], float], S_M: gtx.Field[[Edge], float] @@ -67,21 +63,19 @@ def pnabla( def test_ffront_compute_zavgS(exec_alloc_descriptor): - executor, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator - - setup = nabla_setup() + _, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator - pp = gtx.as_field([Vertex], setup.input_field, allocator=allocator) - S_M = tuple(map(gtx.as_field.partial([Edge], allocator=allocator), setup.S_fields)) + setup = nabla_setup(allocator=allocator) zavgS = gtx.zeros({Edge: setup.edges_size}, allocator=allocator) - e2v = gtx.NeighborTableOffsetProvider( - atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False - ) - - compute_zavgS.with_backend(exec_alloc_descriptor)( - pp, S_M[0], out=zavgS, offset_provider={"E2V": e2v} + compute_zavgS.with_backend( + None if exec_alloc_descriptor.executor is None else exec_alloc_descriptor + )( + setup.input_field, + setup.S_fields[0], + out=zavgS, + offset_provider={"E2V": setup.edges2node_connectivity}, ) assert_close(-199755464.25741270, np.min(zavgS.asnumpy())) @@ -89,27 +83,23 @@ def test_ffront_compute_zavgS(exec_alloc_descriptor): def test_ffront_nabla(exec_alloc_descriptor): - executor, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator - - setup = nabla_setup() + _, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator - sign = gtx.as_field([Vertex, V2EDim], setup.sign_field, allocator=allocator) - pp = gtx.as_field([Vertex], setup.input_field, allocator=allocator) - S_M = tuple(map(gtx.as_field.partial([Edge], allocator=allocator), setup.S_fields)) - vol = gtx.as_field([Vertex], setup.vol_field, allocator=allocator) + setup = nabla_setup(allocator=allocator) pnabla_MXX = gtx.zeros({Vertex: setup.nodes_size}, allocator=allocator) pnabla_MYY = gtx.zeros({Vertex: setup.nodes_size}, allocator=allocator) - e2v = gtx.NeighborTableOffsetProvider( - atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False - ) - v2e = gtx.NeighborTableOffsetProvider( - atlas_utils.AtlasTable(setup.nodes2edge_connectivity).asnumpy(), Vertex, Edge, 7 - ) - - pnabla.with_backend(exec_alloc_descriptor)( - pp, S_M, sign, vol, out=(pnabla_MXX, pnabla_MYY), offset_provider={"E2V": e2v, "V2E": v2e} + pnabla.with_backend(None if exec_alloc_descriptor.executor is None else exec_alloc_descriptor)( + setup.input_field, + setup.S_fields, + setup.sign_field, + setup.vol_field, + out=(pnabla_MXX, pnabla_MYY), + offset_provider={ + "E2V": setup.edges2node_connectivity, + "V2E": setup.nodes2edge_connectivity, + }, ) # TODO this check is not sensitive enough, need to implement a proper numpy reference! diff --git a/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py b/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py index 8d7324f438..6a5865134d 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py @@ -20,6 +20,18 @@ functionspace, ) +from gt4py import next as gtx +from gt4py.next.iterator import atlas_utils + + +Vertex = gtx.Dimension("Vertex") +Edge = gtx.Dimension("Edge") +V2EDim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL) +E2VDim = gtx.Dimension("E2V", kind=gtx.DimensionKind.LOCAL) + +V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) +E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim)) + def assert_close(expected, actual): assert math.isclose(expected, actual), "expected={}, actual={}".format(expected, actual) @@ -33,9 +45,10 @@ def _default_config(): config["angle"] = 20.0 return config - def __init__(self, *, grid=StructuredGrid("O32"), config=None): + def __init__(self, *, allocator, grid=StructuredGrid("O32"), config=None): if config is None: config = self._default_config() + self.allocator = allocator mesh = StructuredMeshGenerator(config).generate(grid) fs_edges = functionspace.EdgeColumns(mesh, halo=1) @@ -55,12 +68,22 @@ def __init__(self, *, grid=StructuredGrid("O32"), config=None): self.edges_per_node = edges_per_node @property - def edges2node_connectivity(self): - return self.mesh.edges.node_connectivity + def edges2node_connectivity(self) -> gtx.Connectivity: + return gtx.as_connectivity( + domain={Edge: self.edges_size, E2VDim: 2}, + codomain=Vertex, + data=atlas_utils.AtlasTable(self.mesh.edges.node_connectivity).asnumpy(), + allocator=self.allocator, + ) @property - def nodes2edge_connectivity(self): - return self.mesh.nodes.edge_connectivity + def nodes2edge_connectivity(self) -> gtx.Connectivity: + return gtx.as_connectivity( + domain={Vertex: self.nodes_size, V2EDim: self.edges_per_node}, + codomain=Edge, + data=atlas_utils.AtlasTable(self.mesh.nodes.edge_connectivity).asnumpy(), + allocator=self.allocator, + ) @property def nodes_size(self): @@ -75,16 +98,16 @@ def _is_pole_edge(e, edge_flags): return Topology.check(edge_flags[e], Topology.POLE) @property - def is_pole_edge_field(self): + def is_pole_edge_field(self) -> gtx.Field: edge_flags = np.array(self.mesh.edges.flags()) pole_edge_field = np.zeros((self.edges_size,), dtype=bool) for e in range(self.edges_size): pole_edge_field[e] = self._is_pole_edge(e, edge_flags) - return pole_edge_field + return gtx.as_field([Edge], pole_edge_field, allocator=self.allocator) @property - def sign_field(self): + def sign_field(self) -> gtx.Field: node2edge_sign = np.zeros((self.nodes_size, self.edges_per_node)) edge_flags = np.array(self.mesh.edges.flags()) @@ -100,10 +123,10 @@ def sign_field(self): node2edge_sign[jnode, jedge] = -1.0 if self._is_pole_edge(iedge, edge_flags): node2edge_sign[jnode, jedge] = 1.0 - return node2edge_sign + return gtx.as_field([Vertex, V2EDim], node2edge_sign, allocator=self.allocator) @property - def S_fields(self): + def S_fields(self) -> tuple[gtx.Field, gtx.Field]: S = np.array(self.mesh.edges.field("dual_normals"), copy=False) S_MXX = np.zeros((self.edges_size)) S_MYY = np.zeros((self.edges_size)) @@ -124,10 +147,12 @@ def S_fields(self): assert math.isclose(min(S_MYY), -2001577.7946404363) assert math.isclose(max(S_MYY), 2001577.7946404363) - return S_MXX, S_MYY + return gtx.as_field([Edge], S_MXX, allocator=self.allocator), gtx.as_field( + [Edge], S_MYY, allocator=self.allocator + ) @property - def vol_field(self): + def vol_field(self) -> gtx.Field: rpi = 2.0 * math.asin(1.0) radius = 6371.22e03 deg2rad = 2.0 * rpi / 360.0 @@ -142,10 +167,10 @@ def vol_field(self): # VOL(min/max): 57510668192.214096 851856184496.32886 assert_close(57510668192.214096, min(vol)) assert_close(851856184496.32886, max(vol)) - return vol + return gtx.as_field([Vertex], vol, allocator=self.allocator) @property - def input_field(self): + def input_field(self) -> gtx.Field: klevel = 0 MXX = 0 MYY = 1 @@ -200,4 +225,5 @@ def input_field(self): assert_close(0.0000000000000000, min(rzs)) assert_close(1965.4980340735883, max(rzs)) - return rzs[:, klevel] + + return gtx.as_field([Vertex], rzs[:, klevel], allocator=self.allocator) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 3db4497910..4487681abf 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -111,25 +111,18 @@ def nabla(n_nodes, out, pp, S_MXX, S_MYY, sign, vol): @pytest.mark.requires_atlas def test_compute_zavgS(program_processor): program_processor, validate = program_processor - setup = nabla_setup() - - pp = gtx.as_field([Vertex], setup.input_field) - S_MXX, S_MYY = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) + setup = nabla_setup(allocator=None) zavgS = gtx.as_field([Edge], np.zeros((setup.edges_size))) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - run_processor( compute_zavgS_fencil, program_processor, setup.edges_size, zavgS, - pp, - S_MXX, - offset_provider={"E2V": e2v}, + setup.input_field, + setup.S_fields[0], + offset_provider={"E2V": setup.edges2node_connectivity}, ) if validate: @@ -141,9 +134,9 @@ def test_compute_zavgS(program_processor): program_processor, setup.edges_size, zavgS, - pp, - S_MYY, - offset_provider={"E2V": e2v}, + setup.input_field, + setup.S_fields[1], + offset_provider={"E2V": setup.edges2node_connectivity}, ) if validate: assert_close(-1000788897.3202186, np.min(zavgS.asnumpy())) @@ -158,29 +151,21 @@ def compute_zavgS2_fencil(n_edges, out, pp, S_M): @pytest.mark.requires_atlas def test_compute_zavgS2(program_processor): program_processor, validate = program_processor - setup = nabla_setup() - - pp = gtx.as_field([Vertex], setup.input_field) - - S = tuple(gtx.as_field([Edge], s) for s in setup.S_fields) + setup = nabla_setup(allocator=None) zavgS = ( gtx.as_field([Edge], np.zeros((setup.edges_size))), gtx.as_field([Edge], np.zeros((setup.edges_size))), ) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - run_processor( compute_zavgS2_fencil, program_processor, setup.edges_size, zavgS, - pp, - S, - offset_provider={"E2V": e2v}, + setup.input_field, + setup.S_fields, + offset_provider={"E2V": setup.edges2node_connectivity}, ) if validate: @@ -195,34 +180,27 @@ def test_compute_zavgS2(program_processor): def test_nabla(program_processor): program_processor, validate = program_processor - setup = nabla_setup() + setup = nabla_setup(allocator=None) - sign = gtx.as_field([Vertex, V2EDim], setup.sign_field) - pp = gtx.as_field([Vertex], setup.input_field) - S_MXX, S_MYY = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) - vol = gtx.as_field([Vertex], setup.vol_field) + S_MXX, S_MYY = setup.S_fields pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - v2e = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7 - ) - run_processor( nabla, program_processor, setup.nodes_size, (pnabla_MXX, pnabla_MYY), - pp, + setup.input_field, S_MXX, S_MYY, - sign, - vol, - offset_provider={"E2V": e2v, "V2E": v2e}, + setup.sign_field, + setup.vol_field, + offset_provider={ + "E2V": setup.edges2node_connectivity, + "V2E": setup.nodes2edge_connectivity, + }, ) if validate: @@ -245,33 +223,24 @@ def nabla2(n_nodes, out, pp, S, sign, vol): @pytest.mark.requires_atlas def test_nabla2(program_processor): program_processor, validate = program_processor - setup = nabla_setup() - - sign = gtx.as_field([Vertex, V2EDim], setup.sign_field) - pp = gtx.as_field([Vertex], setup.input_field) - S_M = tuple(gtx.as_field([Edge], s) for s in setup.S_fields) - vol = gtx.as_field([Vertex], setup.vol_field) + setup = nabla_setup(allocator=None) pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - v2e = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7 - ) - run_processor( nabla2, program_processor, setup.nodes_size, (pnabla_MXX, pnabla_MYY), - pp, - S_M, - sign, - vol, - offset_provider={"E2V": e2v, "V2E": v2e}, + setup.input_field, + setup.S_fields, + setup.sign_field, + setup.vol_field, + offset_provider={ + "E2V": setup.edges2node_connectivity, + "V2E": setup.nodes2edge_connectivity, + }, ) if validate: @@ -325,36 +294,29 @@ def nabla_sign(n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, vol, node_index, is_ def test_nabla_sign(program_processor): program_processor, validate = program_processor - setup = nabla_setup() + setup = nabla_setup(allocator=None) - is_pole_edge = gtx.as_field([Edge], setup.is_pole_edge_field) - pp = gtx.as_field([Vertex], setup.input_field) - S_MXX, S_MYY = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) - vol = gtx.as_field([Vertex], setup.vol_field) + S_MXX, S_MYY = setup.S_fields pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - v2e = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7 - ) - run_processor( nabla_sign, program_processor, setup.nodes_size, pnabla_MXX, pnabla_MYY, - pp, + setup.input_field, S_MXX, S_MYY, - vol, + setup.vol_field, gtx.index_field(Vertex), - is_pole_edge, - offset_provider={"E2V": e2v, "V2E": v2e}, + setup.is_pole_edge_field, + offset_provider={ + "E2V": setup.edges2node_connectivity, + "V2E": setup.nodes2edge_connectivity, + }, ) if validate: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index 6fdc6a77a1..ac7ce9e544 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -38,9 +38,13 @@ V2VDim, Vertex, c2e_arr, + c2e_conn, e2v_arr, + e2v_conn, v2e_arr, + v2e_conn, v2v_arr, + v2v_conn, ) from next_tests.unit_tests.conftest import program_processor, run_processor @@ -89,7 +93,7 @@ def test_sum_edges_to_vertices(program_processor, stencil): program_processor, inp, out=out, - offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + offset_provider={"V2E": v2e_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -111,7 +115,7 @@ def test_map_neighbors(program_processor): program_processor, inp, out=out, - offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + offset_provider={"V2E": v2e_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -134,7 +138,7 @@ def test_map_make_const_list(program_processor): program_processor, inp, out=out, - offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + offset_provider={"V2E": v2e_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -157,8 +161,8 @@ def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(program_processo inp, out=out, offset_provider={ - "E2V": gtx.NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2), - "C2E": gtx.NeighborTableOffsetProvider(c2e_arr, Cell, Edge, 4), + "E2V": e2v_conn, + "C2E": c2e_conn, }, ) if validate: @@ -185,7 +189,7 @@ def test_sparse_input_field(program_processor): non_sparse, inp, out=out, - offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + offset_provider={"V2E": v2e_conn}, ) if validate: @@ -208,8 +212,8 @@ def test_sparse_input_field_v2v(program_processor): inp, out=out, offset_provider={ - "V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4), - "V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4), + "V2V": v2v_conn, + "V2E": v2e_conn, }, ) @@ -235,7 +239,7 @@ def test_slice_sparse(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -259,7 +263,7 @@ def test_slice_twice_sparse(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -284,7 +288,7 @@ def test_shift_sliced_sparse(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -309,7 +313,7 @@ def test_slice_shifted_sparse(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -337,7 +341,7 @@ def test_lift(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -360,7 +364,7 @@ def test_shift_sparse_input_field(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -393,8 +397,8 @@ def test_shift_sparse_input_field2(program_processor): out2 = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) offset_provider = { - "E2V": gtx.NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2), - "V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4), + "E2V": e2v_conn, + "V2E": v2e_conn, } domain = {Vertex: range(0, 9)} @@ -448,7 +452,7 @@ def test_sparse_shifted_stencil_reduce(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: diff --git a/tests/next_tests/toy_connectivity.py b/tests/next_tests/toy_connectivity.py index 82c91a5e74..50db24b880 100644 --- a/tests/next_tests/toy_connectivity.py +++ b/tests/next_tests/toy_connectivity.py @@ -49,6 +49,8 @@ dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), ) +c2e_conn = gtx.as_connectivity(domain={Cell: 9, C2EDim: 4}, codomain=Edge, data=c2e_arr) + v2v_arr = np.array( [ [1, 3, 2, 6], @@ -64,6 +66,8 @@ dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), ) +v2v_conn = gtx.as_connectivity(domain={Vertex: 9, V2VDim: 4}, codomain=Vertex, data=v2v_arr) + e2v_arr = np.array( [ [0, 1], @@ -88,6 +92,7 @@ dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), ) +e2v_conn = gtx.as_connectivity(domain={Edge: 18, E2VDim: 2}, codomain=Vertex, data=e2v_arr) # order east, north, west, south (counter-clock wise) v2e_arr = np.array( @@ -104,3 +109,5 @@ ], dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), ) + +v2e_conn = gtx.as_connectivity(domain={Vertex: 9, V2EDim: 4}, codomain=Edge, data=v2e_arr) diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index ca66b45d6d..f1269f1ed8 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -14,11 +14,11 @@ import pytest import gt4py.next as gtx -from gt4py.next import backend +from gt4py.next import backend, common +from gt4py.next.embedded import nd_array_field from gt4py.next.iterator import runtime from gt4py.next.program_processors import program_formatter - import next_tests @@ -97,12 +97,21 @@ def run_processor( @dataclasses.dataclass -class DummyConnectivity: +class DummyConnectivity(common.Connectivity): max_neighbors: int has_skip_values: int - origin_axis: gtx.Dimension = gtx.Dimension("dummy_origin") - neighbor_axis: gtx.Dimension = gtx.Dimension("dummy_neighbor") - index_type: type[int] = int + source_dim: gtx.Dimension = gtx.Dimension("dummy_origin") + codomain: gtx.Dimension = gtx.Dimension("dummy_neighbor") + + +def nd_array_implementation_params(): + for xp in nd_array_field._nd_array_implementations: + if hasattr(nd_array_field, "cp") and xp == nd_array_field.cp: + yield pytest.param(xp, id=xp.__name__, marks=pytest.mark.requires_gpu) + else: + yield pytest.param(xp, id=xp.__name__) + - def mapped_index(_, __) -> int: - return 0 +@pytest.fixture(params=nd_array_implementation_params()) +def nd_array_implementation(request): + yield request.param diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 063e79d92e..9dde5bb40a 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -15,7 +15,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common -from gt4py.next.common import Dimension, Domain, UnitRange, NamedRange, NamedIndex +from gt4py.next.common import Dimension, Domain, NamedIndex, NamedRange, UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice from gt4py.next.ffront import fbuiltins @@ -28,19 +28,6 @@ D2 = Dimension("D2") -def nd_array_implementation_params(): - for xp in nd_array_field._nd_array_implementations: - if hasattr(nd_array_field, "cp") and xp == nd_array_field.cp: - yield pytest.param(xp, id=xp.__name__, marks=pytest.mark.requires_gpu) - else: - yield pytest.param(xp, id=xp.__name__) - - -@pytest.fixture(params=nd_array_implementation_params()) -def nd_array_implementation(request): - yield request.param - - @pytest.fixture( params=[ operator.add, diff --git a/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py index dcc3a306f2..a91dbeb608 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py @@ -31,12 +31,10 @@ # 0 --0-- 1 --1-- 2 e2v_arr = np.array([[0, 1], [1, 2]]) -e2v_conn = gtx.NeighborTableOffsetProvider( - table=e2v_arr, - origin_axis=E, - neighbor_axis=V, - max_neighbors=2, - has_skip_values=False, +e2v_conn = gtx.as_connectivity( + domain={E: 2, E2VDim: 2}, + codomain=V, + data=e2v_arr, ) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py index 1f08362f4f..13e8637d1a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py @@ -10,18 +10,22 @@ import pytest import gt4py.next as gtx +from gt4py.next import common from gt4py.next.iterator.builtins import deref from gt4py.next.iterator.runtime import CartesianDomain, UnstructuredDomain, _deduce_domain, fundef -from next_tests.unit_tests.conftest import DummyConnectivity - @fundef def foo(inp): return deref(inp) -connectivity = DummyConnectivity(max_neighbors=0, has_skip_values=True) +connectivity = common.ConnectivityType( + domain=[gtx.Dimension("dummy_origin"), gtx.Dimension("dummy_neighbor")], + codomain=gtx.Dimension("dummy_codomain"), + skip_value=common._DEFAULT_SKIP_VALUE, + dtype=None, +) def test_deduce_domain(): diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 7b6214fb1b..65a5b5888d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -218,11 +218,11 @@ def expression_test_cases(): @pytest.mark.parametrize("test_case", expression_test_cases()) def test_expression_type(test_case): mesh = simple_mesh() - offset_provider = {**mesh.offset_provider, "Ioff": IDim, "Joff": JDim, "Koff": KDim} + offset_provider_type = {**mesh.offset_provider_type, "Ioff": IDim, "Joff": JDim, "Koff": KDim} testee, expected_type = test_case result = itir_type_inference.infer( - testee, offset_provider=offset_provider, allow_undeclared_symbols=True + testee, offset_provider_type=offset_provider_type, allow_undeclared_symbols=True ) assert result.type == expected_type @@ -231,14 +231,16 @@ def test_adhoc_polymorphism(): func = im.lambda_("a")(im.lambda_("b")(im.make_tuple("a", "b"))) testee = im.call(im.call(func)(im.ref("a_", bool_type)))(im.ref("b_", int_type)) - result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + result = itir_type_inference.infer( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) assert result.type == ts.TupleType(types=[bool_type, int_type]) def test_aliased_function(): testee = im.let("f", im.lambda_("x")("x"))(im.call("f")(1)) - result = itir_type_inference.infer(testee, offset_provider={}) + result = itir_type_inference.infer(testee, offset_provider_type={}) assert result.args[0].type == ts.FunctionType( pos_only_args=[int_type], pos_or_kw_args={}, kw_only_args={}, returns=int_type @@ -253,7 +255,7 @@ def test_late_offset_axis(): testee = im.call(func)(im.ensure_offset("V2E")) result = itir_type_inference.infer( - testee, offset_provider=mesh.offset_provider, allow_undeclared_symbols=True + testee, offset_provider_type=mesh.offset_provider_type, allow_undeclared_symbols=True ) assert result.type == it_on_e_of_e_type @@ -265,7 +267,9 @@ def test_cast_first_arg_inference(): testee = im.call("cast_")( im.plus(im.literal_from_value(1), im.literal_from_value(2)), "float64" ) - result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + result = itir_type_inference.infer( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) assert result.args[0].type == int_type assert result.type == float64_type @@ -291,7 +295,7 @@ def test_cartesian_fencil_definition(): ], ) - result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) closure_type = it_ts.StencilClosureType( domain=it_ts.DomainType(dims=[IDim]), @@ -336,7 +340,7 @@ def test_unstructured_fencil_definition(): ], ) - result = itir_type_inference.infer(testee, offset_provider=mesh.offset_provider) + result = itir_type_inference.infer(testee, offset_provider_type=mesh.offset_provider_type) closure_type = it_ts.StencilClosureType( domain=it_ts.DomainType(dims=[Vertex, KDim]), @@ -384,7 +388,7 @@ def test_function_definition(): ], ) - result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) closure_type = it_ts.StencilClosureType( domain=it_ts.DomainType(dims=[IDim]), @@ -429,7 +433,7 @@ def test_fencil_with_nb_field_input(): ], ) - result = itir_type_inference.infer(testee, offset_provider=mesh.offset_provider) + result = itir_type_inference.infer(testee, offset_provider_type=mesh.offset_provider_type) assert result.closures[0].stencil.expr.args[0].type == float64_list_type assert result.closures[0].stencil.type.returns == float64_type @@ -456,7 +460,7 @@ def test_program_tuple_setat_short_target(): ], ) - result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) assert ( isinstance(result.body[0].expr.type, ts.TupleType) @@ -487,7 +491,7 @@ def test_program_setat_without_domain(): ], ) - result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) assert ( isinstance(result.body[0].expr.type, ts.DeferredType) @@ -512,7 +516,9 @@ def test_if_stmt(): false_branch=[], ) - result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + result = itir_type_inference.infer( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) assert result.cond.type == bool_type assert result.true_branch[0].expr.type == float_i_field @@ -522,7 +528,7 @@ def test_as_fieldop_without_domain(): im.ref("inp", float_i_field) ) result = itir_type_inference.infer( - testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True ) assert result.type == ts.DeferredType(constraint=ts.FieldType) assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index e04856b75f..f4ea2d7fe1 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -21,7 +21,7 @@ @pytest.fixture -def offset_provider(request): +def offset_provider_type(request): return {"I": common.Dimension("I", kind=common.DimensionKind.HORIZONTAL)} @@ -137,7 +137,7 @@ def common_expr(): assert actual == expected -def test_if_can_deref_no_extraction(offset_provider): +def test_if_can_deref_no_extraction(offset_provider_type): # Test that a subexpression only occurring in one branch of an `if_` is not moved outside the # if statement. A case using `can_deref` is used here as it is common. @@ -157,11 +157,11 @@ def test_if_can_deref_no_extraction(offset_provider): ) ) - actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) + actual = CSE.apply(testee, offset_provider_type=offset_provider_type, within_stencil=True) assert actual == expected -def test_if_can_deref_eligible_extraction(offset_provider): +def test_if_can_deref_eligible_extraction(offset_provider_type): # Test that a subexpression only occurring in both branches of an `if_` is moved outside the # if statement. A case using `can_deref` is used here as it is common. @@ -178,11 +178,11 @@ def test_if_can_deref_eligible_extraction(offset_provider): ) ) - actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) + actual = CSE.apply(testee, offset_provider_type=offset_provider_type, within_stencil=True) assert actual == expected -def test_if_eligible_extraction(offset_provider): +def test_if_eligible_extraction(offset_provider_type): # Test that a subexpression only occurring in the condition of an `if_` is moved outside the # if statement. @@ -191,7 +191,7 @@ def test_if_eligible_extraction(offset_provider): # (λ(_cs_1) → if _cs_1 ∧ _cs_1 then c else d)(a ∧ b) expected = im.let("_cs_1", im.and_("a", "b"))(im.if_(im.and_("_cs_1", "_cs_1"), "c", "d")) - actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) + actual = CSE.apply(testee, offset_provider_type=offset_provider_type, within_stencil=True) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 141091b450..817c06e8f0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -14,11 +14,12 @@ from gt4py import eve from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next import constructors from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import infer_domain from gt4py.next.iterator.ir_utils import domain_utils from gt4py.next.common import Dimension -from gt4py.next import common, NeighborTableOffsetProvider +from gt4py.next import common from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.transforms.constant_folding import ConstantFolding from gt4py.next import utils @@ -29,6 +30,7 @@ KDim = common.Dimension(value="KDim", kind=common.DimensionKind.VERTICAL) Vertex = common.Dimension(value="Vertex", kind=common.DimensionKind.HORIZONTAL) Edge = common.Dimension(value="Edge", kind=common.DimensionKind.HORIZONTAL) +E2VDim = common.Dimension(value="E2V", kind=common.DimensionKind.LOCAL) @pytest.fixture @@ -39,11 +41,10 @@ def offset_provider(): @pytest.fixture def unstructured_offset_provider(): return { - "E2V": NeighborTableOffsetProvider( - np.array([[0, 1]], dtype=np.int32), - Edge, - Vertex, - 2, + "E2V": constructors.as_connectivity( + domain={Edge: 1, E2VDim: 2}, + codomain=Vertex, + data=np.array([[0, 1]], dtype=np.int32), ) } diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index b5b9a62009..168e9490e0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -13,6 +13,7 @@ from gt4py.next.iterator.transforms import fuse_as_fieldop from gt4py.next.type_system import type_specifications as ts + IDim = gtx.Dimension("IDim") field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) @@ -30,7 +31,7 @@ def test_trivial(): d, )(im.ref("inp1", field_type), im.ref("inp2", field_type), im.ref("inp3", field_type)) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True ) assert actual == expected @@ -40,7 +41,7 @@ def test_trivial_literal(): testee = im.op_as_fieldop("plus", d)(im.op_as_fieldop("multiplies", d)(1, 2), 3) expected = im.as_fieldop(im.lambda_()(im.plus(im.multiplies_(1, 2), 3)), d)() actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True ) assert actual == expected @@ -65,7 +66,7 @@ def test_tuple_arg(): d, )() actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True ) assert actual == expected @@ -85,7 +86,7 @@ def test_symref_used_twice(): d, )("inp1", "inp2") actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True ) assert actual == expected @@ -100,7 +101,7 @@ def test_no_inline(): d1, )(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type))) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True ) assert actual == testee @@ -132,6 +133,6 @@ def test_partial_inline(): d1, )(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), "inp1") actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True ) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 23f62842c4..9d51dc4f33 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -52,7 +52,7 @@ def test_trivial(): ) ], ) - testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) testee = infer_domain.infer_program(testee, offset_provider=offset_provider) expected = program_factory( @@ -87,7 +87,7 @@ def test_trivial_let(): ) ], ) - testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) testee = infer_domain.infer_program(testee, offset_provider=offset_provider) expected = program_factory( @@ -128,7 +128,7 @@ def test_top_level_if(): ) ], ) - testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) testee = infer_domain.infer_program(testee, offset_provider=offset_provider) expected = program_factory( @@ -186,7 +186,7 @@ def test_nested_if(): ) ], ) - testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) testee = infer_domain.infer_program(testee, offset_provider=offset_provider) expected = program_factory( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py index 7c991fb9a8..77d3323fb4 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py @@ -8,16 +8,16 @@ from gt4py import next as gtx from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.transforms.prune_casts import PruneCasts from gt4py.next.iterator.type_system import inference as type_inference +from gt4py.next.type_system import type_specifications as ts def test_prune_casts_simple(): x_ref = im.ref("x", ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) y_ref = im.ref("y", ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) testee = im.call("plus")(im.call("cast_")(x_ref, "float64"), im.call("cast_")(y_ref, "float64")) - testee = type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + testee = type_inference.infer(testee, offset_provider_type={}, allow_undeclared_symbols=True) expected = im.call("plus")(im.call("cast_")(x_ref, "float64"), y_ref) actual = PruneCasts.apply(testee) @@ -32,7 +32,7 @@ def test_prune_casts_fieldop(): im.cast_as_fieldop("float64")(x_ref), im.cast_as_fieldop("float64")(y_ref), ) - testee = type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + testee = type_inference.infer(testee, offset_provider_type={}, allow_undeclared_symbols=True) expected = im.op_as_fieldop("plus")( im.cast_as_fieldop("float64")(x_ref), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py index 28bd88b853..0760247996 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py @@ -11,11 +11,20 @@ import pytest from gt4py.eve.utils import UIDs +from gt4py.next import common from gt4py.next.iterator import ir -from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce, _get_partial_offset_tags from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce, _get_partial_offset_tags -from next_tests.unit_tests.conftest import DummyConnectivity + +def dummy_connectivity_type(max_neighbors: int, has_skip_values: bool): + return common.NeighborConnectivityType( + domain=[common.Dimension("dummy_origin"), common.Dimension("dummy_neighbor")], + codomain=common.Dimension("dummy_codomain"), + skip_value=common._DEFAULT_SKIP_VALUE if has_skip_values else None, + dtype=None, + max_neighbors=max_neighbors, + ) @pytest.fixture(params=[True, False]) @@ -67,7 +76,7 @@ def reduction_if(): ], ) def test_get_partial_offsets(reduction, request): - offset_provider = {"Dim": SimpleNamespace(max_neighbors=3, has_skip_values=False)} + offset_provider_type = {"Dim": SimpleNamespace(max_neighbors=3, has_skip_values=False)} partial_offsets = _get_partial_offset_tags(request.getfixturevalue(reduction).args) assert set(partial_offsets) == {"Dim"} @@ -108,63 +117,73 @@ def _expected(red, dim, max_neighbors, has_skip_values, shifted_arg=0): def test_basic(basic_reduction, has_skip_values): expected = _expected(basic_reduction, "Dim", 3, has_skip_values) - offset_provider = {"Dim": DummyConnectivity(max_neighbors=3, has_skip_values=has_skip_values)} - actual = UnrollReduce.apply(basic_reduction, offset_provider=offset_provider) + offset_provider_type = { + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=has_skip_values) + } + actual = UnrollReduce.apply(basic_reduction, offset_provider_type=offset_provider_type) assert actual == expected def test_reduction_with_shift_on_second_arg(reduction_with_shift_on_second_arg, has_skip_values): expected = _expected(reduction_with_shift_on_second_arg, "Dim", 1, has_skip_values, 1) - offset_provider = {"Dim": DummyConnectivity(max_neighbors=1, has_skip_values=has_skip_values)} - actual = UnrollReduce.apply(reduction_with_shift_on_second_arg, offset_provider=offset_provider) + offset_provider_type = { + "Dim": dummy_connectivity_type(max_neighbors=1, has_skip_values=has_skip_values) + } + actual = UnrollReduce.apply( + reduction_with_shift_on_second_arg, offset_provider_type=offset_provider_type + ) assert actual == expected def test_reduction_with_if(reduction_if): expected = _expected(reduction_if, "Dim", 2, False) - offset_provider = {"Dim": DummyConnectivity(max_neighbors=2, has_skip_values=False)} - actual = UnrollReduce.apply(reduction_if, offset_provider=offset_provider) + offset_provider_type = {"Dim": dummy_connectivity_type(max_neighbors=2, has_skip_values=False)} + actual = UnrollReduce.apply(reduction_if, offset_provider_type=offset_provider_type) assert actual == expected def test_reduction_with_irrelevant_full_shift(reduction_with_irrelevant_full_shift): expected = _expected(reduction_with_irrelevant_full_shift, "Dim", 3, False) - offset_provider = { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "IrrelevantDim": DummyConnectivity( + offset_provider_type = { + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "IrrelevantDim": dummy_connectivity_type( max_neighbors=1, has_skip_values=True ), # different max_neighbors and skip value to trigger error } actual = UnrollReduce.apply( - reduction_with_irrelevant_full_shift, offset_provider=offset_provider + reduction_with_irrelevant_full_shift, offset_provider_type=offset_provider_type ) assert actual == expected @pytest.mark.parametrize( - "offset_provider", + "offset_provider_type", [ { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "Dim2": DummyConnectivity(max_neighbors=2, has_skip_values=False), + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "Dim2": dummy_connectivity_type(max_neighbors=2, has_skip_values=False), }, { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "Dim2": DummyConnectivity(max_neighbors=3, has_skip_values=True), + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "Dim2": dummy_connectivity_type(max_neighbors=3, has_skip_values=True), }, { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "Dim2": DummyConnectivity(max_neighbors=2, has_skip_values=True), + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "Dim2": dummy_connectivity_type(max_neighbors=2, has_skip_values=True), }, ], ) -def test_reduction_with_incompatible_shifts(reduction_with_incompatible_shifts, offset_provider): - offset_provider = { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "Dim2": DummyConnectivity(max_neighbors=2, has_skip_values=False), +def test_reduction_with_incompatible_shifts( + reduction_with_incompatible_shifts, offset_provider_type +): + offset_provider_type = { + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "Dim2": dummy_connectivity_type(max_neighbors=2, has_skip_values=False), } with pytest.raises(RuntimeError, match="incompatible"): - UnrollReduce.apply(reduction_with_incompatible_shifts, offset_provider=offset_provider) + UnrollReduce.apply( + reduction_with_incompatible_shifts, offset_provider_type=offset_provider_type + ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py index 1a86f7b0f8..97591122e5 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py @@ -21,7 +21,7 @@ def test_funcall_to_op(): ) actual = it2gtfn.GTFN_lowering( - grid_type=gtx.GridType.CARTESIAN, offset_provider={}, column_axis=None + grid_type=gtx.GridType.CARTESIAN, offset_provider_type={}, column_axis=None ).visit(testee) assert expected == actual @@ -32,7 +32,7 @@ def test_unapplied_funcall_to_function_object(): expected = gtfn_ir.SymRef(id="plus") actual = it2gtfn.GTFN_lowering( - grid_type=gtx.GridType.CARTESIAN, offset_provider={}, column_axis=None + grid_type=gtx.GridType.CARTESIAN, offset_provider_type={}, column_axis=None ).visit(testee) assert expected == actual diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py index 329b2814d2..62d88d9f0a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py @@ -11,6 +11,7 @@ import ctypes import unittest import unittest.mock +from unittest.mock import patch import numpy as np import pytest @@ -20,19 +21,15 @@ from gt4py.next.ffront.fbuiltins import where from next_tests.integration_tests import cases -from next_tests.integration_tests.cases import ( - E2V, - cartesian_case, - unstructured_case, -) +from next_tests.integration_tests.cases import E2V, cartesian_case, unstructured_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, mesh_descriptor, ) -from unittest.mock import patch from . import pytestmark + dace = pytest.importorskip("dace") @@ -151,14 +148,14 @@ def test_dace_fastcall_with_connectivity(unstructured_case, monkeypatch): # check that test connectivities are allocated on host memory # this is an assumption to test that fast_call cannot be used for gpu tests - assert isinstance(connectivity_E2V.table, np.ndarray) + assert isinstance(connectivity_E2V.ndarray, np.ndarray) @gtx.field_operator def testee(a: cases.VField) -> cases.EField: return a(E2V[0]) (a,), kwfields = cases.get_default_data(unstructured_case, testee) - numpy_ref = lambda a: a[connectivity_E2V.table[:, 0]] + numpy_ref = lambda a: a[connectivity_E2V.ndarray[:, 0]] mock_fast_call, mock_construct_args = make_mocks(monkeypatch) @@ -194,12 +191,11 @@ def verify_testee(offset_provider): # Here we copy the connectivity to gpu memory, and resuse the same cupy array # on multiple program calls, in order to ensure that fast_call is used. offset_provider = { - "E2V": gtx.NeighborTableOffsetProvider( - table=cp.asarray(connectivity_E2V.table), - origin_axis=connectivity_E2V.origin_axis, - neighbor_axis=connectivity_E2V.neighbor_axis, - max_neighbors=connectivity_E2V.max_neighbors, - has_skip_values=connectivity_E2V.has_skip_values, + "E2V": gtx.as_connectivity( + domain=connectivity_E2V.domain, + codomain=connectivity_E2V.codomain, + data=cp.asarray(connectivity_E2V.ndarray), + skip_value=connectivity_E2V.skip_value, ) } diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index e0c0c3fa4e..9c52ea81c3 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -18,7 +18,7 @@ import numpy as np import pytest -from gt4py.next import common as gtx_common +from gt4py.next import common as gtx_common, constructors from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_specifications as ts @@ -50,13 +50,7 @@ "IDim": IDim, } SIMPLE_MESH: MeshDescriptor = simple_mesh() -SIMPLE_MESH_OFFSET_PROVIDER: dict[str, gtx_common.Connectivity | gtx_common.Dimension] = ( - SIMPLE_MESH.offset_provider | CARTESIAN_OFFSETS -) SKIP_VALUE_MESH: MeshDescriptor = skip_value_mesh() -SKIP_VALUE_MESH_OFFSET_PROVIDER: dict[str, gtx_common.Connectivity | gtx_common.Dimension] = ( - SKIP_VALUE_MESH.offset_provider | CARTESIAN_OFFSETS -) SIZE_TYPE = ts.ScalarType(ts.ScalarKind.INT32) FSYMBOLS = dict( __w_size_0=N, @@ -83,20 +77,20 @@ def make_mesh_symbols(mesh: MeshDescriptor): __vertices_size_0=mesh.num_vertices, __vertices_stride_0=1, __connectivity_C2E_size_0=mesh.num_cells, - __connectivity_C2E_size_1=mesh.offset_provider["C2E"].max_neighbors, - __connectivity_C2E_stride_0=mesh.offset_provider["C2E"].max_neighbors, + __connectivity_C2E_size_1=mesh.offset_provider_type["C2E"].max_neighbors, + __connectivity_C2E_stride_0=mesh.offset_provider_type["C2E"].max_neighbors, __connectivity_C2E_stride_1=1, __connectivity_C2V_size_0=mesh.num_cells, - __connectivity_C2V_size_1=mesh.offset_provider["C2V"].max_neighbors, - __connectivity_C2V_stride_0=mesh.offset_provider["C2V"].max_neighbors, + __connectivity_C2V_size_1=mesh.offset_provider_type["C2V"].max_neighbors, + __connectivity_C2V_stride_0=mesh.offset_provider_type["C2V"].max_neighbors, __connectivity_C2V_stride_1=1, __connectivity_E2V_size_0=mesh.num_edges, - __connectivity_E2V_size_1=mesh.offset_provider["E2V"].max_neighbors, - __connectivity_E2V_stride_0=mesh.offset_provider["E2V"].max_neighbors, + __connectivity_E2V_size_1=mesh.offset_provider_type["E2V"].max_neighbors, + __connectivity_E2V_stride_0=mesh.offset_provider_type["E2V"].max_neighbors, __connectivity_E2V_stride_1=1, __connectivity_V2E_size_0=mesh.num_vertices, - __connectivity_V2E_size_1=mesh.offset_provider["V2E"].max_neighbors, - __connectivity_V2E_stride_0=mesh.offset_provider["V2E"].max_neighbors, + __connectivity_V2E_size_1=mesh.offset_provider_type["V2E"].max_neighbors, + __connectivity_V2E_stride_0=mesh.offset_provider_type["V2E"].max_neighbors, __connectivity_V2E_stride_1=1, ) @@ -1018,14 +1012,14 @@ def test_gtir_connectivity_shift(): CELL_OFFSET_FTYPE = ts.FieldType(dims=[Cell], dtype=SIZE_TYPE) EDGE_OFFSET_FTYPE = ts.FieldType(dims=[Edge], dtype=SIZE_TYPE) - connectivity_C2E = SIMPLE_MESH_OFFSET_PROVIDER["C2E"] + connectivity_C2E = SIMPLE_MESH.offset_provider["C2E"] assert isinstance(connectivity_C2E, gtx_common.NeighborTable) - connectivity_E2V = SIMPLE_MESH_OFFSET_PROVIDER["E2V"] + connectivity_E2V = SIMPLE_MESH.offset_provider["E2V"] assert isinstance(connectivity_E2V, gtx_common.NeighborTable) ev = np.random.rand(SIMPLE_MESH.num_edges, SIMPLE_MESH.num_vertices) - ref = ev[connectivity_C2E.table[:, C2E_neighbor_idx], :][ - :, connectivity_E2V.table[:, E2V_neighbor_idx] + ref = ev[connectivity_C2E.ndarray[:, C2E_neighbor_idx], :][ + :, connectivity_E2V.ndarray[:, E2V_neighbor_idx] ] for i, stencil in enumerate( @@ -1053,7 +1047,7 @@ def test_gtir_connectivity_shift(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) ce = np.empty([SIMPLE_MESH.num_cells, SIMPLE_MESH.num_edges]) @@ -1062,8 +1056,8 @@ def test_gtir_connectivity_shift(): ev, c2e_offset=np.full(SIMPLE_MESH.num_cells, C2E_neighbor_idx, dtype=np.int32), e2v_offset=np.full(SIMPLE_MESH.num_edges, E2V_neighbor_idx, dtype=np.int32), - connectivity_C2E=connectivity_C2E.table, - connectivity_E2V=connectivity_E2V.table, + connectivity_C2E=connectivity_C2E.ndarray, + connectivity_E2V=connectivity_E2V.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), __ce_field_size_0=SIMPLE_MESH.num_cells, @@ -1114,15 +1108,17 @@ def test_gtir_connectivity_shift_chain(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) - connectivity_E2V = SIMPLE_MESH_OFFSET_PROVIDER["E2V"] + connectivity_E2V = SIMPLE_MESH.offset_provider["E2V"] assert isinstance(connectivity_E2V, gtx_common.NeighborTable) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) e = np.random.rand(SIMPLE_MESH.num_edges) - ref = e[connectivity_V2E.table[connectivity_E2V.table[:, E2V_neighbor_idx], V2E_neighbor_idx]] + ref = e[ + connectivity_V2E.ndarray[connectivity_E2V.ndarray[:, E2V_neighbor_idx], V2E_neighbor_idx] + ] # new empty output field e_out = np.empty_like(e) @@ -1130,8 +1126,8 @@ def test_gtir_connectivity_shift_chain(): sdfg( e, e_out, - connectivity_E2V=connectivity_E2V.table, - connectivity_V2E=connectivity_V2E.table, + connectivity_E2V=connectivity_E2V.ndarray, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), __edges_out_size_0=SIMPLE_MESH.num_edges, @@ -1174,30 +1170,30 @@ def test_gtir_neighbors_as_input(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - v2e_field = np.random.rand(SIMPLE_MESH.num_vertices, connectivity_V2E.max_neighbors) + v2e_field = np.random.rand(SIMPLE_MESH.num_vertices, connectivity_V2E.shape[1]) e = np.random.rand(SIMPLE_MESH.num_edges) v = np.empty(SIMPLE_MESH.num_vertices, dtype=v2e_field.dtype) v_ref = [ functools.reduce(lambda x, y: x + y, v2e_values + e[v2e_neighbors], init_value) - for v2e_neighbors, v2e_values in zip(connectivity_V2E.table, v2e_field, strict=True) + for v2e_neighbors, v2e_values in zip(connectivity_V2E.ndarray, v2e_field, strict=True) ] sdfg( v2e_field, e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), __v2e_field_size_0=SIMPLE_MESH.num_vertices, - __v2e_field_size_1=connectivity_V2E.max_neighbors, - __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_size_1=connectivity_V2E.shape[1], + __v2e_field_stride_0=connectivity_V2E.shape[1], __v2e_field_stride_1=1, ) assert np.allclose(v, v_ref) @@ -1210,7 +1206,7 @@ def test_gtir_neighbors_as_output(): gtx_common.GridType.UNSTRUCTURED, ranges={ Vertex: (0, "nvertices"), - V2EDim: (0, SIMPLE_MESH_OFFSET_PROVIDER["V2E"].max_neighbors), + V2EDim: (0, SIMPLE_MESH.offset_provider_type["V2E"].max_neighbors), }, ) vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) @@ -1232,9 +1228,9 @@ def test_gtir_neighbors_as_output(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) e = np.random.rand(SIMPLE_MESH.num_edges) @@ -1243,7 +1239,7 @@ def test_gtir_neighbors_as_output(): sdfg( e, v2e_field, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), __v2e_field_size_0=SIMPLE_MESH.num_vertices, @@ -1251,7 +1247,7 @@ def test_gtir_neighbors_as_output(): __v2e_field_stride_0=connectivity_V2E.max_neighbors, __v2e_field_stride_1=1, ) - assert np.allclose(v2e_field, e[connectivity_V2E.table]) + assert np.allclose(v2e_field, e[connectivity_V2E.ndarray]) def test_gtir_reduce(): @@ -1278,13 +1274,13 @@ def test_gtir_reduce(): ) )(im.as_fieldop_neighbors("V2E", "edges", vertex_domain)) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) e = np.random.rand(SIMPLE_MESH.num_edges) v_ref = [ functools.reduce(lambda x, y: x + y, e[v2e_neighbors], init_value) - for v2e_neighbors in connectivity_V2E.table + for v2e_neighbors in connectivity_V2E.ndarray ] for i, stencil in enumerate([stencil_inlined, stencil_fieldview]): @@ -1305,7 +1301,7 @@ def test_gtir_reduce(): ) ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) # new empty output field v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) @@ -1313,7 +1309,7 @@ def test_gtir_reduce(): sdfg( e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), ) @@ -1344,7 +1340,7 @@ def test_gtir_reduce_with_skip_values(): ) )(im.as_fieldop_neighbors("V2E", "edges", vertex_domain)) - connectivity_V2E = SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SKIP_VALUE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) e = np.random.rand(SKIP_VALUE_MESH.num_edges) @@ -1354,7 +1350,7 @@ def test_gtir_reduce_with_skip_values(): [e[i] if i != gtx_common._DEFAULT_SKIP_VALUE else 0.0 for i in v2e_neighbors], init_value, ) - for v2e_neighbors in connectivity_V2E.table + for v2e_neighbors in connectivity_V2E.ndarray ] for i, stencil in enumerate([stencil_inlined, stencil_fieldview]): @@ -1375,7 +1371,7 @@ def test_gtir_reduce_with_skip_values(): ) ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH.offset_provider_type) # new empty output field v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) @@ -1383,7 +1379,7 @@ def test_gtir_reduce_with_skip_values(): sdfg( e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SKIP_VALUE_MESH), ) @@ -1394,10 +1390,10 @@ def test_gtir_reduce_dot_product(): init_value = np.random.rand() vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) - connectivity_V2E = SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SKIP_VALUE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - v2e_field = np.random.rand(SKIP_VALUE_MESH.num_vertices, connectivity_V2E.max_neighbors) + v2e_field = np.random.rand(*connectivity_V2E.shape) e = np.random.rand(SKIP_VALUE_MESH.num_edges) v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) v_ref = [ @@ -1409,7 +1405,7 @@ def test_gtir_reduce_dot_product(): ), init_value, ) - for v2e_neighbors, v2e_values in zip(connectivity_V2E.table, v2e_field) + for v2e_neighbors, v2e_values in zip(connectivity_V2E.ndarray, v2e_field) ] testee = gtir.Program( @@ -1448,17 +1444,17 @@ def test_gtir_reduce_dot_product(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH.offset_provider_type) sdfg( v2e_field, e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **make_mesh_symbols(SKIP_VALUE_MESH), __v2e_field_size_0=SKIP_VALUE_MESH.num_vertices, - __v2e_field_size_1=connectivity_V2E.max_neighbors, - __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_size_1=connectivity_V2E.shape[1], + __v2e_field_stride_0=connectivity_V2E.shape[1], __v2e_field_stride_1=1, ) assert np.allclose(v, v_ref) @@ -1500,14 +1496,14 @@ def test_gtir_reduce_with_cond_neighbors(): ], ) - connectivity_V2E = SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SKIP_VALUE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - v2e_field = np.random.rand(SKIP_VALUE_MESH.num_vertices, connectivity_V2E.max_neighbors) + v2e_field = np.random.rand(*connectivity_V2E.shape) e = np.random.rand(SKIP_VALUE_MESH.num_edges) for use_sparse in [False, True]: - sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH.offset_provider_type) v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) v_ref = [ @@ -1525,19 +1521,19 @@ def test_gtir_reduce_with_cond_neighbors(): [e[i] if i != gtx_common._DEFAULT_SKIP_VALUE else 0.0 for i in v2e_neighbors], init_value, ) - for v2e_neighbors, v2e_values in zip(connectivity_V2E.table, v2e_field, strict=True) + for v2e_neighbors, v2e_values in zip(connectivity_V2E.ndarray, v2e_field, strict=True) ] sdfg( np.bool_(use_sparse), v2e_field, e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SKIP_VALUE_MESH), __v2e_field_size_0=SKIP_VALUE_MESH.num_vertices, - __v2e_field_size_1=connectivity_V2E.max_neighbors, - __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_size_1=connectivity_V2E.shape[1], + __v2e_field_stride_0=connectivity_V2E.shape[1], __v2e_field_stride_1=1, ) assert np.allclose(v, v_ref) @@ -1631,9 +1627,9 @@ def test_gtir_let_lambda_with_connectivity(): C2V_neighbor_idx = 2 cell_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Cell: (0, "ncells")}) - connectivity_C2E = SIMPLE_MESH_OFFSET_PROVIDER["C2E"] + connectivity_C2E = SIMPLE_MESH.offset_provider["C2E"] assert isinstance(connectivity_C2E, gtx_common.NeighborTable) - connectivity_C2V = SIMPLE_MESH_OFFSET_PROVIDER["C2V"] + connectivity_C2V = SIMPLE_MESH.offset_provider["C2V"] assert isinstance(connectivity_C2V, gtx_common.NeighborTable) testee = gtir.Program( @@ -1669,22 +1665,22 @@ def test_gtir_let_lambda_with_connectivity(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) e = np.random.rand(SIMPLE_MESH.num_edges) v = np.random.rand(SIMPLE_MESH.num_vertices) c = np.empty(SIMPLE_MESH.num_cells) ref = ( - e[connectivity_C2E.table[:, C2E_neighbor_idx]] - + v[connectivity_C2V.table[:, C2V_neighbor_idx]] + e[connectivity_C2E.ndarray[:, C2E_neighbor_idx]] + + v[connectivity_C2V.ndarray[:, C2V_neighbor_idx]] ) sdfg( cells=c, edges=e, vertices=v, - connectivity_C2E=connectivity_C2E.table, - connectivity_C2V=connectivity_C2V.table, + connectivity_C2E=connectivity_C2E.ndarray, + connectivity_C2V=connectivity_C2V.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), ) diff --git a/tests/next_tests/unit_tests/test_constructors.py b/tests/next_tests/unit_tests/test_constructors.py index 6e9dfa3d64..0998ab8eab 100644 --- a/tests/next_tests/unit_tests/test_constructors.py +++ b/tests/next_tests/unit_tests/test_constructors.py @@ -11,10 +11,7 @@ from gt4py import next as gtx from gt4py._core import definitions as core_defs -from gt4py.next import allocators as next_allocators, common, float32 -from gt4py.next.program_processors.runners import roundtrip - -from next_tests.integration_tests import cases +from gt4py.next import allocators as next_allocators, common I = gtx.Dimension("I") @@ -154,3 +151,12 @@ def test_field_wrong_origin(): @pytest.mark.xfail(reason="aligned_index not supported yet") def test_aligned_index(): gtx.as_field([I], np.random.rand(sizes[I]).astype(gtx.float32), aligned_index=[I, 0]) + + +@pytest.mark.parametrize( + "data, skip_value", + [([0, 1, 2], None), ([0, 1, common._DEFAULT_SKIP_VALUE], common._DEFAULT_SKIP_VALUE)], +) +def test_as_connectivity(nd_array_implementation, data, skip_value): + testee = gtx.as_connectivity([I], J, nd_array_implementation.array(data)) + assert testee.skip_value is skip_value From 3fb206e46ceecf07b7ef6c668239d62d79028503 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 26 Nov 2024 10:53:19 +0100 Subject: [PATCH 3/5] feat[next][dace]: Symbolic domain without dace array offsets (#1735) Add support for field operator domain with symbolic shape, with dimension extent in non zero-based range. --- .../runners/dace_common/utility.py | 10 +- .../gtir_builtin_translators.py | 127 ++++++++++----- .../runners/dace_fieldview/gtir_dataflow.py | 100 +++++++----- .../runners/dace_fieldview/gtir_sdfg.py | 148 +++++++++++++----- .../runners/dace_fieldview/utility.py | 11 +- .../dace_tests/test_gtir_to_sdfg.py | 123 +++++++++++++-- 6 files changed, 367 insertions(+), 152 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index 29395a30c1..3e96ef3cec 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -9,7 +9,7 @@ from __future__ import annotations import re -from typing import Final, Optional, Sequence +from typing import Final, Literal, Optional, Sequence import dace @@ -51,12 +51,16 @@ def connectivity_identifier(name: str) -> str: return f"connectivity_{name}" +def field_symbol_name(field_name: str, axis: int, sym: Literal["size", "stride"]) -> str: + return f"__{field_name}_{sym}_{axis}" + + def field_size_symbol_name(field_name: str, axis: int) -> str: - return f"__{field_name}_size_{axis}" + return field_symbol_name(field_name, axis, "size") def field_stride_symbol_name(field_name: str, axis: int) -> str: - return f"__{field_name}_stride_{axis}" + return field_symbol_name(field_name, axis, "stride") def is_field_symbol(name: str) -> bool: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 69aedf44d6..60dcd8ddc9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -10,7 +10,7 @@ import abc import dataclasses -from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, TypeAlias +from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias import dace import dace.subsets as sbs @@ -33,6 +33,34 @@ from gt4py.next.program_processors.runners.dace_fieldview import gtir_sdfg +def _get_domain_indices( + dims: Sequence[gtx_common.Dimension], offsets: Optional[Sequence[dace.symbolic.SymExpr]] = None +) -> sbs.Indices: + """ + Helper function to construct the list of indices for a field domain, applying + an optional offset in each dimension as start index. + + Args: + dims: The field dimensions. + offsets: The range start index in each dimension. + + Returns: + A list of indices for field access in dace arrays. As this list is returned + as `dace.subsets.Indices`, it should be converted to `dace.subsets.Range` before + being used in memlet subset because ranges are better supported throughout DaCe. + """ + index_variables = [dace.symbolic.SymExpr(dace_gtir_utils.get_map_variable(dim)) for dim in dims] + if offsets is None: + return sbs.Indices(index_variables) + else: + return sbs.Indices( + [ + index - offset if offset != 0 else index + for index, offset in zip(index_variables, offsets, strict=True) + ] + ) + + @dataclasses.dataclass(frozen=True) class FieldopData: """ @@ -45,42 +73,59 @@ class FieldopData: Args: dc_node: DaCe access node to the data storage. gt_type: GT4Py type definition, which includes the field domain information. + offset: List of index offsets, in each dimension, when the dimension range + does not start from zero; assume zero offset, if not set. """ dc_node: dace.nodes.AccessNode gt_type: ts.FieldType | ts.ScalarType + offset: Optional[list[dace.symbolic.SymExpr]] + + def make_copy(self, data_node: dace.nodes.AccessNode) -> FieldopData: + """Create a copy of this data descriptor with a different access node.""" + assert data_node != self.dc_node + return FieldopData(data_node, self.gt_type, self.offset) def get_local_view( self, domain: FieldopDomain ) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: - """Helper method to access a field in local view, given a field operator domain.""" + """Helper method to access a field in local view, given the compute domain of a field operator.""" if isinstance(self.gt_type, ts.ScalarType): return gtir_dataflow.MemletExpr( dc_node=self.dc_node, gt_dtype=self.gt_type, subset=sbs.Indices([0]) ) if isinstance(self.gt_type, ts.FieldType): - indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = { - dim: gtir_dataflow.SymbolExpr(dace_gtir_utils.get_map_variable(dim), INDEX_DTYPE) - for dim, _, _ in domain + domain_dims = [dim for dim, _, _ in domain] + domain_indices = _get_domain_indices(domain_dims) + it_indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = { + dim: gtir_dataflow.SymbolExpr(index, INDEX_DTYPE) + for dim, index in zip(domain_dims, domain_indices) } + field_domain = [ + (dim, dace.symbolic.SymExpr(0) if self.offset is None else self.offset[i]) + for i, dim in enumerate(self.gt_type.dims) + ] local_dims = [ dim for dim in self.gt_type.dims if dim.kind == gtx_common.DimensionKind.LOCAL ] - if len(local_dims) == 0: return gtir_dataflow.IteratorExpr( - self.dc_node, self.gt_type.dtype, self.gt_type.dims, indices + self.dc_node, self.gt_type.dtype, field_domain, it_indices ) elif len(local_dims) == 1: field_dtype = itir_ts.ListType( element_type=self.gt_type.dtype, offset_type=local_dims[0] ) - field_dims = [ - dim for dim in self.gt_type.dims if dim.kind != gtx_common.DimensionKind.LOCAL + field_domain = [ + (dim, offset) + for dim, offset in field_domain + if dim.kind != gtx_common.DimensionKind.LOCAL ] - return gtir_dataflow.IteratorExpr(self.dc_node, field_dtype, field_dims, indices) + return gtir_dataflow.IteratorExpr( + self.dc_node, field_dtype, field_domain, it_indices + ) else: raise ValueError( @@ -155,9 +200,9 @@ def _parse_fieldop_arg( return arg.get_local_view(domain) -def _get_field_shape( +def _get_field_layout( domain: FieldopDomain, -) -> tuple[list[gtx_common.Dimension], list[dace.symbolic.SymExpr]]: +) -> tuple[list[gtx_common.Dimension], list[dace.symbolic.SymExpr], list[dace.symbolic.SymExpr]]: """ Parse the field operator domain and generates the shape of the result field. @@ -174,11 +219,14 @@ def _get_field_shape( domain: The field operator domain. Returns: - A tuple of two lists: the list of field dimensions and the list of dace - array sizes in each dimension. + A tuple of three lists containing: + - the domain dimensions + - the domain offset in each dimension + - the domain size in each dimension """ - domain_dims, _, domain_ubs = zip(*domain) - return list(domain_dims), list(domain_ubs) + domain_dims, domain_lbs, domain_ubs = zip(*domain) + domain_sizes = [(ub - lb) for lb, ub in zip(domain_lbs, domain_ubs)] + return list(domain_dims), list(domain_lbs), domain_sizes def _create_temporary_field( @@ -189,7 +237,7 @@ def _create_temporary_field( dataflow_output: gtir_dataflow.DataflowOutputEdge, ) -> FieldopData: """Helper method to allocate a temporary field where to write the output of a field operator.""" - field_dims, field_shape = _get_field_shape(domain) + field_dims, field_offset, field_shape = _get_field_layout(domain) output_desc = dataflow_output.result.dc_node.desc(sdfg) if isinstance(output_desc, dace.data.Array): @@ -197,6 +245,7 @@ def _create_temporary_field( assert isinstance(node_type.dtype.element_type, ts.ScalarType) assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype.element_type) # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) + field_offset.extend(output_desc.offset) field_shape.extend(output_desc.shape) elif isinstance(output_desc, dace.data.Scalar): assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) @@ -215,7 +264,11 @@ def _create_temporary_field( assert dataflow_output.result.gt_dtype.offset_type is not None field_dims.append(dataflow_output.result.gt_dtype.offset_type) - return FieldopData(field_node, ts.FieldType(field_dims, field_dtype)) + return FieldopData( + field_node, + ts.FieldType(field_dims, field_dtype), + offset=(field_offset if set(field_offset) != {0} else None), + ) def extract_domain(node: gtir.Node) -> FieldopDomain: @@ -285,7 +338,8 @@ def translate_as_fieldop( # parse the domain of the field operator domain = extract_domain(domain_expr) - domain_indices = sbs.Indices([dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain]) + domain_dims, domain_offsets, _ = zip(*domain) + domain_indices = _get_domain_indices(domain_dims, domain_offsets) # visit the list of arguments to be passed to the lambda expression stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] @@ -350,10 +404,8 @@ def translate_broadcast_scalar( assert cpm.is_ref_to(stencil_expr, "deref") domain = extract_domain(domain_expr) - field_dims, field_shape = _get_field_shape(domain) - field_subset = sbs.Range.from_string( - ",".join(dace_gtir_utils.get_map_variable(dim) for dim in field_dims) - ) + output_dims, output_offset, output_shape = _get_field_layout(domain) + output_subset = sbs.Range.from_indices(_get_domain_indices(output_dims, output_offset)) assert len(node.args) == 1 scalar_expr = _parse_fieldop_arg(node.args[0], sdfg, state, sdfg_builder, domain) @@ -369,26 +421,15 @@ def translate_broadcast_scalar( assert isinstance(scalar_expr, gtir_dataflow.IteratorExpr) if len(node.args[0].type.dims) == 0: # zero-dimensional field input_subset = "0" - elif all( - isinstance(scalar_expr.indices[dim], gtir_dataflow.SymbolExpr) - for dim in scalar_expr.dimensions - if dim not in field_dims - ): - input_subset = ",".join( - dace_gtir_utils.get_map_variable(dim) - if dim in field_dims - else scalar_expr.indices[dim].value # type: ignore[union-attr] # catched by exception above - for dim in scalar_expr.dimensions - ) else: - raise ValueError(f"Cannot deref field {scalar_expr.field} in broadcast expression.") + input_subset = scalar_expr.get_memlet_subset(sdfg) input_node = scalar_expr.field gt_dtype = node.args[0].type.dtype else: raise ValueError(f"Unexpected argument {node.args[0]} in broadcast expression.") - output, _ = sdfg.add_temp_transient(field_shape, input_node.desc(sdfg).dtype) + output, _ = sdfg.add_temp_transient(output_shape, input_node.desc(sdfg).dtype) output_node = state.add_access(output) sdfg_builder.add_mapped_tasklet( @@ -400,13 +441,13 @@ def translate_broadcast_scalar( }, inputs={"__inp": dace.Memlet(data=input_node.data, subset=input_subset)}, code="__val = __inp", - outputs={"__val": dace.Memlet(data=output_node.data, subset=field_subset)}, + outputs={"__val": dace.Memlet(data=output_node.data, subset=output_subset)}, input_nodes={input_node.data: input_node}, output_nodes={output_node.data: output_node}, external_edges=True, ) - return FieldopData(output_node, ts.FieldType(field_dims, gt_dtype)) + return FieldopData(output_node, ts.FieldType(output_dims, gt_dtype), output_offset) def translate_if( @@ -467,7 +508,7 @@ def construct_output(inner_data: FieldopData) -> FieldopData: outer, _ = sdfg.add_temp_transient_like(inner_desc) outer_node = state.add_access(outer) - return FieldopData(outer_node, inner_data.gt_type) + return inner_data.make_copy(outer_node) result_temps = gtx_utils.tree_map(construct_output)(true_br_args) @@ -513,7 +554,7 @@ def _get_data_nodes( ) -> FieldopResult: if isinstance(data_type, ts.FieldType): data_node = state.add_access(data_name) - return FieldopData(data_node, data_type) + return sdfg_builder.make_field(data_node, data_type) elif isinstance(data_type, ts.ScalarType): if data_name in sdfg.symbols: @@ -522,7 +563,7 @@ def _get_data_nodes( ) else: data_node = state.add_access(data_name) - return FieldopData(data_node, data_type) + return sdfg_builder.make_field(data_node, data_type) elif isinstance(data_type, ts.TupleType): tuple_fields = dace_gtir_utils.get_tuple_fields(data_name, data_type) @@ -579,7 +620,7 @@ def translate_literal( data_type = node.type data_node = _get_symbolic_value(sdfg, state, sdfg_builder, node.value, data_type) - return FieldopData(data_node, data_type) + return FieldopData(data_node, data_type, offset=None) def translate_make_tuple( @@ -708,7 +749,7 @@ def translate_scalar_expr( dace.Memlet(data=temp_name, subset="0"), ) - return FieldopData(temp_node, node.type) + return FieldopData(temp_node, node.type, offset=None) def translate_symbol_ref( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 74142dec66..cfba4d61e5 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -90,17 +90,42 @@ class IteratorExpr: Args: field: Access node to the field this iterator operates on. gt_dtype: GT4Py data type, which includes the `offset_type` local dimension for lists. - dimensions: Field domain represented as a sorted list of dimensions, needed - to order the map index variables and dereference an element in the field. + field_domain: Field domain represented as a sorted list of dimensions and offset values, + used to find the position of a map index variable in the memlet subset. The offset + value is either the start index of dimension range or the compile-time value of + a shift expression, or a composition of both, and it must be subtracted to the index + variable when constructing the memlet subset range. indices: Maps each dimension to an index value, which could be either a symbolic value or the result of a tasklet computation like neighbors connectivity or dynamic offset. """ field: dace.nodes.AccessNode gt_dtype: itir_ts.ListType | ts.ScalarType - dimensions: list[gtx_common.Dimension] + field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymExpr]] indices: dict[gtx_common.Dimension, DataExpr] + def get_memlet_subset(self, sdfg: dace.SDFG) -> sbs.Range: + if not all(isinstance(self.indices[dim], SymbolExpr) for dim, _ in self.field_domain): + raise ValueError(f"Cannot deref iterator {self}.") + + field_desc = self.field.desc(sdfg) + if isinstance(self.gt_dtype, itir_ts.ListType): + assert len(field_desc.shape) == len(self.field_domain) + 1 + assert self.gt_dtype.offset_type is not None + field_domain = [*self.field_domain, (self.gt_dtype.offset_type, 0)] + else: + assert len(field_desc.shape) == len(self.field_domain) + field_domain = self.field_domain + + return sbs.Range.from_string( + ",".join( + str(self.indices[dim].value - offset) # type: ignore[union-attr] + if dim in self.indices + else f"0:{size}" + for (dim, offset), size in zip(field_domain, field_desc.shape, strict=True) + ) + ) + class DataflowInputEdge(Protocol): """ @@ -271,8 +296,17 @@ def _add_input_data_edge( src_subset: sbs.Range, dst_node: dace.nodes.Node, dst_conn: Optional[str] = None, + src_offset: Optional[list[dace.symbolic.SymExpr]] = None, ) -> None: - edge = MemletInputEdge(self.state, src, src_subset, dst_node, dst_conn) + input_subset = ( + src_subset + if src_offset is None + else sbs.Range( + (start - off, stop - off, step) + for (start, stop, step), off in zip(src_subset, src_offset, strict=True) + ) + ) + edge = MemletInputEdge(self.state, src, input_subset, dst_node, dst_conn) self.input_edges.append(edge) def _add_edge( @@ -440,34 +474,21 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: field_desc = arg_expr.field.desc(self.sdfg) if isinstance(field_desc, dace.data.Scalar): # deref a zero-dimensional field - assert len(arg_expr.dimensions) == 0 + assert len(arg_expr.field_domain) == 0 assert isinstance(node.type, ts.ScalarType) return MemletExpr(arg_expr.field, arg_expr.gt_dtype, subset="0") # default case: deref a field with one or more dimensions if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()): - # when all indices are symblic expressions, we can perform direct field access through a memlet - if isinstance(arg_expr.gt_dtype, itir_ts.ListType): - assert len(field_desc.shape) == len(arg_expr.dimensions) + 1 - assert arg_expr.gt_dtype.offset_type is not None - field_dims = [*arg_expr.dimensions, arg_expr.gt_dtype.offset_type] - else: - assert len(field_desc.shape) == len(arg_expr.dimensions) - field_dims = arg_expr.dimensions - - field_subset = sbs.Range( - (arg_expr.indices[dim].value, arg_expr.indices[dim].value, 1) # type: ignore[union-attr] - if dim in arg_expr.indices - else (0, size - 1, 1) - for dim, size in zip(field_dims, field_desc.shape) - ) + # when all indices are symbolic expressions, we can perform direct field access through a memlet + field_subset = arg_expr.get_memlet_subset(self.sdfg) return MemletExpr(arg_expr.field, arg_expr.gt_dtype, field_subset) # we use a tasklet to dereference an iterator when one or more indices are the result of some computation, # either indirection through connectivity table or dynamic cartesian offset. - assert all(dim in arg_expr.indices for dim in arg_expr.dimensions) - assert len(field_desc.shape) == len(arg_expr.dimensions) - field_indices = [(dim, arg_expr.indices[dim]) for dim in arg_expr.dimensions] + assert all(dim in arg_expr.indices for dim, _ in arg_expr.field_domain) + assert len(field_desc.shape) == len(arg_expr.field_domain) + field_indices = [(dim, arg_expr.indices[dim]) for dim, _ in arg_expr.field_domain] index_connectors = [ IndexConnectorFmt.format(dim=dim.value) for dim, index in field_indices @@ -494,6 +515,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: sbs.Range.from_array(field_desc), deref_node, "field", + src_offset=[offset for (_, offset) in arg_expr.field_domain], ) for dim, index_expr in field_indices: @@ -532,7 +554,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: it = self.visit(node.args[1]) assert isinstance(it, IteratorExpr) - assert offset_provider.codomain in it.dimensions + assert any(dim == offset_provider.codomain for dim, _ in it.field_domain) assert offset_provider.source_dim in it.indices origin_index = it.indices[offset_provider.source_dim] assert isinstance(origin_index, SymbolExpr) @@ -560,10 +582,12 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: gt_dtype=node.type, subset=sbs.Range.from_string( ",".join( - it.indices[dim].value # type: ignore[union-attr] + str(it.indices[dim].value - offset) # type: ignore[union-attr] if dim != offset_provider.codomain else f"0:{size}" - for dim, size in zip(it.dimensions, field_desc.shape, strict=True) + for (dim, offset), size in zip( + it.field_domain, field_desc.shape, strict=True + ) ) ), ) @@ -971,14 +995,13 @@ def _make_cartesian_shift( self, it: IteratorExpr, offset_dim: gtx_common.Dimension, offset_expr: DataExpr ) -> IteratorExpr: """Implements cartesian shift along one dimension.""" - assert offset_dim in it.dimensions + assert any(dim == offset_dim for dim, _ in it.field_domain) new_index: SymbolExpr | ValueExpr - assert offset_dim in it.indices index_expr = it.indices[offset_dim] if isinstance(index_expr, SymbolExpr) and isinstance(offset_expr, SymbolExpr): # purely symbolic expression which can be interpreted at compile time new_index = SymbolExpr( - dace.symbolic.pystr_to_symbolic(index_expr.value) + offset_expr.value, + index_expr.value + offset_expr.value, index_expr.dc_dtype, ) else: @@ -1032,15 +1055,10 @@ def _make_cartesian_shift( ) # a new iterator with a shifted index along one dimension - return IteratorExpr( - field=it.field, - gt_dtype=it.gt_dtype, - dimensions=it.dimensions, - indices={ - dim: (new_index if dim == offset_dim else index) - for dim, index in it.indices.items() - }, - ) + shifted_indices = { + dim: (new_index if dim == offset_dim else index) for dim, index in it.indices.items() + } + return IteratorExpr(it.field, it.gt_dtype, it.field_domain, shifted_indices) def _make_dynamic_neighbor_offset( self, @@ -1094,7 +1112,7 @@ def _make_unstructured_shift( offset_expr: DataExpr, ) -> IteratorExpr: """Implements shift in unstructured domain by means of a neighbor table.""" - assert connectivity.codomain in it.dimensions + assert any(dim == connectivity.codomain for dim, _ in it.field_domain) neighbor_dim = connectivity.codomain assert neighbor_dim not in it.indices @@ -1117,9 +1135,7 @@ def _make_unstructured_shift( offset_expr, offset_table_node, origin_index ) - return IteratorExpr( - field=it.field, gt_dtype=it.gt_dtype, dimensions=it.dimensions, indices=shifted_indices - ) + return IteratorExpr(it.field, it.gt_dtype, it.field_domain, shifted_indices) def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: # convert builtin-index type to dace type diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 52284edfac..f15287e64c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -16,6 +16,7 @@ import abc import dataclasses +import functools import itertools import operator from typing import Any, Dict, Iterable, List, Optional, Protocol, Sequence, Set, Tuple, Union @@ -98,9 +99,16 @@ def add_mapped_tasklet( class SDFGBuilder(DataflowBuilder, Protocol): """Visitor interface available to GTIR-primitive translators.""" + @abc.abstractmethod + def make_field( + self, data_node: dace.nodes.AccessNode, data_type: ts.FieldType | ts.ScalarType + ) -> gtir_builtin_translators.FieldopData: + """Retrieve the field data descriptor including the domain offset information.""" + ... + @abc.abstractmethod def get_symbol_type(self, symbol_name: str) -> ts.DataType: - """Retrieve the GT4Py type of a symbol used in the program.""" + """Retrieve the GT4Py type of a symbol used in the SDFG.""" ... @abc.abstractmethod @@ -141,6 +149,15 @@ def _collect_symbols_in_domain_expressions( ) +def _get_tuple_type(data: tuple[gtir_builtin_translators.FieldopResult, ...]) -> ts.TupleType: + """ + Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. + """ + return ts.TupleType( + types=[_get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] + ) + + @dataclasses.dataclass(frozen=True) class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """Provides translation capability from a GTIR program to a DaCe SDFG. @@ -157,6 +174,9 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): offset_provider_type: gtx_common.OffsetProviderType global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=lambda: {}) + field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = dataclasses.field( + default_factory=lambda: {} + ) map_uids: eve.utils.UIDGenerator = dataclasses.field( init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map") ) @@ -167,6 +187,15 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderTypeElem: return self.offset_provider_type[offset] + def make_field( + self, data_node: dace.nodes.AccessNode, data_type: ts.FieldType | ts.ScalarType + ) -> gtir_builtin_translators.FieldopData: + if isinstance(data_type, ts.FieldType): + domain_offset = self.field_offsets.get(data_node.data, None) + else: + domain_offset = None + return gtir_builtin_translators.FieldopData(data_node, data_type, domain_offset) + def get_symbol_type(self, symbol_name: str) -> ts.DataType: return self.global_symbols[symbol_name] @@ -248,12 +277,10 @@ def _add_storage( """ if isinstance(gt_type, ts.TupleType): tuple_fields = [] - for tname, tsymbol_type in dace_gtir_utils.get_tuple_fields( - name, gt_type, flatten=True - ): + for tname, ttype in dace_gtir_utils.get_tuple_fields(name, gt_type, flatten=True): tuple_fields.extend( self._add_storage( - sdfg, symbolic_arguments, tname, tsymbol_type, transient, tuple_name=name + sdfg, symbolic_arguments, tname, ttype, transient, tuple_name=name ) ) return tuple_fields @@ -275,7 +302,6 @@ def _add_storage( tuple_name, gt_type.dims ) sdfg.add_array(name, sym_shape, dc_dtype, strides=sym_strides, transient=transient) - return [(name, gt_type)] elif isinstance(gt_type, ts.ScalarType): @@ -344,7 +370,7 @@ def make_temps( head_state.add_nedge( field.dc_node, temp_node, sdfg.make_array_memlet(field.dc_node.data) ) - return gtir_builtin_translators.FieldopData(temp_node, field.gt_type) + return field.make_copy(temp_node) temp_result = gtx_utils.tree_map(make_temps)(result) return list(gtx_utils.flatten_nested_tuple((temp_result,))) @@ -405,6 +431,10 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: if node.function_definitions: raise NotImplementedError("Functions expected to be inlined as lambda calls.") + # Since program field arguments are passed to the SDFG as full-shape arrays, + # there is no offset that needs to be compensated. + assert len(self.field_offsets) == 0 + sdfg = dace.SDFG(node.id) sdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) @@ -459,7 +489,7 @@ def visit_SetAt( The SDFG head state, eventually updated if the target write requires a new state. """ - temp_fields = self._visit_expression(stmt.expr, sdfg, state) + source_fields = self._visit_expression(stmt.expr, sdfg, state) # the target expression could be a `SymRef` to an output node or a `make_tuple` expression # in case the statement returns more than one field @@ -482,17 +512,26 @@ def visit_SetAt( } target_state: Optional[dace.SDFGState] = None - for temp, target in zip(temp_fields, target_fields, strict=True): + for source, target in zip(source_fields, target_fields, strict=True): target_desc = sdfg.arrays[target.dc_node.data] assert not target_desc.transient if isinstance(target.gt_type, ts.FieldType): - subset = ",".join( + target_subset = ",".join( f"{domain[dim][0]}:{domain[dim][1]}" for dim in target.gt_type.dims ) + source_subset = ( + target_subset + if source.offset is None + else ",".join( + f"{domain[dim][0] - offset}:{domain[dim][1] - offset}" + for dim, offset in zip(target.gt_type.dims, source.offset, strict=True) + ) + ) else: assert len(domain) == 0 - subset = "0" + target_subset = "0" + source_subset = "0" if target.dc_node.data in state_input_data: # if inout argument, write the result in separate next state @@ -501,17 +540,21 @@ def visit_SetAt( target_state = sdfg.add_state_after(state, f"post_{state.label}") # create new access nodes in the target state target_state.add_nedge( - target_state.add_access(temp.dc_node.data), + target_state.add_access(source.dc_node.data), target_state.add_access(target.dc_node.data), - dace.Memlet(data=target.dc_node.data, subset=subset, other_subset=subset), + dace.Memlet( + data=target.dc_node.data, subset=target_subset, other_subset=source_subset + ), ) # remove isolated access node state.remove_node(target.dc_node) else: state.add_nedge( - temp.dc_node, + source.dc_node, target.dc_node, - dace.Memlet(data=target.dc_node.data, subset=subset, other_subset=subset), + dace.Memlet( + data=target.dc_node.data, subset=target_subset, other_subset=source_subset + ), ) return target_state or state @@ -574,17 +617,65 @@ def visit_Lambda( (str(param.id), arg) for param, arg in zip(node.params, args, strict=True) ] + def flatten_tuples( + name: str, + arg: gtir_builtin_translators.FieldopResult, + ) -> list[tuple[str, gtir_builtin_translators.FieldopData]]: + if isinstance(arg, tuple): + tuple_type = _get_tuple_type(arg) + tuple_field_names = [ + arg_name for arg_name, _ in dace_gtir_utils.get_tuple_fields(name, tuple_type) + ] + tuple_args = zip(tuple_field_names, arg, strict=True) + return list( + itertools.chain(*[flatten_tuples(fname, farg) for fname, farg in tuple_args]) + ) + else: + return [(name, arg)] + + lambda_arg_nodes = dict( + itertools.chain(*[flatten_tuples(pname, arg) for pname, arg in lambda_args_mapping]) + ) + # inherit symbols from parent scope but eventually override with local symbols lambda_symbols = { sym: self.global_symbols[sym] for sym in symbol_ref_utils.collect_symbol_refs(node.expr, self.global_symbols.keys()) } | { - pname: dace_gtir_utils.get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type + pname: _get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type for pname, arg in lambda_args_mapping } + def get_field_domain_offset( + p_name: str, p_type: ts.DataType + ) -> dict[str, Optional[list[dace.symbolic.SymExpr]]]: + if isinstance(p_type, ts.FieldType): + if p_name in lambda_arg_nodes: + arg = lambda_arg_nodes[p_name] + assert isinstance(arg, gtir_builtin_translators.FieldopData) + return {p_name: arg.offset} + elif field_domain_offset := self.field_offsets.get(p_name, None): + return {p_name: field_domain_offset} + elif isinstance(p_type, ts.TupleType): + p_fields = dace_gtir_utils.get_tuple_fields(p_name, p_type, flatten=True) + return functools.reduce( + lambda field_offsets, field: ( + field_offsets | get_field_domain_offset(field[0], field[1]) + ), + p_fields, + {}, + ) + return {} + + # populate mapping from field name to domain offset + lambda_field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = {} + for p_name, p_type in lambda_symbols.items(): + lambda_field_offsets |= get_field_domain_offset(p_name, p_type) + # lower let-statement lambda node as a nested SDFG - lambda_translator = GTIRToSDFG(self.offset_provider_type, lambda_symbols) + lambda_translator = GTIRToSDFG( + self.offset_provider_type, lambda_symbols, lambda_field_offsets + ) nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) nstate = nsdfg.add_state("lambda") @@ -603,30 +694,11 @@ def visit_Lambda( head_state=nstate, ) - def _flatten_tuples( - name: str, - arg: gtir_builtin_translators.FieldopResult, - ) -> list[tuple[str, gtir_builtin_translators.FieldopData]]: - if isinstance(arg, tuple): - tuple_type = dace_gtir_utils.get_tuple_type(arg) - tuple_field_names = [ - arg_name for arg_name, _ in dace_gtir_utils.get_tuple_fields(name, tuple_type) - ] - tuple_args = zip(tuple_field_names, arg, strict=True) - return list( - itertools.chain(*[_flatten_tuples(fname, farg) for fname, farg in tuple_args]) - ) - else: - return [(name, arg)] - # Process lambda inputs # # All input arguments are passed as parameters to the nested SDFG, therefore # we they are stored as non-transient array and scalar objects. # - lambda_arg_nodes = dict( - itertools.chain(*[_flatten_tuples(pname, arg) for pname, arg in lambda_args_mapping]) - ) connectivity_arrays = { dace_utils.connectivity_identifier(offset) for offset in dace_utils.filter_connectivity_types(self.offset_provider_type) @@ -739,7 +811,7 @@ def construct_output_for_nested_sdfg( head_state.add_edge( nsdfg_node, connector, outer_node, None, sdfg.make_array_memlet(outer) ) - outer_data = gtir_builtin_translators.FieldopData(outer_node, inner_data.gt_type) + outer_data = inner_data.make_copy(outer_node) elif inner_data.dc_node.data in lambda_arg_nodes: # This if branch and the next one handle the non-transient result nodes. # Non-transient nodes are just input nodes that are immediately returned @@ -748,7 +820,7 @@ def construct_output_for_nested_sdfg( outer_data = lambda_arg_nodes[inner_data.dc_node.data] else: outer_node = head_state.add_access(inner_data.dc_node.data) - outer_data = gtir_builtin_translators.FieldopData(outer_node, inner_data.gt_type) + outer_data = inner_data.make_copy(outer_node) # Isolated access node will make validation fail. # Isolated access nodes can be found in the join-state of an if-expression # or in lambda expressions that just construct tuples from input arguments. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index caec6cd87e..118f0449c8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -9,7 +9,7 @@ from __future__ import annotations import itertools -from typing import Any, Dict, TypeVar +from typing import Dict, TypeVar import dace @@ -58,15 +58,6 @@ def get_tuple_fields( return fields -def get_tuple_type(data: tuple[Any, ...]) -> ts.TupleType: - """ - Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. - """ - return ts.TupleType( - types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] - ) - - def replace_invalid_symbols(sdfg: dace.SDFG, ir: gtir.Program) -> gtir.Program: """ Ensure that all symbols used in the program IR are valid strings (e.g. no unicode-strings). diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 9c52ea81c3..f5191fbaaa 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -47,7 +47,7 @@ VFTYPE = ts.FieldType(dims=[Vertex], dtype=FLOAT_TYPE) V2E_FTYPE = ts.FieldType(dims=[Vertex, V2EDim], dtype=EFTYPE.dtype) CARTESIAN_OFFSETS = { - "IDim": IDim, + IDim.value: IDim, } SIMPLE_MESH: MeshDescriptor = simple_mesh() SKIP_VALUE_MESH: MeshDescriptor = skip_value_mesh() @@ -735,13 +735,13 @@ def test_gtir_cartesian_shift_left(): # cartesian shift with literal integer offset stencil1_inlined = im.as_fieldop( - im.lambda_("a")(im.plus(im.deref(im.shift("IDim", OFFSET)("a")), DELTA)), + im.lambda_("a")(im.plus(im.deref(im.shift(IDim.value, OFFSET)("a")), DELTA)), domain, )("x") # fieldview flavor of same stencil, in which a temporary field is initialized with the `DELTA` constant value stencil1_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a")(im.deref(im.shift("IDim", OFFSET)("a"))), + im.lambda_("a")(im.deref(im.shift(IDim.value, OFFSET)("a"))), domain, )("x"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -749,13 +749,15 @@ def test_gtir_cartesian_shift_left(): # use dynamic offset retrieved from field stencil2_inlined = im.as_fieldop( - im.lambda_("a", "off")(im.plus(im.deref(im.shift("IDim", im.deref("off"))("a")), DELTA)), + im.lambda_("a", "off")( + im.plus(im.deref(im.shift(IDim.value, im.deref("off"))("a")), DELTA) + ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil2_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )("x", "x_offset"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -764,14 +766,14 @@ def test_gtir_cartesian_shift_left(): # use the result of an arithmetic field operation as dynamic offset stencil3_inlined = im.as_fieldop( im.lambda_("a", "off")( - im.plus(im.deref(im.shift("IDim", im.plus(im.deref("off"), 0))("a")), DELTA) + im.plus(im.deref(im.shift(IDim.value, im.plus(im.deref("off"), 0))("a")), DELTA) ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil3_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )( "x", @@ -828,13 +830,13 @@ def test_gtir_cartesian_shift_right(): # cartesian shift with literal integer offset stencil1_inlined = im.as_fieldop( - im.lambda_("a")(im.plus(im.deref(im.shift("IDim", -OFFSET)("a")), DELTA)), + im.lambda_("a")(im.plus(im.deref(im.shift(IDim.value, -OFFSET)("a")), DELTA)), domain, )("x") # fieldview flavor of same stencil, in which a temporary field is initialized with the `DELTA` constant value stencil1_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a")(im.deref(im.shift("IDim", -OFFSET)("a"))), + im.lambda_("a")(im.deref(im.shift(IDim.value, -OFFSET)("a"))), domain, )("x"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -842,13 +844,15 @@ def test_gtir_cartesian_shift_right(): # use dynamic offset retrieved from field stencil2_inlined = im.as_fieldop( - im.lambda_("a", "off")(im.plus(im.deref(im.shift("IDim", im.deref("off"))("a")), DELTA)), + im.lambda_("a", "off")( + im.plus(im.deref(im.shift(IDim.value, im.deref("off"))("a")), DELTA) + ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil2_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )("x", "x_offset"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -857,14 +861,14 @@ def test_gtir_cartesian_shift_right(): # use the result of an arithmetic field operation as dynamic offset stencil3_inlined = im.as_fieldop( im.lambda_("a", "off")( - im.plus(im.deref(im.shift("IDim", im.plus(im.deref("off"), 0))("a")), DELTA) + im.plus(im.deref(im.shift(IDim.value, im.plus(im.deref("off"), 0))("a")), DELTA) ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil3_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )( "x", @@ -1539,6 +1543,91 @@ def test_gtir_reduce_with_cond_neighbors(): assert np.allclose(v, v_ref) +def test_gtir_symbolic_domain(): + MARGIN = 2 + assert MARGIN < N + OFFSET = 1000 * 1000 * 1000 + domain = im.domain( + gtx_common.GridType.CARTESIAN, ranges={IDim: (MARGIN, im.minus("size", MARGIN))} + ) + left_domain = im.domain( + gtx_common.GridType.CARTESIAN, + ranges={IDim: (im.minus(MARGIN, OFFSET), im.minus(im.minus("size", MARGIN), OFFSET))}, + ) + right_domain = im.domain( + gtx_common.GridType.CARTESIAN, + ranges={IDim: (im.plus(MARGIN, OFFSET), im.plus(im.plus("size", MARGIN), OFFSET))}, + ) + shift_left_stencil = im.lambda_("a")(im.deref(im.shift(IDim.value, OFFSET)("a"))) + shift_right_stencil = im.lambda_("a")(im.deref(im.shift(IDim.value, -OFFSET)("a"))) + testee = gtir.Program( + id="symbolic_domain", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.let( + "xᐞ1", + im.op_as_fieldop("multiplies", left_domain)( + 4.0, + im.as_fieldop( + shift_left_stencil, + left_domain, + )("x"), + ), + )( + im.let( + "xᐞ2", + im.op_as_fieldop("multiplies", right_domain)( + 3.0, + im.as_fieldop( + shift_right_stencil, + right_domain, + )("x"), + ), + )( + im.let( + "xᐞ3", + im.as_fieldop( + shift_right_stencil, + domain, + )("xᐞ1"), + )( + im.let( + "xᐞ4", + im.as_fieldop( + shift_left_stencil, + domain, + )("xᐞ2"), + )( + im.let("xᐞ5", im.op_as_fieldop("plus", domain)("xᐞ3", "xᐞ4"))( + im.op_as_fieldop("plus", domain)("xᐞ5", "x") + ) + ) + ) + ) + ), + domain=domain, + target=gtir.SymRef(id="y"), + ) + ], + ) + + a = np.random.rand(N) + b = np.random.rand(N) + ref = np.concatenate((b[0:MARGIN], a[MARGIN : N - MARGIN] * 8, b[N - MARGIN : N])) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + + sdfg(a, b, **FSYMBOLS) + assert np.allclose(b, ref) + + def test_gtir_let_lambda(): domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) subdomain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (1, im.minus("size", 1))}) @@ -1722,7 +1811,7 @@ def test_gtir_let_lambda_with_cond(): def test_gtir_let_lambda_with_tuple1(): - domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (1, im.minus("size", 1))}) testee = gtir.Program( id="let_lambda_with_tuple1", function_definitions=[], @@ -1753,10 +1842,12 @@ def test_gtir_let_lambda_with_tuple1(): sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) z_fields = (np.empty_like(a), np.empty_like(a)) + a_ref = np.concatenate((z_fields[0][:1], a[1 : N - 1], z_fields[0][N - 1 :])) + b_ref = np.concatenate((z_fields[1][:1], b[1 : N - 1], z_fields[1][N - 1 :])) sdfg(a, b, *z_fields, **FSYMBOLS) - assert np.allclose(z_fields[0], a) - assert np.allclose(z_fields[1], b) + assert np.allclose(z_fields[0], a_ref) + assert np.allclose(z_fields[1], b_ref) def test_gtir_let_lambda_with_tuple2(): From f6c219bd989e3c5325da1173bade4bff2ac9e650 Mon Sep 17 00:00:00 2001 From: SF-N Date: Tue, 26 Nov 2024 15:59:58 +0100 Subject: [PATCH 4/5] bug[next]: Fix SetAt type inference for ts.DeferredType (#1747) Fix to correctly handle tuples of ts.DeferredType. --------- Co-authored-by: Till Ehrengruber --- src/gt4py/next/iterator/type_system/inference.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 987eb0f308..249019769b 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -509,7 +509,10 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: # the target can have fewer elements than the expr in which case the output from the # expression is simply discarded. expr_type = functools.reduce( - lambda tuple_type, i: tuple_type.types[i], # type: ignore[attr-defined] # format ensured by primitive_constituents + lambda tuple_type, i: tuple_type.types[i] # type: ignore[attr-defined] # format ensured by primitive_constituents + # `ts.DeferredType` only occurs for scans returning a tuple + if not isinstance(tuple_type, ts.DeferredType) + else ts.DeferredType(constraint=None), path, node.expr.type, ) From f6c0498dbffd85a80a32281e5a53bfb35e00e745 Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 27 Nov 2024 09:55:46 +0100 Subject: [PATCH 5/5] feat[next][dace]: Lowering to SDFG of index builtin (#1751) Implements the lowering to SDFG of the GTIR index builtin. --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 14 ++++ .../gtir_builtin_translators.py | 83 ++++++++++++++++--- .../runners/dace_fieldview/gtir_sdfg.py | 2 + tests/next_tests/definitions.py | 1 - .../dace_tests/test_gtir_to_sdfg.py | 50 ++++++++++- 5 files changed, 134 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 2864c7f727..a4e111e785 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -519,6 +519,20 @@ def _impl(it: itir.Expr) -> itir.FunCall: return _impl +def index(dim: common.Dimension) -> itir.FunCall: + """ + Create a call to the `index` builtin, shorthand for `call("index")(axis)`, + after converting the given dimension to `itir.AxisLiteral`. + + Args: + dim: the dimension corresponding to the index axis. + + Returns: + A function that constructs a Field of indices in the given dimension. + """ + return call("index")(itir.AxisLiteral(value=dim.value, kind=dim.kind)) + + def map_(op): """Create a `map_` call.""" return call(call("map_")(op)) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 60dcd8ddc9..94ab3a6f76 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -18,7 +18,7 @@ from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -277,20 +277,31 @@ def extract_domain(node: gtir.Node) -> FieldopDomain: the corresponding lower and upper bounds. The returned lower bound is inclusive, the upper bound is exclusive: [lower_bound, upper_bound[ """ - assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) domain = [] - for named_range in node.args: - assert cpm.is_call_to(named_range, "named_range") - assert len(named_range.args) == 3 - axis = named_range.args[0] - assert isinstance(axis, gtir.AxisLiteral) - lower_bound, upper_bound = ( - dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(arg)) - for arg in named_range.args[1:3] - ) - dim = gtx_common.Dimension(axis.value, axis.kind) - domain.append((dim, lower_bound, upper_bound)) + + def parse_range_boundary(expr: gtir.Expr) -> str: + return dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(expr)) + + if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")): + for named_range in node.args: + assert cpm.is_call_to(named_range, "named_range") + assert len(named_range.args) == 3 + axis = named_range.args[0] + assert isinstance(axis, gtir.AxisLiteral) + lower_bound, upper_bound = (parse_range_boundary(arg) for arg in named_range.args[1:3]) + dim = gtx_common.Dimension(axis.value, axis.kind) + domain.append((dim, lower_bound, upper_bound)) + + elif isinstance(node, domain_utils.SymbolicDomain): + assert str(node.grid_type) in {"cartesian_domain", "unstructured_domain"} + for dim, drange in node.ranges.items(): + domain.append( + (dim, parse_range_boundary(drange.start), parse_range_boundary(drange.stop)) + ) + + else: + raise ValueError(f"Invalid domain {node}.") return domain @@ -545,6 +556,51 @@ def construct_output(inner_data: FieldopData) -> FieldopData: return result_temps +def translate_index( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, +) -> FieldopResult: + """ + Lowers the `index` builtin function to a mapped tasklet that writes the dimension + index values to a transient array. The extent of the index range is taken from + the domain information that should be present in the node annex. + """ + assert "domain" in node.annex + domain = extract_domain(node.annex.domain) + assert len(domain) == 1 + dim, lower_bound, upper_bound = domain[0] + dim_index = dace_gtir_utils.get_map_variable(dim) + + field_dims, field_offset, field_shape = _get_field_layout(domain) + field_type = ts.FieldType(field_dims, dace_utils.as_itir_type(INDEX_DTYPE)) + + output, _ = sdfg.add_temp_transient(field_shape, INDEX_DTYPE) + output_node = state.add_access(output) + + sdfg_builder.add_mapped_tasklet( + "index", + state, + map_ranges={ + dim_index: f"{lower_bound}:{upper_bound}", + }, + inputs={}, + code=f"__val = {dim_index}", + outputs={ + "__val": dace.Memlet( + data=output_node.data, + subset=sbs.Range.from_indices(_get_domain_indices(field_dims, field_offset)), + ) + }, + input_nodes={}, + output_nodes={output_node.data: output_node}, + external_edges=True, + ) + + return FieldopData(output_node, field_type, field_offset) + + def _get_data_nodes( sdfg: dace.SDFG, state: dace.SDFGState, @@ -777,6 +833,7 @@ def translate_symbol_ref( translate_as_fieldop, translate_broadcast_scalar, translate_if, + translate_index, translate_literal, translate_make_tuple, translate_tuple_get, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index f15287e64c..6b5e164458 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -568,6 +568,8 @@ def visit_FunCall( # use specialized dataflow builder classes for each builtin function if cpm.is_call_to(node, "if_"): return gtir_builtin_translators.translate_if(node, sdfg, head_state, self) + elif cpm.is_call_to(node, "index"): + return gtir_builtin_translators.translate_index(node, sdfg, head_state, self) elif cpm.is_call_to(node, "make_tuple"): return gtir_builtin_translators.translate_make_tuple(node, sdfg, head_state, self) elif cpm.is_call_to(node, "tuple_get"): diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 01fd18897d..349d3e9f70 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -154,7 +154,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), ] GTIR_DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ - (USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index f5191fbaaa..c7466b853f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -12,15 +12,15 @@ Note: this test module covers the fieldview flavour of ITIR. """ -import copy import functools import numpy as np import pytest -from gt4py.next import common as gtx_common, constructors +from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import infer_domain from gt4py.next.type_system import type_specifications as ts from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -1973,3 +1973,49 @@ def test_gtir_if_values(): sdfg(a, b, c, **FSYMBOLS) assert np.allclose(c, np.where(a < b, a, b)) + + +def test_gtir_index(): + MARGIN = 2 + assert MARGIN < N + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + subdomain = im.domain( + gtx_common.GridType.CARTESIAN, ranges={IDim: (MARGIN, im.minus("size", MARGIN))} + ) + + testee = gtir.Program( + id="gtir_cast", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.let("i", im.index(IDim))( + im.op_as_fieldop("plus", domain)( + "i", + im.as_fieldop( + im.lambda_("a")(im.deref(im.shift(IDim.value, 1)("a"))), subdomain + )("i"), + ) + ), + domain=subdomain, + target=gtir.SymRef(id="x"), + ) + ], + ) + + v = np.empty(N, dtype=np.int32) + + # we need to run domain inference in order to add the domain annex information to the index node. + testee = infer_domain.infer_program(testee, offset_provider=CARTESIAN_OFFSETS) + sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + + ref = np.concatenate( + (v[:MARGIN], np.arange(MARGIN, N - MARGIN, dtype=np.int32), v[N - MARGIN :]) + ) + + sdfg(v, **FSYMBOLS) + np.allclose(v, ref)