diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d70f335bef..b1092fafd0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -62,7 +62,7 @@ repos: ## version = re.search('black==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: '{version}' # version from constraints.txt") ##]]] - rev: '23.7.0' # version from constraints.txt + rev: '23.9.1' # version from constraints.txt ##[[[end]]] hooks: - id: black @@ -97,7 +97,7 @@ repos: ## print(f"- {pkg}==" + str(re.search(f'\n{pkg}==([0-9\.]*)', constraints)[1])) ##]]] - darglint==1.8.1 - - flake8-bugbear==23.7.10 + - flake8-bugbear==23.9.16 - flake8-builtins==2.1.0 - flake8-debugger==4.1.2 - flake8-docstrings==1.7.0 @@ -146,9 +146,9 @@ repos: ## version = re.search('mypy==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"#========= FROM constraints.txt: v{version} =========") ##]]] - #========= FROM constraints.txt: v1.5.0 ========= + #========= FROM constraints.txt: v1.5.1 ========= ##[[[end]]] - rev: v1.5.0 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) + rev: v1.5.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) hooks: - id: mypy additional_dependencies: # versions from constraints.txt @@ -162,26 +162,26 @@ repos: ##]]] - astunparse==1.6.3 - attrs==23.1.0 - - black==23.7.0 + - black==23.9.1 - boltons==23.0.0 - cached-property==1.5.2 - - click==8.1.6 - - cmake==3.27.2 + - click==8.1.7 + - cmake==3.27.5 - cytoolz==0.12.2 - - deepdiff==6.3.1 - - devtools==0.11.0 + - deepdiff==6.5.0 + - devtools==0.12.2 - frozendict==2.3.8 - gridtools-cpp==2.3.1 - importlib-resources==6.0.1 - jinja2==3.1.2 - lark==1.1.7 - mako==1.2.4 - - nanobind==1.5.0 + - nanobind==1.5.2 - ninja==1.11.1 - numpy==1.24.4 - packaging==23.1 - pybind11==2.11.1 - - setuptools==68.1.0 + - setuptools==68.2.2 - tabulate==0.9.0 - typing-extensions==4.5.0 - xxhash==3.0.0 diff --git a/constraints.txt b/constraints.txt index 35e3d9e330..b334851af1 100644 --- a/constraints.txt +++ b/constraints.txt @@ -6,14 +6,14 @@ # aenum==3.1.15 # via dace alabaster==0.7.13 # via sphinx -asttokens==2.2.1 # via devtools +asttokens==2.4.0 # via devtools astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing babel==2.12.1 # via sphinx -black==23.7.0 # via gt4py (pyproject.toml) +black==23.9.1 # via gt4py (pyproject.toml) blinker==1.6.2 # via flask boltons==23.0.0 # via gt4py (pyproject.toml) -build==0.10.0 # via pip-tools +build==1.0.3 # via pip-tools cached-property==1.5.2 # via gt4py (pyproject.toml) cachetools==5.3.1 # via tox certifi==2023.7.22 # via requests @@ -22,17 +22,17 @@ cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox charset-normalizer==3.2.0 # via requests clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) -click==8.1.6 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.27.2 # via gt4py (pyproject.toml) +click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools +cmake==3.27.5 # via gt4py (pyproject.toml) cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox -coverage==7.3.0 # via -r requirements-dev.in, pytest-cov +coverage==7.3.1 # via -r requirements-dev.in, pytest-cov cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis cytoolz==0.12.2 # via gt4py (pyproject.toml) dace==0.14.4 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -deepdiff==6.3.1 # via gt4py (pyproject.toml) -devtools==0.11.0 # via gt4py (pyproject.toml) +deepdiff==6.5.0 # via gt4py (pyproject.toml) +devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.7 # via dace distlib==0.3.7 # via virtualenv docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme @@ -41,11 +41,11 @@ exceptiongroup==1.1.3 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist executing==1.2.0 # via devtools factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==19.3.0 # via factory-boy +faker==19.6.1 # via factory-boy fastjsonschema==2.18.0 # via nbformat -filelock==3.12.2 # via tox, virtualenv +filelock==3.12.4 # via tox, virtualenv flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings -flake8-bugbear==23.7.10 # via -r requirements-dev.in +flake8-bugbear==23.9.16 # via -r requirements-dev.in flake8-builtins==2.1.0 # via -r requirements-dev.in flake8-debugger==4.1.2 # via -r requirements-dev.in flake8-docstrings==1.7.0 # via -r requirements-dev.in @@ -53,14 +53,14 @@ flake8-eradicate==1.5.0 # via -r requirements-dev.in flake8-mutable==1.2.0 # via -r requirements-dev.in flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in -flask==2.3.2 # via dace +flask==2.3.3 # via dace frozendict==2.3.8 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.82.4 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.5.26 # via pre-commit +hypothesis==6.86.1 # via -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.5.29 # via pre-commit idna==3.4 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.8.0 # via flask, sphinx +importlib-metadata==6.8.0 # via build, flask, sphinx importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest @@ -70,7 +70,7 @@ jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx jsonschema==4.19.0 # via nbformat jsonschema-specifications==2023.7.1 # via jsonschema jupyter-core==5.3.1 # via nbformat -jupytext==1.15.0 # via -r requirements-dev.in +jupytext==1.15.2 # via -r requirements-dev.in lark==1.1.7 # via gt4py (pyproject.toml) mako==1.2.4 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins @@ -79,9 +79,9 @@ mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.5.0 # via -r requirements-dev.in +mypy==1.5.1 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==1.5.0 # via gt4py (pyproject.toml) +nanobind==1.5.2 # via gt4py (pyproject.toml) nbformat==5.9.2 # via jupytext networkx==3.1 # via dace ninja==1.11.1 # via gt4py (pyproject.toml) @@ -94,36 +94,36 @@ pip-tools==7.3.0 # via -r requirements-dev.in pipdeptree==2.13.0 # via -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via jsonschema platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv -pluggy==1.2.0 # via pytest, tox +pluggy==1.3.0 # via pytest, tox ply==3.11 # via dace -pre-commit==3.3.3 # via -r requirements-dev.in +pre-commit==3.4.0 # via -r requirements-dev.in psutil==5.9.5 # via -r requirements-dev.in, pytest-xdist pybind11==2.11.1 # via gt4py (pyproject.toml) pycodestyle==2.11.0 # via flake8, flake8-debugger pycparser==2.21 # via cffi pydocstyle==6.3.0 # via flake8-docstrings pyflakes==3.1.0 # via flake8 -pygments==2.16.1 # via -r requirements-dev.in, flake8-rst-docstrings, sphinx -pyproject-api==1.5.3 # via tox +pygments==2.16.1 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx +pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.0 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==7.4.2 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in pytest-factoryboy==2.5.1 # via -r requirements-dev.in pytest-xdist==3.3.1 # via -r requirements-dev.in python-dateutil==2.8.2 # via faker -pytz==2023.3 # via babel +pytz==2023.3.post1 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit referencing==0.30.2 # via jsonschema, jsonschema-specifications requests==2.31.0 # via dace, sphinx restructuredtext-lint==1.4.0 # via flake8-rst-docstrings -rpds-py==0.9.2 # via jsonschema, referencing -ruff==0.0.284 # via -r requirements-dev.in +rpds-py==0.10.3 # via jsonschema, referencing +ruff==0.0.290 # via -r requirements-dev.in six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx sortedcontainers==2.4.0 # via hypothesis -sphinx==6.2.1 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==1.2.2 # via -r requirements-dev.in +sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery +sphinx-rtd-theme==1.3.0 # via -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx sphinxcontrib-htmlhelp==2.0.1 # via sphinx @@ -136,8 +136,8 @@ tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox toolz==0.12.0 # via cytoolz -tox==4.9.0 # via -r requirements-dev.in -traitlets==5.9.0 # via jupyter-core, nbformat +tox==4.11.3 # via -r requirements-dev.in +traitlets==5.10.0 # via jupyter-core, nbformat types-aiofiles==23.2.0.0 # via types-all types-all==1.0.0 # via -r requirements-dev.in types-annoy==1.17.8.4 # via types-all @@ -182,14 +182,14 @@ types-kazoo==0.1.3 # via types-all types-markdown==3.4.2.10 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 -types-mock==5.1.0.1 # via types-all +types-mock==5.1.0.2 # via types-all types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all types-paramiko==3.3.0.0 # via types-all, types-pysftp types-pathlib2==2.3.0 # via types-all -types-pillow==10.0.0.2 # via types-all +types-pillow==10.0.0.3 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.1 # via types-all types-protobuf==4.24.0.1 # via types-all @@ -205,17 +205,17 @@ types-pysftp==0.2.17.6 # via types-all types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.0.1 # via types-all, types-tzlocal +types-pytz==2023.3.1.0 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all types-pyyaml==6.0.12.11 # via types-all -types-redis==4.6.0.4 # via types-all +types-redis==4.6.0.6 # via types-all types-requests==2.31.0.2 # via types-all types-retry==0.9.9.4 # via types-all types-routes==2.5.0 # via types-all types-scribe==2.0.0 # via types-all -types-setuptools==68.1.0.0 # via types-cffi +types-setuptools==68.2.0.0 # via types-cffi types-simplejson==3.19.0.2 # via types-all -types-singledispatch==4.0.0.2 # via types-all +types-singledispatch==4.1.0.0 # via types-all types-six==1.16.21.9 # via types-all types-tabulate==0.9.0.3 # via types-all types-termcolor==1.1.6.2 # via types-all @@ -230,13 +230,13 @@ types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy urllib3==2.0.4 # via requests -virtualenv==20.24.3 # via pre-commit, tox +virtualenv==20.24.5 # via pre-commit, tox websockets==11.0.3 # via dace werkzeug==2.3.7 # via flask -wheel==0.41.1 # via astunparse, pip-tools +wheel==0.41.2 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) zipp==3.16.2 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: pip==23.2.1 # via pip-tools -setuptools==68.1.0 # via gt4py (pyproject.toml), nodeenv, pip-tools +setuptools==68.2.2 # via gt4py (pyproject.toml), nodeenv, pip-tools diff --git a/docs/development/ADRs/0015-Test_Exclusion_Matrices.md b/docs/development/ADRs/0015-Test_Exclusion_Matrices.md new file mode 100644 index 0000000000..6c6a043560 --- /dev/null +++ b/docs/development/ADRs/0015-Test_Exclusion_Matrices.md @@ -0,0 +1,80 @@ +--- +tags: [] +--- + +# Test-Exclusion Matrices + +- **Status**: valid +- **Authors**: Edoardo Paone (@edopao), Enrique G. Paredes (@egparedes) +- **Created**: 2023-09-21 +- **Updated**: 2023-09-21 + +In the context of Field View testing, lacking support for specific ITIR features while a certain backend +is being developed, we decided to use `pytest` fixtures to exclude unsupported tests. + +## Context + +It should be possible to run Field View tests on different backends. However, specific tests could be unsupported +on a certain backend, or the backend implementation could be only partially ready. +Therefore, we need a mechanism to specify the features required by each test and selectively enable +the supported backends, while keeping the test code clean. + +## Decision + +It was decided to apply fixtures and markers from `pytest` module. The fixture is the same used to execute the test +on different backends (`fieldview_backend` and `program_processor`), but it is extended with a check on the available feature markers. +If a test is annotated with a feature marker, the fixture will check if this feature is supported on the selected backend. +If no marker is specified, the test is supposed to run on all backends. + +In the example below, `test_offset_field` requires the backend to support dynamic offsets in the translation from ITIR: + +```python +@pytest.mark.uses_dynamic_offsets +def test_offset_field(cartesian_case): +``` + +In order to selectively enable the backends, the dictionary `next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX` +lists for each backend the features that are not supported. +The fixture will check if the annotated feature is present in the exclusion-matrix for the selected backend. +If so, the exclusion matrix will also specify the action `pytest` should take (e.g. `SKIP` or `XFAIL`). + +The test-exclusion matrix is a dictionary, where `key` is the backend name and each entry is a tuple with the following fields: + +`(, , )` + +The backend string, used both as dictionary key and as string formatter in the skip message, is retrieved +by calling `next_tests.get_processor_id()`, which returns the so-called processor name. +The following backend processors are defined: + +```python +DACE = "dace_iterator.run_dace_iterator" +GTFN_CPU = "otf_compile_executor.run_gtfn" +GTFN_CPU_IMPERATIVE = "otf_compile_executor.run_gtfn_imperative" +GTFN_CPU_WITH_TEMPORARIES = "otf_compile_executor.run_gtfn_with_temporaries" +``` + +Following the previous example, the GTFN backend with temporaries does not support yet dynamic offsets in ITIR: + +```python +BACKEND_SKIP_TEST_MATRIX = { + GTFN_CPU_WITH_TEMPORARIES: [ + ("uses_dynamic_offsets", pytest.XFAIL, "'{marker}' tests not supported by '{backend}' backend"), + ] +} +``` + +## Consequences + +Positive outcomes of this decision: + +- The solution provides a central place to specify test exclusion. +- The test code remains clean from if-statements for backend exclusion. +- The exclusion matrix gives an overview of the feature-readiness of different backends. + +Negative outcomes: + +- There is not (yet) any code-style check to enforce this solution, so code reviews should be aware of the ADR. + +## References + +- [pytest - Using markers to pass data to fixtures](https://docs.pytest.org/en/6.2.x/fixture.html#using-markers-to-pass-data-to-fixtures) diff --git a/docs/development/ADRs/Index.md b/docs/development/ADRs/Index.md index 1bbfd62d81..09d2273ee9 100644 --- a/docs/development/ADRs/Index.md +++ b/docs/development/ADRs/Index.md @@ -51,9 +51,9 @@ _None_ - [0011 - On The Fly Compilation](0011-On_The_Fly_Compilation.md) - [0012 - GridTools C++ OTF](0011-_GridTools_Cpp_OTF.md) -### Miscellanea +### Testing -_None_ +- [0015 - Exclusion Matrices](0015-Test_Exclusion_Matrices.md) ### Superseded diff --git a/pyproject.toml b/pyproject.toml index e915622857..e2d2a7dfe9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -326,9 +326,25 @@ module = 'gt4py.next.iterator.runtime' [tool.pytest.ini_options] markers = [ - 'requires_atlas', # mark tests that require 'atlas4py' bindings package - 'requires_dace', # mark tests that require 'dace' package - 'requires_gpu:' # mark tests that require a NVidia GPU (assume 'cupy' and 'cudatoolkit' are installed) + 'requires_atlas: tests that require `atlas4py` bindings package', + 'requires_dace: tests that require `dace` package', + 'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)', + 'uses_applied_shifts: tests that require backend support for applied-shifts', + 'uses_can_deref: tests that require backend support for can_deref', + 'uses_constant_fields: tests that require backend support for constant fields', + 'uses_dynamic_offsets: tests that require backend support for dynamic offsets', + 'uses_if_stmts: tests that require backend support for if-statements', + 'uses_index_fields: tests that require backend support for index fields', + 'uses_lift_expressions: tests that require backend support for lift expressions', + 'uses_negative_modulo: tests that require backend support for modulo on negative numbers', + 'uses_origin: tests that require backend support for domain origin', + 'uses_reduction_over_lift_expressions: tests that require backend support for reduction over lift expressions', + 'uses_scan_in_field_operator: tests that require backend support for scan in field operator', + 'uses_sparse_fields: tests that require backend support for sparse fields', + 'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset', + 'uses_tuple_args: tests that require backend support for tuple arguments', + 'uses_tuple_returns: tests that require backend support for tuple results', + 'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields' ] norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*'] testpaths = 'tests' diff --git a/requirements-dev.txt b/requirements-dev.txt index a167b2979a..d6dcc12d21 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,14 +6,14 @@ # aenum==3.1.15 # via dace alabaster==0.7.13 # via sphinx -asttokens==2.2.1 # via devtools +asttokens==2.4.0 # via devtools astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing babel==2.12.1 # via sphinx -black==23.7.0 # via gt4py (pyproject.toml) +black==23.9.1 # via gt4py (pyproject.toml) blinker==1.6.2 # via flask boltons==23.0.0 # via gt4py (pyproject.toml) -build==0.10.0 # via pip-tools +build==1.0.3 # via pip-tools cached-property==1.5.2 # via gt4py (pyproject.toml) cachetools==5.3.1 # via tox certifi==2023.7.22 # via requests @@ -22,17 +22,17 @@ cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox charset-normalizer==3.2.0 # via requests clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) -click==8.1.6 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.27.2 # via gt4py (pyproject.toml) +click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools +cmake==3.27.5 # via gt4py (pyproject.toml) cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox -coverage[toml]==7.3.0 # via -r requirements-dev.in, pytest-cov +coverage[toml]==7.3.1 # via -r requirements-dev.in, pytest-cov cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis cytoolz==0.12.2 # via gt4py (pyproject.toml) dace==0.14.4 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -deepdiff==6.3.1 # via gt4py (pyproject.toml) -devtools==0.11.0 # via gt4py (pyproject.toml) +deepdiff==6.5.0 # via gt4py (pyproject.toml) +devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.7 # via dace distlib==0.3.7 # via virtualenv docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme @@ -41,11 +41,11 @@ exceptiongroup==1.1.3 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist executing==1.2.0 # via devtools factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==19.3.0 # via factory-boy +faker==19.6.1 # via factory-boy fastjsonschema==2.18.0 # via nbformat -filelock==3.12.2 # via tox, virtualenv +filelock==3.12.4 # via tox, virtualenv flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings -flake8-bugbear==23.7.10 # via -r requirements-dev.in +flake8-bugbear==23.9.16 # via -r requirements-dev.in flake8-builtins==2.1.0 # via -r requirements-dev.in flake8-debugger==4.1.2 # via -r requirements-dev.in flake8-docstrings==1.7.0 # via -r requirements-dev.in @@ -53,14 +53,14 @@ flake8-eradicate==1.5.0 # via -r requirements-dev.in flake8-mutable==1.2.0 # via -r requirements-dev.in flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in -flask==2.3.2 # via dace +flask==2.3.3 # via dace frozendict==2.3.8 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.82.4 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.5.26 # via pre-commit +hypothesis==6.86.1 # via -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.5.29 # via pre-commit idna==3.4 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.8.0 # via flask, sphinx +importlib-metadata==6.8.0 # via build, flask, sphinx importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest @@ -70,7 +70,7 @@ jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx jsonschema==4.19.0 # via nbformat jsonschema-specifications==2023.7.1 # via jsonschema jupyter-core==5.3.1 # via nbformat -jupytext==1.15.0 # via -r requirements-dev.in +jupytext==1.15.2 # via -r requirements-dev.in lark==1.1.7 # via gt4py (pyproject.toml) mako==1.2.4 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins @@ -79,9 +79,9 @@ mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.5.0 # via -r requirements-dev.in +mypy==1.5.1 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==1.5.0 # via gt4py (pyproject.toml) +nanobind==1.5.2 # via gt4py (pyproject.toml) nbformat==5.9.2 # via jupytext networkx==3.1 # via dace ninja==1.11.1 # via gt4py (pyproject.toml) @@ -94,36 +94,36 @@ pip-tools==7.3.0 # via -r requirements-dev.in pipdeptree==2.13.0 # via -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via jsonschema platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv -pluggy==1.2.0 # via pytest, tox +pluggy==1.3.0 # via pytest, tox ply==3.11 # via dace -pre-commit==3.3.3 # via -r requirements-dev.in +pre-commit==3.4.0 # via -r requirements-dev.in psutil==5.9.5 # via -r requirements-dev.in, pytest-xdist pybind11==2.11.1 # via gt4py (pyproject.toml) pycodestyle==2.11.0 # via flake8, flake8-debugger pycparser==2.21 # via cffi pydocstyle==6.3.0 # via flake8-docstrings pyflakes==3.1.0 # via flake8 -pygments==2.16.1 # via -r requirements-dev.in, flake8-rst-docstrings, sphinx -pyproject-api==1.5.3 # via tox +pygments==2.16.1 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx +pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.0 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==7.4.2 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in pytest-factoryboy==2.5.1 # via -r requirements-dev.in pytest-xdist[psutil]==3.3.1 # via -r requirements-dev.in python-dateutil==2.8.2 # via faker -pytz==2023.3 # via babel +pytz==2023.3.post1 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit referencing==0.30.2 # via jsonschema, jsonschema-specifications requests==2.31.0 # via dace, sphinx restructuredtext-lint==1.4.0 # via flake8-rst-docstrings -rpds-py==0.9.2 # via jsonschema, referencing -ruff==0.0.284 # via -r requirements-dev.in +rpds-py==0.10.3 # via jsonschema, referencing +ruff==0.0.290 # via -r requirements-dev.in six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx sortedcontainers==2.4.0 # via hypothesis -sphinx==6.2.1 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==1.2.2 # via -r requirements-dev.in +sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery +sphinx-rtd-theme==1.3.0 # via -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx sphinxcontrib-htmlhelp==2.0.1 # via sphinx @@ -136,8 +136,8 @@ tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox toolz==0.12.0 # via cytoolz -tox==4.9.0 # via -r requirements-dev.in -traitlets==5.9.0 # via jupyter-core, nbformat +tox==4.11.3 # via -r requirements-dev.in +traitlets==5.10.0 # via jupyter-core, nbformat types-aiofiles==23.2.0.0 # via types-all types-all==1.0.0 # via -r requirements-dev.in types-annoy==1.17.8.4 # via types-all @@ -182,14 +182,14 @@ types-kazoo==0.1.3 # via types-all types-markdown==3.4.2.10 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 -types-mock==5.1.0.1 # via types-all +types-mock==5.1.0.2 # via types-all types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all types-paramiko==3.3.0.0 # via types-all, types-pysftp types-pathlib2==2.3.0 # via types-all -types-pillow==10.0.0.2 # via types-all +types-pillow==10.0.0.3 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.1 # via types-all types-protobuf==4.24.0.1 # via types-all @@ -205,17 +205,17 @@ types-pysftp==0.2.17.6 # via types-all types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.0.1 # via types-all, types-tzlocal +types-pytz==2023.3.1.0 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all types-pyyaml==6.0.12.11 # via types-all -types-redis==4.6.0.4 # via types-all +types-redis==4.6.0.6 # via types-all types-requests==2.31.0.2 # via types-all types-retry==0.9.9.4 # via types-all types-routes==2.5.0 # via types-all types-scribe==2.0.0 # via types-all -types-setuptools==68.1.0.0 # via types-cffi +types-setuptools==68.2.0.0 # via types-cffi types-simplejson==3.19.0.2 # via types-all -types-singledispatch==4.0.0.2 # via types-all +types-singledispatch==4.1.0.0 # via types-all types-six==1.16.21.9 # via types-all types-tabulate==0.9.0.3 # via types-all types-termcolor==1.1.6.2 # via types-all @@ -230,13 +230,13 @@ types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy urllib3==2.0.4 # via requests -virtualenv==20.24.3 # via pre-commit, tox +virtualenv==20.24.5 # via pre-commit, tox websockets==11.0.3 # via dace werkzeug==2.3.7 # via flask -wheel==0.41.1 # via astunparse, pip-tools +wheel==0.41.2 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) zipp==3.16.2 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: pip==23.2.1 # via pip-tools -setuptools==68.1.0 # via gt4py (pyproject.toml), nodeenv, pip-tools +setuptools==68.2.2 # via gt4py (pyproject.toml), nodeenv, pip-tools diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index 34829317d6..3b8373ade1 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -552,11 +552,14 @@ def is_value_hashable_typing( return type_annotation is None -def is_protocol(type_: Type) -> bool: +def _is_protocol(type_: type, /) -> bool: """Check if a type is a Protocol definition.""" return getattr(type_, "_is_protocol", False) +is_protocol = getattr(_typing_extensions, "is_protocol", _is_protocol) + + def get_partial_type_hints( obj: Union[ object, diff --git a/src/gt4py/next/ffront/foast_passes/type_alias_replacement.py b/src/gt4py/next/ffront/foast_passes/type_alias_replacement.py new file mode 100644 index 0000000000..c5857999ee --- /dev/null +++ b/src/gt4py/next/ffront/foast_passes/type_alias_replacement.py @@ -0,0 +1,105 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from dataclasses import dataclass +from typing import Any, cast + +import gt4py.next.ffront.field_operator_ast as foast +from gt4py.eve import NodeTranslator, traits +from gt4py.eve.concepts import SourceLocation, SymbolName, SymbolRef +from gt4py.next.ffront import dialect_ast_enums +from gt4py.next.ffront.fbuiltins import TYPE_BUILTIN_NAMES +from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system.type_translation import from_type_hint + + +@dataclass +class TypeAliasReplacement(NodeTranslator, traits.VisitorWithSymbolTableTrait): + """ + Replace Type Aliases with their actual type. + + After this pass, the type aliases used for explicit construction of literal + values and for casting field values are replaced by their actual types. + """ + + closure_vars: dict[str, Any] + + @classmethod + def apply( + cls, node: foast.FunctionDefinition | foast.FieldOperator, closure_vars: dict[str, Any] + ) -> tuple[foast.FunctionDefinition, dict[str, Any]]: + foast_node = cls(closure_vars=closure_vars).visit(node) + new_closure_vars = closure_vars.copy() + for key, value in closure_vars.items(): + if isinstance(value, type) and key not in TYPE_BUILTIN_NAMES: + new_closure_vars[value.__name__] = closure_vars[key] + return foast_node, new_closure_vars + + def is_type_alias(self, node_id: SymbolName | SymbolRef) -> bool: + return ( + node_id in self.closure_vars + and isinstance(self.closure_vars[node_id], type) + and node_id not in TYPE_BUILTIN_NAMES + ) + + def visit_Name(self, node: foast.Name, **kwargs) -> foast.Name: + if self.is_type_alias(node.id): + return foast.Name( + id=self.closure_vars[node.id].__name__, location=node.location, type=node.type + ) + return node + + def _update_closure_var_symbols( + self, closure_vars: list[foast.Symbol], location: SourceLocation + ) -> list[foast.Symbol]: + new_closure_vars: list[foast.Symbol] = [] + existing_type_names: set[str] = set() + + for var in closure_vars: + if self.is_type_alias(var.id): + actual_type_name = self.closure_vars[var.id].__name__ + # Avoid multiple definitions of a type in closure_vars + if actual_type_name not in existing_type_names: + new_closure_vars.append( + foast.Symbol( + id=actual_type_name, + type=ts.FunctionType( + pos_or_kw_args={}, + kw_only_args={}, + pos_only_args=[ts.DeferredType(constraint=ts.ScalarType)], + returns=cast( + ts.DataType, from_type_hint(self.closure_vars[var.id]) + ), + ), + namespace=dialect_ast_enums.Namespace.CLOSURE, + location=location, + ) + ) + existing_type_names.add(actual_type_name) + elif var.id not in existing_type_names: + new_closure_vars.append(var) + existing_type_names.add(var.id) + + return new_closure_vars + + def visit_FunctionDefinition( + self, node: foast.FunctionDefinition, **kwargs + ) -> foast.FunctionDefinition: + return foast.FunctionDefinition( + id=node.id, + params=node.params, + body=self.visit(node.body, **kwargs), + closure_vars=self._update_closure_var_symbols(node.closure_vars, node.location), + location=node.location, + ) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 082939c938..c7c4c3a23f 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -33,6 +33,7 @@ from gt4py.next.ffront.foast_passes.closure_var_type_deduction import ClosureVarTypeDeduction from gt4py.next.ffront.foast_passes.dead_closure_var_elimination import DeadClosureVarElimination from gt4py.next.ffront.foast_passes.iterable_unpack import UnpackedAssignPass +from gt4py.next.ffront.foast_passes.type_alias_replacement import TypeAliasReplacement from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -91,6 +92,7 @@ def _postprocess_dialect_ast( closure_vars: dict[str, Any], annotations: dict[str, Any], ) -> foast.FunctionDefinition: + foast_node, closure_vars = TypeAliasReplacement.apply(foast_node, closure_vars) foast_node = ClosureVarFolding.apply(foast_node, closure_vars) foast_node = DeadClosureVarElimination.apply(foast_node) foast_node = ClosureVarTypeDeduction.apply(foast_node, closure_vars) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index f78d90095c..1c1bed9c5e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -11,29 +11,44 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - -from typing import Any, Mapping, Sequence +import hashlib +from typing import Any, Mapping, Optional, Sequence import dace import numpy as np +from dace.codegen.compiled_sdfg import CompiledSDFG +from dace.transformation.auto import auto_optimize as autoopt import gt4py.next.iterator.ir as itir -from gt4py.next import common -from gt4py.next.iterator.embedded import NeighborTableOffsetProvider +from gt4py.next.common import Dimension, Domain, UnitRange, is_field +from gt4py.next.iterator.embedded import NeighborTableOffsetProvider, StridedNeighborOffsetProvider from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms from gt4py.next.otf.compilation import cache from gt4py.next.program_processors.processor_interface import program_executor -from gt4py.next.type_system import type_translation +from gt4py.next.type_system import type_specifications as ts, type_translation from .itir_to_sdfg import ItirToSDFG -from .utility import connectivity_identifier, filter_neighbor_tables +from .utility import connectivity_identifier, filter_neighbor_tables, get_sorted_dims + + +def get_sorted_dim_ranges(domain: Domain) -> Sequence[UnitRange]: + sorted_dims = get_sorted_dims(domain.dims) + return [domain.ranges[dim_index] for dim_index, _ in sorted_dims] + + +""" Default build configuration in DaCe backend """ +_build_type = "Release" +# removing -ffast-math from DaCe default compiler args in order to support isfinite/isinf/isnan built-ins +_cpu_args = ( + "-std=c++14 -fPIC -Wall -Wextra -O3 -march=native -Wno-unused-parameter -Wno-unused-label" +) def convert_arg(arg: Any): - if common.is_field(arg): - sorted_dims = sorted(enumerate(arg.__gt_dims__), key=lambda v: v[1].value) + if is_field(arg): + sorted_dims = get_sorted_dims(arg.domain.dims) ndim = len(sorted_dims) - dim_indices = [dim[0] for dim in sorted_dims] + dim_indices = [dim_index for dim_index, _ in sorted_dims] assert isinstance(arg.ndarray, np.ndarray) return np.moveaxis(arg.ndarray, range(ndim), dim_indices) return arg @@ -69,6 +84,17 @@ def get_shape_args( } +def get_offset_args( + arrays: Mapping[str, dace.data.Array], params: Sequence[itir.Sym], args: Sequence[Any] +) -> Mapping[str, int]: + return { + str(sym): -drange.start + for param, arg in zip(params, args) + if is_field(arg) + for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain)) + } + + def get_stride_args( arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any] ) -> Mapping[str, int]: @@ -85,17 +111,89 @@ def get_stride_args( return stride_args +_build_cache_cpu: dict[str, CompiledSDFG] = {} +_build_cache_gpu: dict[str, CompiledSDFG] = {} + + +def get_cache_id( + program: itir.FencilDefinition, + arg_types: Sequence[ts.TypeSpec], + column_axis: Optional[Dimension], + offset_provider: Mapping[str, Any], +) -> str: + max_neighbors = [ + (k, v.max_neighbors) + for k, v in offset_provider.items() + if isinstance(v, (NeighborTableOffsetProvider, StridedNeighborOffsetProvider)) + ] + cache_id_args = [ + str(arg) + for arg in ( + program, + *arg_types, + column_axis, + *max_neighbors, + ) + ] + m = hashlib.sha256() + for s in cache_id_args: + m.update(s.encode()) + return m.hexdigest() + + @program_executor def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: + # build parameters + auto_optimize = kwargs.get("auto_optimize", False) + build_type = kwargs.get("build_type", "RelWithDebInfo") + run_on_gpu = kwargs.get("run_on_gpu", False) + build_cache = kwargs.get("build_cache", None) + # ITIR parameters column_axis = kwargs.get("column_axis", None) offset_provider = kwargs["offset_provider"] - neighbor_tables = filter_neighbor_tables(offset_provider) - program = preprocess_program(program, offset_provider) arg_types = [type_translation.from_value(arg) for arg in args] - sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) - sdfg: dace.SDFG = sdfg_genenerator.visit(program) - sdfg.simplify() + neighbor_tables = filter_neighbor_tables(offset_provider) + + cache_id = get_cache_id(program, arg_types, column_axis, offset_provider) + if build_cache is not None and cache_id in build_cache: + # retrieve SDFG program from build cache + sdfg_program = build_cache[cache_id] + sdfg = sdfg_program.sdfg + else: + # visit ITIR and generate SDFG + program = preprocess_program(program, offset_provider) + sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) + sdfg = sdfg_genenerator.visit(program) + sdfg.simplify() + + # set array storage for GPU execution + if run_on_gpu: + device = dace.DeviceType.GPU + sdfg._name = f"{sdfg.name}_gpu" + for _, _, array in sdfg.arrays_recursive(): + if not array.transient: + array.storage = dace.dtypes.StorageType.GPU_Global + else: + device = dace.DeviceType.CPU + + # run DaCe auto-optimization heuristics + if auto_optimize: + # TODO Investigate how symbol definitions improve autoopt transformations, + # in which case the cache table should take the symbols map into account. + symbols: dict[str, int] = {} + sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols) + + # compile SDFG and retrieve SDFG program + sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" + with dace.config.temporary_config(): + dace.config.Config.set("compiler", "build_type", value=build_type) + dace.config.Config.set("compiler", "cpu", "args", value=_cpu_args) + sdfg_program = sdfg.compile(validate=False) + + # store SDFG program in build cache + if build_cache is not None: + build_cache[cache_id] = sdfg_program dace_args = get_args(program.params, args) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} @@ -103,9 +201,8 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: dace_shapes = get_shape_args(sdfg.arrays, dace_field_args) dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) dace_strides = get_stride_args(sdfg.arrays, dace_field_args) - dace_conn_stirdes = get_stride_args(sdfg.arrays, dace_conn_args) - - sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" + dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args) + dace_offsets = get_offset_args(sdfg.arrays, program.params, args) all_args = { **dace_args, @@ -113,16 +210,40 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: **dace_shapes, **dace_conn_shapes, **dace_strides, - **dace_conn_stirdes, + **dace_conn_strides, + **dace_offsets, } expected_args = { key: value for key, value in all_args.items() if key in sdfg.signature_arglist(with_types=False) } + with dace.config.temporary_config(): dace.config.Config.set("compiler", "allow_view_arguments", value=True) - dace.config.Config.set("compiler", "build_type", value="Debug") - dace.config.Config.set("compiler", "cpu", "args", value="-O0") dace.config.Config.set("frontend", "check_args", value=True) - sdfg(**expected_args) + sdfg_program(**expected_args) + + +@program_executor +def run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: + run_dace_iterator( + program, + *args, + **kwargs, + build_cache=_build_cache_cpu, + build_type=_build_type, + run_on_gpu=False, + ) + + +@program_executor +def run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: + run_dace_iterator( + program, + *args, + **kwargs, + build_cache=_build_cache_gpu, + build_type=_build_type, + run_on_gpu=True, + ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 4f93777215..580486aa4a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -32,11 +32,14 @@ is_scan, ) from .utility import ( + add_mapped_nested_sdfg, as_dace_type, connectivity_identifier, create_memlet_at, create_memlet_full, filter_neighbor_tables, + flatten_list, + get_sorted_dims, map_nested_sdfg_symbols, unique_var_name, ) @@ -78,9 +81,10 @@ def get_scan_dim( - scan_dim_dtype: data type along the scan dimension """ output_type = cast(ts.FieldType, storage_types[output.id]) + sorted_dims = [dim for _, dim in get_sorted_dims(output_type.dims)] return ( column_axis.value, - output_type.dims.index(column_axis), + sorted_dims.index(column_axis), output_type.dtype, ) @@ -104,18 +108,30 @@ def __init__( self.offset_provider = offset_provider self.storage_types = {} - def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec): + def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True): if isinstance(type_, ts.FieldType): shape = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] strides = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] + offset = ( + [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] + if has_offset + else None + ) dtype = as_dace_type(type_.dtype) - sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) + sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype) elif isinstance(type_, ts.ScalarType): sdfg.add_symbol(name, as_dace_type(type_)) else: raise NotImplementedError() self.storage_types[name] = type_ + def get_output_nodes( + self, closure: itir.StencilClosure, context: Context + ) -> dict[str, dace.nodes.AccessNode]: + translator = PythonTaskletCodegen(self.offset_provider, context, self.node_types) + output_nodes = flatten_list(translator.visit(closure.output)) + return {node.value.data: node.value for node in output_nodes} + def visit_FencilDefinition(self, node: itir.FencilDefinition): program_sdfg = dace.SDFG(name=node.id) last_state = program_sdfg.add_state("program_entry") @@ -133,54 +149,33 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): scalar_kind = type_translation.get_scalar_kind(table.table.dtype) local_dim = Dimension("ElementDim", kind=DimensionKind.LOCAL) type_ = ts.FieldType([table.origin_axis, local_dim], ts.ScalarType(scalar_kind)) - self.add_storage(program_sdfg, connectivity_identifier(offset), type_) + self.add_storage(program_sdfg, connectivity_identifier(offset), type_, has_offset=False) # Create a nested SDFG for all stencil closures. for closure in node.closures: - assert isinstance(closure.output, itir.SymRef) - - # filter out arguments with scalar type, because they are passed as symbols - input_names = [ - str(inp.id) - for inp in closure.inputs - if isinstance(self.storage_types[inp.id], ts.FieldType) - ] - connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] - output_names = [str(closure.output.id)] - # Translate the closure and its stencil's body to an SDFG. - closure_sdfg = self.visit(closure, array_table=program_sdfg.arrays) + closure_sdfg, input_names, output_names = self.visit( + closure, array_table=program_sdfg.arrays + ) # Create a new state for the closure. last_state = program_sdfg.add_state_after(last_state) # Create memlets to transfer the program parameters - input_memlets = [ - create_memlet_full(name, program_sdfg.arrays[name]) for name in input_names - ] - connectivity_memlets = [ - create_memlet_full(name, program_sdfg.arrays[name]) for name in connectivity_names - ] - output_memlets = [ - create_memlet_full(name, program_sdfg.arrays[name]) for name in output_names - ] - - input_mapping = {param: arg for param, arg in zip(input_names, input_memlets)} - connectivity_mapping = { - param: arg for param, arg in zip(connectivity_names, connectivity_memlets) + input_mapping = { + name: create_memlet_full(name, program_sdfg.arrays[name]) for name in input_names } output_mapping = { - param: arg_memlet for param, arg_memlet in zip(output_names, output_memlets) + name: create_memlet_full(name, program_sdfg.arrays[name]) for name in output_names } - array_mapping = {**input_mapping, **connectivity_mapping} - symbol_mapping = map_nested_sdfg_symbols(program_sdfg, closure_sdfg, array_mapping) + symbol_mapping = map_nested_sdfg_symbols(program_sdfg, closure_sdfg, input_mapping) # Insert the closure's SDFG as a nested SDFG of the program. nsdfg_node = last_state.add_nested_sdfg( sdfg=closure_sdfg, parent=program_sdfg, - inputs=set(input_names) | set(connectivity_names), + inputs=set(input_names), outputs=set(output_names), symbol_mapping=symbol_mapping, ) @@ -190,49 +185,78 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): access_node = last_state.add_access(inner_name) last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet) - for inner_name, memlet in connectivity_mapping.items(): - access_node = last_state.add_access(inner_name) - last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet) - for inner_name, memlet in output_mapping.items(): access_node = last_state.add_access(inner_name) last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) + program_sdfg.validate() return program_sdfg def visit_StencilClosure( self, node: itir.StencilClosure, array_table: dict[str, dace.data.Array] - ) -> dace.SDFG: + ) -> tuple[dace.SDFG, list[str], list[str]]: assert ItirToSDFG._check_no_lifts(node) assert ItirToSDFG._check_shift_offsets_are_literals(node) - assert isinstance(node.output, itir.SymRef) - - neighbor_tables = filter_neighbor_tables(self.offset_provider) - input_names = [str(inp.id) for inp in node.inputs] - conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] - output_name = str(node.output.id) # Create the closure's nested SDFG and single state. closure_sdfg = dace.SDFG(name="closure") closure_state = closure_sdfg.add_state("closure_entry") closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init") - # Add DaCe arrays for inputs, output and connectivities to closure SDFG. - for name in [*input_names, *conn_names, output_name]: - assert name not in closure_sdfg.arrays or (name in input_names and name == output_name) + program_arg_syms: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {} + closure_ctx = Context(closure_sdfg, closure_state, program_arg_syms) + neighbor_tables = filter_neighbor_tables(self.offset_provider) + + input_names = [str(inp.id) for inp in node.inputs] + conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] + + output_nodes = self.get_output_nodes(node, closure_ctx) + output_names = [k for k, _ in output_nodes.items()] + + # Add DaCe arrays for inputs, outputs and connectivities to closure SDFG. + input_transients_mapping = {} + for name in [*input_names, *conn_names, *output_names]: if name in closure_sdfg.arrays: - # in/out parameter, container already added for in parameter - continue - if isinstance(self.storage_types[name], ts.FieldType): + assert name in input_names and name in output_names + # In case of closures with in/out fields, there is risk of race condition + # between read/write access nodes in the (asynchronous) map tasklet. + transient_name = unique_var_name() + closure_sdfg.add_array( + transient_name, + shape=array_table[name].shape, + strides=array_table[name].strides, + dtype=array_table[name].dtype, + transient=True, + ) + closure_init_state.add_nedge( + closure_init_state.add_access(name), + closure_init_state.add_access(transient_name), + create_memlet_full(name, closure_sdfg.arrays[name]), + ) + input_transients_mapping[name] = transient_name + elif isinstance(self.storage_types[name], ts.FieldType): closure_sdfg.add_array( name, shape=array_table[name].shape, strides=array_table[name].strides, dtype=array_table[name].dtype, ) + else: + assert isinstance(self.storage_types[name], ts.ScalarType) - # Get output domain of the closure - program_arg_syms: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {} + input_field_names = [ + input_name + for input_name in input_names + if isinstance(self.storage_types[input_name], ts.FieldType) + ] + + # Closure outputs should all be fields + assert all( + isinstance(self.storage_types[output_name], ts.FieldType) + for output_name in output_names + ) + + # Update symbol table and get output domain of the closure for name, type_ in self.storage_types.items(): if isinstance(type_, ts.ScalarType): if name in input_names: @@ -245,94 +269,86 @@ def visit_StencilClosure( ) access = closure_init_state.add_access(out_name) value = ValueExpr(access, dtype) - memlet = create_memlet_at(out_name, ("0",)) + memlet = dace.Memlet.simple(out_name, "0") closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet) program_arg_syms[name] = value else: program_arg_syms[name] = SymbolExpr(name, as_dace_type(type_)) - domain_ctx = Context(closure_sdfg, closure_state, program_arg_syms) - closure_domain = self._visit_domain(node.domain, domain_ctx) + closure_domain = self._visit_domain(node.domain, closure_ctx) # Map SDFG tasklet arguments to parameters input_access_names = [ - input_name - if isinstance(self.storage_types[input_name], ts.FieldType) + input_transients_mapping[input_name] + if input_name in input_transients_mapping + else input_name + if input_name in input_field_names else cast(ValueExpr, program_arg_syms[input_name]).value.data for input_name in input_names ] input_memlets = [ create_memlet_full(name, closure_sdfg.arrays[name]) for name in input_access_names ] - conn_memlet = [create_memlet_full(name, closure_sdfg.arrays[name]) for name in conn_names] + conn_memlets = [create_memlet_full(name, closure_sdfg.arrays[name]) for name in conn_names] - transient_to_arg_name_mapping = {} # create and write to transient that is then copied back to actual output array to avoid aliasing of # same memory in nested SDFG with different names - nsdfg_output_name = unique_var_name() - output_descriptor = closure_sdfg.arrays[output_name] - transient_to_arg_name_mapping[nsdfg_output_name] = output_name + output_connectors_mapping = {unique_var_name(): output_name for output_name in output_names} # scan operator should always be the first function call in a closure if is_scan(node.stencil): - nsdfg, map_domain, scan_dim_index = self._visit_scan_stencil_closure( - node, closure_sdfg.arrays, closure_domain, nsdfg_output_name + assert len(output_connectors_mapping) == 1, "Scan does not support multiple outputs" + transient_name, output_name = next(iter(output_connectors_mapping.items())) + + nsdfg, map_ranges, scan_dim_index = self._visit_scan_stencil_closure( + node, closure_sdfg.arrays, closure_domain, transient_name ) - results = [nsdfg_output_name] + results = [transient_name] _, (scan_lb, scan_ub) = closure_domain[scan_dim_index] output_subset = f"{scan_lb.value}:{scan_ub.value}" - closure_sdfg.add_array( - nsdfg_output_name, - dtype=output_descriptor.dtype, - shape=(array_table[output_name].shape[scan_dim_index],), - strides=(array_table[output_name].strides[scan_dim_index],), - transient=True, - ) - - output_memlet = create_memlet_at( - output_name, - tuple( - f"i_{dim}" - if f"i_{dim}" in map_domain - else f"0:{output_descriptor.shape[scan_dim_index]}" - for dim, _ in closure_domain - ), - ) + output_memlets = [ + create_memlet_at( + output_name, + tuple( + f"i_{dim}" + if f"i_{dim}" in map_ranges + else f"0:{closure_sdfg.arrays[output_name].shape[scan_dim_index]}" + for dim, _ in closure_domain + ), + ) + ] else: - nsdfg, map_domain, results = self._visit_parallel_stencil_closure( + nsdfg, map_ranges, results = self._visit_parallel_stencil_closure( node, closure_sdfg.arrays, closure_domain ) - assert len(results) == 1 output_subset = "0" - closure_sdfg.add_scalar( - nsdfg_output_name, - dtype=output_descriptor.dtype, - transient=True, - ) - - output_memlet = create_memlet_at(output_name, tuple(idx for idx in map_domain.keys())) + output_memlets = [ + create_memlet_at(output_name, tuple(idx for idx in map_ranges.keys())) + for output_name in output_connectors_mapping.values() + ] input_mapping = {param: arg for param, arg in zip(input_names, input_memlets)} - output_mapping = {param: arg_memlet for param, arg_memlet in zip(results, [output_memlet])} - conn_mapping = {param: arg for param, arg in zip(conn_names, conn_memlet)} + output_mapping = {param: arg_memlet for param, arg_memlet in zip(results, output_memlets)} + conn_mapping = {param: arg for param, arg in zip(conn_names, conn_memlets)} array_mapping = {**input_mapping, **conn_mapping} symbol_mapping = map_nested_sdfg_symbols(closure_sdfg, nsdfg, array_mapping) - nsdfg_node, map_entry, map_exit = self._add_mapped_nested_sdfg( + nsdfg_node, map_entry, map_exit = add_mapped_nested_sdfg( closure_state, sdfg=nsdfg, - map_ranges=map_domain or {"__dummy": "0"}, + map_ranges=map_ranges or {"__dummy": "0"}, inputs=array_mapping, outputs=output_mapping, symbol_mapping=symbol_mapping, + output_nodes=output_nodes, ) access_nodes = {edge.data.data: edge.dst for edge in closure_state.out_edges(map_exit)} for edge in closure_state.in_edges(map_exit): memlet = edge.data - if memlet.data not in transient_to_arg_name_mapping: + if memlet.data not in output_connectors_mapping: continue transient_access = closure_state.add_access(memlet.data) closure_state.add_edge( @@ -340,28 +356,16 @@ def visit_StencilClosure( edge.src_conn, transient_access, None, - dace.Memlet(data=memlet.data, subset=output_subset), + dace.Memlet.simple(memlet.data, output_subset), ) - inner_memlet = dace.Memlet( - data=memlet.data, subset=output_subset, other_subset=memlet.subset + inner_memlet = dace.Memlet.simple( + memlet.data, output_subset, other_subset_str=memlet.subset ) closure_state.add_edge(transient_access, None, map_exit, edge.dst_conn, inner_memlet) closure_state.remove_edge(edge) - access_nodes[memlet.data].data = transient_to_arg_name_mapping[memlet.data] - - for _, (lb, ub) in closure_domain: - for b in lb, ub: - if isinstance(b, SymbolExpr): - continue - map_entry.add_in_connector(b.value.data) - closure_state.add_edge( - b.value, - None, - map_entry, - b.value.data, - create_memlet_at(b.value.data, ("0",)), - ) - return closure_sdfg + access_nodes[memlet.data].data = output_connectors_mapping[memlet.data] + + return closure_sdfg, input_field_names + conn_names, output_names def _visit_scan_stencil_closure( self, @@ -389,12 +393,12 @@ def _visit_scan_stencil_closure( connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] # find the scan dimension, same as output dimension, and exclude it from the map domain - map_domain = {} + map_ranges = {} for dim, (lb, ub) in closure_domain: lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value if not dim == scan_dim: - map_domain[f"i_{dim}"] = f"{lb_str}:{ub_str}" + map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" else: scan_lb_str = lb_str scan_ub_str = ub_str @@ -480,29 +484,28 @@ def _visit_scan_stencil_closure( "__result", carry_node1, None, - dace.Memlet(data=f"{scan_carry_name}", subset="0"), + dace.Memlet.simple(scan_carry_name, "0"), ) carry_node2 = lambda_state.add_access(scan_carry_name) lambda_state.add_memlet_path( carry_node2, scan_inner_node, - memlet=dace.Memlet(data=f"{scan_carry_name}", subset="0"), + memlet=dace.Memlet.simple(scan_carry_name, "0"), src_conn=None, dst_conn=lambda_carry_name, ) # connect access nodes to lambda inputs for (inner_name, _), data_name in zip(lambda_inputs[1:], input_names): - data_subset = ( - ", ".join([f"i_{dim}" for dim, _ in closure_domain]) - if isinstance(self.storage_types[data_name], ts.FieldType) - else "0" - ) + if isinstance(self.storage_types[data_name], ts.FieldType): + memlet = create_memlet_at(data_name, tuple(f"i_{dim}" for dim, _ in closure_domain)) + else: + memlet = dace.Memlet.simple(data_name, "0") lambda_state.add_memlet_path( lambda_state.add_access(data_name), scan_inner_node, - memlet=dace.Memlet(data=f"{data_name}", subset=data_subset), + memlet=memlet, src_conn=None, dst_conn=inner_name, ) @@ -526,12 +529,13 @@ def _visit_scan_stencil_closure( data_name, shape=(array_table[node.output.id].shape[scan_dim_index],), strides=(array_table[node.output.id].strides[scan_dim_index],), + offset=(array_table[node.output.id].offset[scan_dim_index],), dtype=array_table[node.output.id].dtype, ) lambda_state.add_memlet_path( scan_inner_node, lambda_state.add_access(data_name), - memlet=dace.Memlet(data=data_name, subset=f"i_{scan_dim}"), + memlet=dace.Memlet.simple(data_name, f"i_{scan_dim}"), src_conn=lambda_connector.value.label, dst_conn=None, ) @@ -543,10 +547,10 @@ def _visit_scan_stencil_closure( lambda_update_state.add_memlet_path( result_node, carry_node3, - memlet=dace.Memlet(data=f"{output_names[0]}", subset=f"i_{scan_dim}", other_subset="0"), + memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"), ) - return scan_sdfg, map_domain, scan_dim_index + return scan_sdfg, map_ranges, scan_dim_index def _visit_parallel_stencil_closure( self, @@ -561,11 +565,11 @@ def _visit_parallel_stencil_closure( conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] # find the scan dimension, same as output dimension, and exclude it from the map domain - map_domain = {} + map_ranges = {} for dim, (lb, ub) in closure_domain: lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value - map_domain[f"i_{dim}"] = f"{lb_str}:{ub_str}" + map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" # Create an SDFG for the tasklet that computes a single item of the output domain. index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} @@ -582,77 +586,7 @@ def _visit_parallel_stencil_closure( self.node_types, ) - return context.body, map_domain, [r.value.data for r in results] - - def _add_mapped_nested_sdfg( - self, - state: dace.SDFGState, - map_ranges: dict[str, str | dace.subsets.Subset] - | list[tuple[str, str | dace.subsets.Subset]], - inputs: dict[str, dace.Memlet], - outputs: dict[str, dace.Memlet], - sdfg: dace.SDFG, - symbol_mapping: dict[str, Any] | None = None, - schedule: Any = dace.dtypes.ScheduleType.Default, - unroll_map: bool = False, - location: Any = None, - debuginfo: Any = None, - input_nodes: dict[str, dace.nodes.AccessNode] | None = None, - output_nodes: dict[str, dace.nodes.AccessNode] | None = None, - ) -> tuple[dace.nodes.NestedSDFG, dace.nodes.MapEntry, dace.nodes.MapExit]: - if not symbol_mapping: - symbol_mapping = {sym: sym for sym in sdfg.free_symbols} - - nsdfg_node = state.add_nested_sdfg( - sdfg, - None, - set(inputs.keys()), - set(outputs.keys()), - symbol_mapping, - name=sdfg.name, - schedule=schedule, - location=location, - debuginfo=debuginfo, - ) - - map_entry, map_exit = state.add_map( - f"{sdfg.name}_map", map_ranges, schedule, unroll_map, debuginfo - ) - - if input_nodes is None: - input_nodes = { - memlet.data: state.add_access(memlet.data) for name, memlet in inputs.items() - } - if output_nodes is None: - output_nodes = { - memlet.data: state.add_access(memlet.data) for name, memlet in outputs.items() - } - if not inputs: - state.add_edge(map_entry, None, nsdfg_node, None, dace.Memlet()) - for name, memlet in inputs.items(): - state.add_memlet_path( - input_nodes[memlet.data], - map_entry, - nsdfg_node, - memlet=memlet, - src_conn=None, - dst_conn=name, - propagate=True, - ) - if not outputs: - state.add_edge(nsdfg_node, None, map_exit, None, dace.Memlet()) - for name, memlet in outputs.items(): - state.add_memlet_path( - nsdfg_node, - map_exit, - output_nodes[memlet.data], - memlet=memlet, - src_conn=name, - dst_conn=None, - propagate=True, - ) - - return nsdfg_node, map_entry, map_exit + return context.body, map_ranges, [r.value.data for r in results] def _visit_domain( self, node: itir.FunCall, context: Context diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index d301c3e3cf..b28703feef 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -19,9 +19,11 @@ import dace import numpy as np +from dace.transformation.dataflow import MapFusion +from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols import gt4py.eve.codegen -from gt4py.next import Dimension, type_inference as next_typing +from gt4py.next import Dimension, StridedNeighborOffsetProvider, type_inference as next_typing from gt4py.next.iterator import ir as itir, type_inference as itir_typing from gt4py.next.iterator.embedded import NeighborTableOffsetProvider from gt4py.next.iterator.ir import FunCall, Lambda @@ -29,12 +31,14 @@ from gt4py.next.type_system import type_specifications as ts from .utility import ( + add_mapped_nested_sdfg, as_dace_type, connectivity_identifier, - create_memlet_at, create_memlet_full, filter_neighbor_tables, + flatten_list, map_nested_sdfg_symbols, + unique_name, unique_var_name, ) @@ -56,6 +60,21 @@ def itir_type_as_dace_type(type_: next_typing.Type): raise NotImplementedError() +def get_reduce_identity_value(op_name_: str, type_: Any): + if op_name_ == "plus": + init_value = type_(0) + elif op_name_ == "multiplies": + init_value = type_(1) + elif op_name_ == "minimum": + init_value = type_("inf") + elif op_name_ == "maximum": + init_value = type_("-inf") + else: + raise NotImplementedError() + + return init_value + + _MATH_BUILTINS_MAPPING = { "abs": "abs({})", "sin": "math.sin({})", @@ -136,6 +155,21 @@ class Context: body: dace.SDFG state: dace.SDFGState symbol_map: dict[str, IteratorExpr | ValueExpr | SymbolExpr] + # if we encounter a reduction node, the reduction state needs to be pushed to child nodes + reduce_limit: int + reduce_wcr: Optional[str] + + def __init__( + self, + body: dace.SDFG, + state: dace.SDFGState, + symbol_map: dict[str, IteratorExpr | ValueExpr | SymbolExpr], + ): + self.body = body + self.state = state + self.symbol_map = symbol_map + self.reduce_limit = 0 + self.reduce_wcr = None def builtin_neighbors( @@ -167,13 +201,15 @@ def builtin_neighbors( table_name = connectivity_identifier(offset_dim) table_array = sdfg.arrays[table_name] + # generate unique map index name to avoid conflict with other maps inside same state + index_name = unique_name("__neigh_idx") me, mx = state.add_map( f"{offset_dim}_neighbors_map", - ndrange={"neigh_idx": f"0:{table.max_neighbors}"}, + ndrange={index_name: f"0:{table.max_neighbors}"}, ) shift_tasklet = state.add_tasklet( "shift", - code="__result = __table[__idx, neigh_idx]", + code=f"__result = __table[__idx, {index_name}]", inputs={"__table", "__idx"}, outputs={"__result"}, ) @@ -208,10 +244,8 @@ def builtin_neighbors( ) # select full shape only in the neighbor-axis dimension field_subset = [ - f"0:{sdfg.arrays[iterator.field.data].shape[idx]}" - if dim == table.neighbor_axis.value - else f"i_{dim}" - for idx, dim in enumerate(iterator.dimensions) + f"0:{shape}" if dim == table.neighbor_axis.value else f"i_{dim}" + for dim, shape in zip(sorted(iterator.dimensions), sdfg.arrays[iterator.field.data].shape) ] state.add_memlet_path( iterator.field, @@ -227,7 +261,7 @@ def builtin_neighbors( data_access_tasklet, mx, result_access, - memlet=dace.Memlet(data=result_name, subset="neigh_idx"), + memlet=dace.Memlet(data=result_name, subset=index_name), src_conn="__result", ) @@ -349,6 +383,8 @@ def visit_Lambda( value = IteratorExpr(field, indices, arg.dtype, arg.dimensions) symbol_map[param] = value context = Context(context_sdfg, context_state, symbol_map) + context.reduce_limit = prev_context.reduce_limit + context.reduce_wcr = prev_context.reduce_wcr self.context = context # Add input parameters as arrays @@ -388,27 +424,36 @@ def visit_Lambda( context.body.add_array(name, shape=shape, strides=strides, dtype=dtype) # Translate the function's body - result: ValueExpr | SymbolExpr = self.visit(node.expr)[0] - # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors - if isinstance(result, ValueExpr): - result_name = unique_var_name() - self.context.body.add_scalar(result_name, result.dtype, transient=True) - result_access = self.context.state.add_access(result_name) - self.context.state.add_edge( - result.value, None, result_access, None, dace.Memlet(f"{result.value.data}[0]") - ) - result = ValueExpr(value=result_access, dtype=result.dtype) - else: - result = self.add_expr_tasklet([], result.value, result.dtype, "forward")[0] - self.context.body.arrays[result.value.data].transient = False - self.context = prev_context + results: list[ValueExpr] = [] + # We are flattening the returned list of value expressions because the multiple outputs of a lamda + # should be a list of nodes without tuple structure. Ideally, an ITIR transformation could do this. + for expr in flatten_list(self.visit(node.expr)): + if isinstance(expr, ValueExpr): + result_name = unique_var_name() + self.context.body.add_scalar(result_name, expr.dtype, transient=True) + result_access = self.context.state.add_access(result_name) + self.context.state.add_edge( + expr.value, + None, + result_access, + None, + # in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution + dace.Memlet(f"{result_access.data}[0]", wcr=context.reduce_wcr), + ) + result = ValueExpr(value=result_access, dtype=expr.dtype) + else: + # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors + result = self.add_expr_tasklet([], expr.value, expr.dtype, "forward")[0] + self.context.body.arrays[result.value.data].transient = False + results.append(result) + self.context = prev_context for node in context.state.nodes(): if isinstance(node, dace.nodes.AccessNode): if context.state.out_degree(node) == 0 and context.state.in_degree(node) == 0: context.state.remove_node(node) - return context, inputs, [result] + return context, inputs, results def visit_SymRef(self, node: itir.SymRef) -> list[ValueExpr | SymbolExpr] | IteratorExpr: if node.id not in self.context.symbol_map: @@ -531,15 +576,69 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: if not isinstance(iterator, IteratorExpr): # already a list of ValueExpr return iterator - sorted_index = sorted(iterator.indices.items(), key=lambda x: x[0]) - flat_index = [ - ValueExpr(x[1], iterator.dtype) for x in sorted_index if x[0] in iterator.dimensions - ] - args: list[ValueExpr] = [ValueExpr(iterator.field, iterator.dtype), *flat_index] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{', '.join(internals[1:])}]" - return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") + args: list[ValueExpr] + sorted_dims = sorted(iterator.dimensions) + if self.context.reduce_limit: + # we are visiting a child node of reduction, so the neighbor index can be used for indirect addressing + result_name = unique_var_name() + self.context.body.add_array( + result_name, + dtype=iterator.dtype, + shape=(self.context.reduce_limit,), + transient=True, + ) + result_access = self.context.state.add_access(result_name) + + # generate unique map index name to avoid conflict with other maps inside same state + index_name = unique_name("__deref_idx") + me, mx = self.context.state.add_map( + "deref_map", + ndrange={index_name: f"0:{self.context.reduce_limit}"}, + ) + + # if dim is not found in iterator indices, we take the neighbor index over the reduction domain + flat_index = [ + f"{iterator.indices[dim].data}_v" if dim in iterator.indices else index_name + for dim in sorted_dims + ] + args = [ValueExpr(iterator.field, iterator.dtype)] + [ + ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.indices + ] + internals = [f"{arg.value.data}_v" for arg in args] + + deref_tasklet = self.context.state.add_tasklet( + name="deref", + inputs=set(internals), + outputs={"__result"}, + code=f"__result = {args[0].value.data}_v[{', '.join(flat_index)}]", + ) + + for arg, internal in zip(args, internals): + input_memlet = create_memlet_full( + arg.value.data, self.context.body.arrays[arg.value.data] + ) + self.context.state.add_memlet_path( + arg.value, me, deref_tasklet, memlet=input_memlet, dst_conn=internal + ) + + self.context.state.add_memlet_path( + deref_tasklet, + mx, + result_access, + memlet=dace.Memlet(data=result_name, subset=index_name), + src_conn="__result", + ) + + return [ValueExpr(value=result_access, dtype=iterator.dtype)] + + else: + args = [ValueExpr(iterator.field, iterator.dtype)] + [ + ValueExpr(iterator.indices[dim], iterator.dtype) for dim in sorted_dims + ] + internals = [f"{arg.value.data}_v" for arg in args] + expr = f"{internals[0]}[{', '.join(internals[1:])}]" + return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") def _split_shift_args( self, args: list[itir.Expr] @@ -603,18 +702,31 @@ def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr: element = tail[1].value assert isinstance(element, int) - table: NeighborTableOffsetProvider = self.offset_provider[offset] - shifted_dim = table.origin_axis.value - target_dim = table.neighbor_axis.value + if isinstance(self.offset_provider[offset], NeighborTableOffsetProvider): + table = self.offset_provider[offset] + shifted_dim = table.origin_axis.value + target_dim = table.neighbor_axis.value - conn = self.context.state.add_access(connectivity_identifier(offset)) + conn = self.context.state.add_access(connectivity_identifier(offset)) + + args = [ + ValueExpr(conn, table.table.dtype), + ValueExpr(iterator.indices[shifted_dim], dace.int64), + ] + + internals = [f"{arg.value.data}_v" for arg in args] + expr = f"{internals[0]}[{internals[1]}, {element}]" + else: + offset_provider = self.offset_provider[offset] + assert isinstance(offset_provider, StridedNeighborOffsetProvider) + + shifted_dim = offset_provider.origin_axis.value + target_dim = offset_provider.neighbor_axis.value + offset_value = iterator.indices[shifted_dim] + args = [ValueExpr(offset_value, dace.int64)] + internals = [f"{offset_value.data}_v"] + expr = f"{internals[0]} * {offset_provider.max_neighbors} + {element}" - args = [ - ValueExpr(conn, table.table.dtype), - ValueExpr(iterator.indices[shifted_dim], dace.int64), - ] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{internals[1]}, {element}]" shifted_value = self.add_expr_tasklet( list(zip(args, internals)), expr, dace.dtypes.int64, "ind_addr" )[0].value @@ -626,47 +738,156 @@ def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr: return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) def _visit_reduce(self, node: itir.FunCall): - assert ( - isinstance(node.args[0], itir.FunCall) - and isinstance(node.args[0].fun, itir.SymRef) - and node.args[0].fun.id == "neighbors" - ) - args = self.visit(node.args) - assert len(args) == 1 - args = args[0] - assert len(args) == 1 - assert isinstance(node.fun, itir.FunCall) - op_name = node.fun.args[0] - assert isinstance(op_name, itir.SymRef) - init = node.fun.args[1] - - nreduce = self.context.body.arrays[args[0].value.data].shape[0] - result_name = unique_var_name() result_access = self.context.state.add_access(result_name) - self.context.body.add_scalar(result_name, args[0].dtype, transient=True) - op_str = _MATH_BUILTINS_MAPPING[str(op_name)].format("__result", "__values[__idx]") - reduce_tasklet = self.context.state.add_tasklet( - "reduce", - code=f"__result = {init}\nfor __idx in range({nreduce}):\n __result = {op_str}", - inputs={"__values"}, - outputs={"__result"}, - ) - self.context.state.add_edge( - args[0].value, - None, - reduce_tasklet, - "__values", - dace.Memlet(data=args[0].value.data, subset=f"0:{nreduce}"), - ) - self.context.state.add_edge( - reduce_tasklet, - "__result", - result_access, - None, - dace.Memlet(data=result_name, subset="0"), - ) - return [ValueExpr(result_access, args[0].dtype)] + + if len(node.args) == 1: + assert ( + isinstance(node.args[0], itir.FunCall) + and isinstance(node.args[0].fun, itir.SymRef) + and node.args[0].fun.id == "neighbors" + ) + args = self.visit(node.args) + assert len(args) == 1 + args = args[0] + assert len(args) == 1 + neighbors_expr = args[0] + result_dtype = neighbors_expr.dtype + assert isinstance(node.fun, itir.FunCall) + op_name = node.fun.args[0] + assert isinstance(op_name, itir.SymRef) + init = node.fun.args[1] + + nreduce = self.context.body.arrays[neighbors_expr.value.data].shape[0] + + self.context.body.add_scalar(result_name, result_dtype, transient=True) + op_str = _MATH_BUILTINS_MAPPING[str(op_name)].format("__result", "__values[__idx]") + reduce_tasklet = self.context.state.add_tasklet( + "reduce", + code=f"__result = {init}\nfor __idx in range({nreduce}):\n __result = {op_str}", + inputs={"__values"}, + outputs={"__result"}, + ) + self.context.state.add_edge( + args[0].value, + None, + reduce_tasklet, + "__values", + dace.Memlet(data=neighbors_expr.value.data, subset=f"0:{nreduce}"), + ) + self.context.state.add_edge( + reduce_tasklet, + "__result", + result_access, + None, + dace.Memlet(data=result_name, subset="0"), + ) + else: + assert isinstance(node.fun, itir.FunCall) + assert isinstance(node.fun.args[0], itir.Lambda) + fun_node = node.fun.args[0] + + args = [] + for node_arg in node.args: + if ( + isinstance(node_arg, itir.FunCall) + and isinstance(node_arg.fun, itir.SymRef) + and node_arg.fun.id == "neighbors" + ): + expr = self.visit(node_arg) + args.append(*expr) + else: + args.append(None) + + # first visit only arguments for neighbor selection, all other arguments are none + neighbor_args = [arg for arg in args if arg] + + # check that all neighbors expression have the same range + assert ( + len( + set([self.context.body.arrays[expr.value.data].shape for expr in neighbor_args]) + ) + == 1 + ) + + nreduce = self.context.body.arrays[neighbor_args[0].value.data].shape[0] + nreduce_domain = {"__idx": f"0:{nreduce}"} + + result_dtype = neighbor_args[0].dtype + self.context.body.add_scalar(result_name, result_dtype, transient=True) + + assert isinstance(fun_node.expr, itir.FunCall) + op_name = fun_node.expr.fun + assert isinstance(op_name, itir.SymRef) + + # initialize the reduction result based on type of operation + init_value = get_reduce_identity_value(op_name.id, result_dtype) + init_state = self.context.body.add_state_before(self.context.state, "init") + init_tasklet = init_state.add_tasklet( + "init_reduce", {}, {"__out"}, f"__out = {init_value}" + ) + init_state.add_edge( + init_tasklet, + "__out", + init_state.add_access(result_name), + None, + dace.Memlet.simple(result_name, "0"), + ) + + # set reduction state to enable dereference of neighbors in input fields and to set WCR on reduce tasklet + self.context.reduce_limit = nreduce + self.context.reduce_wcr = "lambda x, y: " + _MATH_BUILTINS_MAPPING[str(op_name)].format( + "x", "y" + ) + + # visit child nodes for input arguments + for i, node_arg in enumerate(node.args): + if not args[i]: + args[i] = self.visit(node_arg)[0] + + lambda_node = itir.Lambda(expr=fun_node.expr.args[1], params=fun_node.params[1:]) + lambda_context, inner_inputs, inner_outputs = self.visit(lambda_node, args=args) + + # clear context + self.context.reduce_limit = 0 + self.context.reduce_wcr = None + + # the connectivity arrays (neighbor tables) are not needed inside the reduce lambda SDFG + neighbor_tables = filter_neighbor_tables(self.offset_provider) + for conn, _ in neighbor_tables: + var = connectivity_identifier(conn) + lambda_context.body.remove_data(var) + # cleanup symbols previously used for shape and stride of connectivity arrays + p = RemoveUnusedSymbols() + p.apply_pass(lambda_context.body, {}) + + input_memlets = [ + dace.Memlet.simple(expr.value.data, "__idx") for arg, expr in zip(node.args, args) + ] + output_memlet = dace.Memlet.simple(result_name, "0") + + input_mapping = {param: arg for (param, _), arg in zip(inner_inputs, input_memlets)} + output_mapping = {inner_outputs[0].value.data: output_memlet} + symbol_mapping = map_nested_sdfg_symbols( + self.context.body, lambda_context.body, input_mapping + ) + + nsdfg_node, map_entry, _ = add_mapped_nested_sdfg( + self.context.state, + sdfg=lambda_context.body, + map_ranges=nreduce_domain, + inputs=input_mapping, + outputs=output_mapping, + symbol_mapping=symbol_mapping, + input_nodes={arg.value.data: arg.value for arg in args}, + output_nodes={result_name: result_access}, + ) + + # we apply map fusion only to the nested-SDFG which is generated for the reduction operator + # the purpose is to keep the ITIR-visitor program simple and to clean up the generated SDFG + self.context.body.apply_transformations_repeated([MapFusion], validate=False) + + return [ValueExpr(result_access, result_dtype)] def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: assert isinstance(node.fun, itir.SymRef) @@ -720,7 +941,7 @@ def add_expr_tasklet( ) self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet) - memlet = create_memlet_at(result_access.data, ("0",)) + memlet = dace.Memlet.simple(result_access.data, "0") self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet) return [ValueExpr(result_access, result_type)] diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 85b1445dd9..1fdd022a49 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -11,11 +11,12 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - -from typing import Any +import itertools +from typing import Any, Sequence import dace +from gt4py.next import Dimension from gt4py.next.iterator.embedded import NeighborTableOffsetProvider from gt4py.next.type_system import type_specifications as ts @@ -49,7 +50,7 @@ def connectivity_identifier(name: str): def create_memlet_full(source_identifier: str, source_array: dace.data.Array): bounds = [(0, size) for size in source_array.shape] subset = ", ".join(f"{lb}:{ub}" for lb, ub in bounds) - return dace.Memlet(data=source_identifier, subset=subset) + return dace.Memlet.simple(source_identifier, subset) def create_memlet_at(source_identifier: str, index: tuple[str, ...]): @@ -57,6 +58,10 @@ def create_memlet_at(source_identifier: str, index: tuple[str, ...]): return dace.Memlet(data=source_identifier, subset=subset) +def get_sorted_dims(dims: Sequence[Dimension]) -> Sequence[tuple[int, Dimension]]: + return sorted(enumerate(dims), key=lambda v: v[1].value) + + def map_nested_sdfg_symbols( parent_sdfg: dace.SDFG, nested_sdfg: dace.SDFG, array_mapping: dict[str, dace.Memlet] ) -> dict[str, str]: @@ -81,10 +86,91 @@ def map_nested_sdfg_symbols( return symbol_mapping +def add_mapped_nested_sdfg( + state: dace.SDFGState, + map_ranges: dict[str, str | dace.subsets.Subset] | list[tuple[str, str | dace.subsets.Subset]], + inputs: dict[str, dace.Memlet], + outputs: dict[str, dace.Memlet], + sdfg: dace.SDFG, + symbol_mapping: dict[str, Any] | None = None, + schedule: Any = dace.dtypes.ScheduleType.Default, + unroll_map: bool = False, + location: Any = None, + debuginfo: Any = None, + input_nodes: dict[str, dace.nodes.AccessNode] | None = None, + output_nodes: dict[str, dace.nodes.AccessNode] | None = None, +) -> tuple[dace.nodes.NestedSDFG, dace.nodes.MapEntry, dace.nodes.MapExit]: + if not symbol_mapping: + symbol_mapping = {sym: sym for sym in sdfg.free_symbols} + + nsdfg_node = state.add_nested_sdfg( + sdfg, + None, + set(inputs.keys()), + set(outputs.keys()), + symbol_mapping, + name=sdfg.name, + schedule=schedule, + location=location, + debuginfo=debuginfo, + ) + + map_entry, map_exit = state.add_map( + f"{sdfg.name}_map", map_ranges, schedule, unroll_map, debuginfo + ) + + if input_nodes is None: + input_nodes = { + memlet.data: state.add_access(memlet.data) for name, memlet in inputs.items() + } + if output_nodes is None: + output_nodes = { + memlet.data: state.add_access(memlet.data) for name, memlet in outputs.items() + } + if not inputs: + state.add_edge(map_entry, None, nsdfg_node, None, dace.Memlet()) + for name, memlet in inputs.items(): + state.add_memlet_path( + input_nodes[memlet.data], + map_entry, + nsdfg_node, + memlet=memlet, + src_conn=None, + dst_conn=name, + propagate=True, + ) + if not outputs: + state.add_edge(nsdfg_node, None, map_exit, None, dace.Memlet()) + for name, memlet in outputs.items(): + state.add_memlet_path( + nsdfg_node, + map_exit, + output_nodes[memlet.data], + memlet=memlet, + src_conn=name, + dst_conn=None, + propagate=True, + ) + + return nsdfg_node, map_entry, map_exit + + _unique_id = 0 -def unique_var_name(): +def unique_name(prefix): global _unique_id _unique_id += 1 - return f"__var_{_unique_id}" + return f"{prefix}_{_unique_id}" + + +def unique_var_name(): + return unique_name("__var") + + +def flatten_list(node_list: list[Any]) -> list[Any]: + return list( + itertools.chain.from_iterable( + [flatten_list(e) if e.__class__ == list else [e] for e in node_list] + ) + ) diff --git a/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py b/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py index 2934e48e7a..6b8c02e41c 100644 --- a/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py +++ b/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py @@ -194,8 +194,8 @@ def subtest_stencil_info(self, exec_info, stencil_info, last_called_stencil=Fals else: assert stencil_info["total_run_cpp_time"] > stencil_info["run_cpp_time"] - @given(data=hyp_st.data()) @pytest.mark.parametrize("backend", ALL_BACKENDS) + @given(data=hyp_st.data()) def test_backcompatibility(self, data, backend, worker_id): # set backend as instance attribute self.backend = backend @@ -237,8 +237,8 @@ def test_backcompatibility(self, data, backend, worker_id): assert type(self.advection).__name__ not in exec_info assert type(self.diffusion).__name__ not in exec_info - @given(data=hyp_st.data()) @pytest.mark.parametrize("backend", ALL_BACKENDS) + @given(data=hyp_st.data()) def test_aggregate(self, data, backend, worker_id): # set backend as instance attribute self.backend = backend diff --git a/tests/next_tests/__init__.py b/tests/next_tests/__init__.py index bd9b968948..54bc4d9c69 100644 --- a/tests/next_tests/__init__.py +++ b/tests/next_tests/__init__.py @@ -12,6 +12,11 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +from . import exclusion_matrices + + +__all__ = ["exclusion_matrices", "get_processor_id"] + def get_processor_id(processor): if hasattr(processor, "__module__") and hasattr(processor, "__name__"): diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py new file mode 100644 index 0000000000..98ac9352c3 --- /dev/null +++ b/tests/next_tests/exclusion_matrices.py @@ -0,0 +1,99 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +import pytest + + +"""Contains definition of test-exclusion matrices, see ADR 15.""" + +# Skip definitions +XFAIL = pytest.xfail +SKIP = pytest.skip + +# Processor ids as returned by next_tests.get_processor_id() +DACE = "dace_iterator.run_dace_iterator" +GTFN_CPU = "otf_compile_executor.run_gtfn" +GTFN_CPU_IMPERATIVE = "otf_compile_executor.run_gtfn_imperative" +GTFN_CPU_WITH_TEMPORARIES = "otf_compile_executor.run_gtfn_with_temporaries" +GTFN_FORMAT_SOURCECODE = "gtfn.format_sourcecode" + +# Test markers +REQUIRES_ATLAS = "requires_atlas" +USES_APPLIED_SHIFTS = "uses_applied_shifts" +USES_CAN_DEREF = "uses_can_deref" +USES_CONSTANT_FIELDS = "uses_constant_fields" +USES_DYNAMIC_OFFSETS = "uses_dynamic_offsets" +USES_IF_STMTS = "uses_if_stmts" +USES_INDEX_FIELDS = "uses_index_fields" +USES_LIFT_EXPRESSIONS = "uses_lift_expressions" +USES_NEGATIVE_MODULO = "uses_negative_modulo" +USES_ORIGIN = "uses_origin" +USES_REDUCTION_OVER_LIFT_EXPRESSIONS = "uses_reduction_over_lift_expressions" +USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator" +USES_SPARSE_FIELDS = "uses_sparse_fields" +USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS = "uses_reduction_with_only_sparse_fields" +USES_STRIDED_NEIGHBOR_OFFSET = "uses_strided_neighbor_offset" +USES_TUPLE_ARGS = "uses_tuple_args" +USES_TUPLE_RETURNS = "uses_tuple_returns" +USES_ZERO_DIMENSIONAL_FIELDS = "uses_zero_dimensional_fields" + +# Skip messages (available format keys: 'marker', 'backend') +UNSUPPORTED_MESSAGE = "'{marker}' tests not supported by '{backend}' backend" +BINDINGS_UNSUPPORTED_MESSAGE = "'{marker}' not supported by '{backend}' bindings" +REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE = ( + "We cannot unroll a reduction on a sparse field only (not clear if it is legal ITIR)" +) +# Common list of feature markers to skip +GTFN_SKIP_TEST_LIST = [ + (REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), + (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), + (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), +] + +#: Skip matrix, contains for each backend processor a list of tuples with following fields: +#: (, ) +BACKEND_SKIP_TEST_MATRIX = { + DACE: GTFN_SKIP_TEST_LIST + + [ + (USES_CAN_DEREF, XFAIL, UNSUPPORTED_MESSAGE), + (USES_CONSTANT_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), + (USES_REDUCTION_OVER_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + ], + GTFN_CPU: GTFN_SKIP_TEST_LIST + + [ + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + ], + GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST + + [ + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + ], + GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST + + [ + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + ], + GTFN_FORMAT_SOURCECODE: [ + (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), + ], +} diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index a8c35cc28f..383716484e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -22,7 +22,16 @@ import gt4py.next as gtx from gt4py.next.ffront import decorator from gt4py.next.iterator import embedded, ir as itir -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu, roundtrip +from gt4py.next.program_processors.runners import gtfn_cpu, roundtrip + + +try: + from gt4py.next.program_processors.runners import dace_iterator +except ModuleNotFoundError as e: + if "dace" in str(e): + dace_iterator = None + else: + raise e import next_tests @@ -32,20 +41,40 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non raise ValueError("No backend selected! Backend selection is mandatory in tests.") +OPTIONAL_PROCESSORS = [] +if dace_iterator: + OPTIONAL_PROCESSORS.append(dace_iterator.run_dace_iterator) + + @pytest.fixture( params=[ roundtrip.executor, gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ], + ] + + OPTIONAL_PROCESSORS, ids=lambda p: next_tests.get_processor_id(p), ) def fieldview_backend(request): + """ + Fixture creating field-view operator backend on-demand for tests. + + Notes: + Check ADR 15 for details on the test-exclusion matrices. + """ + backend = request.param + backend_id = next_tests.get_processor_id(backend) + + for marker, skip_mark, msg in next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get( + backend_id, [] + ): + if request.node.get_closest_marker(marker): + skip_mark(msg.format(marker=marker, backend=backend_id)) + backup_backend = decorator.DEFAULT_BACKEND decorator.DEFAULT_BACKEND = no_backend - yield request.param + yield backend decorator.DEFAULT_BACKEND = backup_backend diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index 71e31542f7..1402649127 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -24,7 +24,7 @@ from gt4py.next.errors.exceptions import TypeError_ from gt4py.next.ffront.decorator import field_operator, program, scan_operator from gt4py.next.ffront.fbuiltins import broadcast, int32, int64 -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu +from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -169,15 +169,8 @@ def testee( ) +@pytest.mark.uses_scan_in_field_operator def test_call_scan_operator_from_field_operator(cartesian_case): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("Calling scan from field operator not fully supported.") - @scan_operator(axis=KDim, forward=True, init=0.0) def testee_scan(state: float, x: float, y: float) -> float: return state + x + 2.0 * y diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index f50f16ea0f..61b34460ef 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -33,7 +33,7 @@ where, ) from gt4py.next.ffront.experimental import as_offset -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu +from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -68,10 +68,8 @@ def testee(a: cases.IJKField) -> cases.IJKField: cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) +@pytest.mark.uses_tuple_returns def test_multicopy(cartesian_case): # noqa: F811 # fixtures - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def testee(a: cases.IJKField, b: cases.IJKField) -> tuple[cases.IJKField, cases.IJKField]: return a, b @@ -162,9 +160,6 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> cases.IJKField: def test_tuples(cartesian_case): # noqa: F811 # fixtures - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def testee(a: cases.IJKFloatField, b: cases.IJKFloatField) -> cases.IJKFloatField: inps = a, b @@ -211,10 +206,8 @@ def testee(a: int32) -> cases.VField: ) +@pytest.mark.uses_index_fields def test_scalar_arg_with_field(cartesian_case): # noqa: F811 # fixtures - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: index fields, constant fields") - @gtx.field_operator def testee(a: cases.IJKField, b: int32) -> cases.IJKField: tmp = b * a @@ -272,16 +265,8 @@ def testee(qc: cases.IKFloatField, scalar: float): cases.verify(cartesian_case, testee, qc, scalar, inout=qc, ref=expected) +@pytest.mark.uses_scan_in_field_operator def test_tuple_scalar_scan(cartesian_case): # noqa: F811 # fixtures - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - ]: - pytest.xfail("Scalar tuple arguments are not supported in gtfn yet.") - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple arguments") - @gtx.scan_operator(axis=KDim, forward=True, init=0.0) def testee_scan( state: float, qc_in: float, tuple_scalar: tuple[float, tuple[float, float]] @@ -301,10 +286,8 @@ def testee_op( cases.verify(cartesian_case, testee_op, qc, tuple_scalar, out=qc, ref=expected) +@pytest.mark.uses_index_fields def test_scalar_scan_vertical_offset(cartesian_case): # noqa: F811 # fixtures - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: scans") - @gtx.scan_operator(axis=KDim, forward=True, init=(0.0)) def testee_scan(state: float, inp: float) -> float: return inp @@ -382,12 +365,8 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], np.float32]: ) +@pytest.mark.uses_dynamic_offsets def test_offset_field(cartesian_case): - if cartesian_case.backend == gtfn_cpu.run_gtfn_with_temporaries: - pytest.xfail("Dynamic offsets not supported in gtfn") - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: offset fields") - ref = np.full( (cartesian_case.default_sizes[IDim], cartesian_case.default_sizes[KDim]), True, dtype=bool ) @@ -421,9 +400,6 @@ def testee(a: cases.IKField, offset_field: cases.IKField) -> gtx.Field[[IDim, KD def test_nested_tuple_return(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def pack_tuple( a: cases.IField, b: cases.IField @@ -438,10 +414,8 @@ def combine(a: cases.IField, b: cases.IField) -> cases.IField: cases.verify_with_default_data(cartesian_case, combine, ref=lambda a, b: a + a + b) +@pytest.mark.uses_reduction_over_lift_expressions def test_nested_reduction(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: reductions over lift expressions") - @gtx.field_operator def testee(a: cases.EField) -> cases.EField: tmp = neighbor_sum(a(V2E), axis=V2EDim) @@ -481,10 +455,8 @@ def testee(inp: cases.EField) -> cases.EField: ) +@pytest.mark.uses_tuple_returns def test_tuple_return_2(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def testee(a: cases.EField, b: cases.EField) -> tuple[cases.VField, cases.VField]: tmp = neighbor_sum(a(V2E), axis=V2EDim) @@ -502,10 +474,8 @@ def testee(a: cases.EField, b: cases.EField) -> tuple[cases.VField, cases.VField ) +@pytest.mark.uses_constant_fields def test_tuple_with_local_field_in_reduction_shifted(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuples") - @gtx.field_operator def reduce_tuple_element(e: cases.EField, v: cases.VField) -> cases.EField: tup = e(V2E), v @@ -522,10 +492,8 @@ def reduce_tuple_element(e: cases.EField, v: cases.VField) -> cases.EField: ) +@pytest.mark.uses_tuple_args def test_tuple_arg(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple args") - @gtx.field_operator def testee(a: tuple[tuple[cases.IField, cases.IField], cases.IField]) -> cases.IField: return 3 * a[0][0] + a[0][1] + a[1] @@ -555,6 +523,7 @@ def simple_scan_operator(carry: float) -> float: cases.verify(cartesian_case, simple_scan_operator, out=out, ref=expected) +@pytest.mark.uses_lift_expressions def test_solve_triag(cartesian_case): if cartesian_case.backend in [ gtfn_cpu.run_gtfn, @@ -564,8 +533,6 @@ def test_solve_triag(cartesian_case): pytest.xfail("Nested `scan`s requires creating temporaries.") if cartesian_case.backend == gtfn_cpu.run_gtfn_with_temporaries: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: scans") @gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0)) def tridiag_forward( @@ -627,10 +594,8 @@ def testee(left: int32, right: int32) -> cases.IField: @pytest.mark.parametrize("left, right", [(2, 3), (3, 2)]) +@pytest.mark.uses_tuple_returns def test_ternary_operator_tuple(cartesian_case, left, right): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def testee( a: cases.IField, b: cases.IField, left: int32, right: int32 @@ -646,10 +611,8 @@ def testee( ) +@pytest.mark.uses_reduction_over_lift_expressions def test_ternary_builtin_neighbor_sum(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: reductions over lift expressions") - @gtx.field_operator def testee(a: cases.EField, b: cases.EField) -> cases.VField: tmp = neighbor_sum(b(V2E) if 2 < 3 else a(V2E), axis=V2EDim) @@ -688,11 +651,10 @@ def simple_scan_operator(carry: float, a: float) -> float: @pytest.mark.parametrize("forward", [True, False]) +@pytest.mark.uses_tuple_returns def test_scan_nested_tuple_output(forward, cartesian_case): if cartesian_case.backend in [gtfn_cpu.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") init = (1, (2, 3)) k_size = cartesian_case.default_sizes[KDim] @@ -720,9 +682,8 @@ def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]): ) +@pytest.mark.uses_tuple_args def test_scan_nested_tuple_input(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple args") init = 1.0 k_size = cartesian_case.default_sizes[KDim] inp1 = gtx.np_as_located_field(KDim)(np.ones((k_size,))) @@ -878,9 +839,6 @@ def program_domain( def test_domain_tuple(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def fieldop_domain_tuple( a: cases.IJField, b: cases.IJField @@ -939,10 +897,8 @@ def return_undefined(): return undefined_symbol +@pytest.mark.uses_zero_dimensional_fields def test_zero_dims_fields(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: zero-dimensional fields") - @gtx.field_operator def implicit_broadcast_scalar(inp: cases.EmptyField): return inp @@ -970,10 +926,8 @@ def fieldop_implicit_broadcast_2(inp: cases.IField) -> cases.IField: ) +@pytest.mark.uses_tuple_returns def test_tuple_unpacking(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def unpack( inp: cases.IField, @@ -986,9 +940,8 @@ def unpack( ) +@pytest.mark.uses_tuple_returns def test_tuple_unpacking_star_multi(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") OutType = tuple[ cases.IField, cases.IField, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index f2c8525346..dbc35ddfdf 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -17,7 +17,6 @@ import gt4py.next as gtx from gt4py.next import int32, neighbor_sum -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu from next_tests.integration_tests import cases from next_tests.integration_tests.cases import V2E, Edge, V2EDim, Vertex, unstructured_case @@ -28,9 +27,6 @@ def test_external_local_field(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: reductions over non-field expressions") - @gtx.field_operator def testee( inp: gtx.Field[[Vertex, V2EDim], int32], ones: gtx.Field[[Edge], int32] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 7acc0e1447..0ae874f3a6 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -18,7 +18,7 @@ import gt4py.next as gtx from gt4py.next import broadcast, float64, int32, int64, max_over, min_over, neighbor_sum, where -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu +from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -46,8 +46,6 @@ ids=["positive_values", "negative_values"], ) def test_maxover_execution_(unstructured_case, strategy): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: reductions") if unstructured_case.backend in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, @@ -69,9 +67,6 @@ def testee(edge_f: cases.EField) -> cases.VField: def test_minover_execution(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: reductions") - @gtx.field_operator def minover(edge_f: cases.EField) -> cases.VField: out = min_over(edge_f(V2E), axis=V2EDim) @@ -99,12 +94,8 @@ def fencil(edge_f: cases.EField, out: cases.VField): ) +@pytest.mark.uses_constant_fields def test_reduction_expression_in_call(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - # -edge_f(V2E) * tmp_nbh * 2 gets inlined with the neighbor_sum operation in the reduction in itir, - # so in addition to the skipped reason, currently itir is a lambda instead of the 'plus' operation - pytest.skip("Not supported in DaCe backend: Reductions not directly on a field.") - @gtx.field_operator def reduce_expr(edge_f: cases.EField) -> cases.VField: tmp_nbh_tup = edge_f(V2E), edge_f(V2E) @@ -124,9 +115,6 @@ def fencil(edge_f: cases.EField, out: cases.VField): def test_reduction_with_common_expression(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.skip("Not supported in DaCe backend: Reductions not directly on a field.") - @gtx.field_operator def testee(flux: cases.EField) -> cases.VField: return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim) @@ -138,10 +126,8 @@ def testee(flux: cases.EField) -> cases.VField: ) +@pytest.mark.uses_tuple_returns def test_conditional_nested_tuple(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def conditional_nested_tuple( mask: cases.IBoolField, a: cases.IFloatField, b: cases.IFloatField diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index 9ceab7f2d0..f7121dc82f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -22,7 +22,6 @@ from gt4py.next.ffront import dialect_ast_enums, fbuiltins, field_operator_ast as foast from gt4py.next.ffront.decorator import FieldOperator from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction -from gt4py.next.program_processors.runners import dace_iterator from gt4py.next.type_system import type_translation from next_tests.integration_tests import cases diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 54374077b4..85826c1ac0 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -37,7 +37,7 @@ tanh, trunc, ) -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu +from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests import cases from next_tests.integration_tests.cases import IDim, cartesian_case, unstructured_case @@ -84,17 +84,8 @@ def floorDiv(inp1: cases.IField) -> cases.IField: cases.verify_with_default_data(cartesian_case, floorDiv, ref=lambda inp1: inp1 // 2) +@pytest.mark.uses_negative_modulo def test_mod(cartesian_case): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail( - "Modulo not properly supported for negative numbers." - ) # see https://github.com/GridTools/gt4py/issues/1219 - @gtx.field_operator def mod_fieldop(inp1: cases.IField) -> cases.IField: return inp1 % 2 diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index d7c50e83f0..d86bc21679 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -20,7 +20,6 @@ import pytest import gt4py.next as gtx -from gt4py.next.program_processors.runners import dace_iterator from next_tests.integration_tests import cases from next_tests.integration_tests.cases import IDim, Ioff, JDim, cartesian_case, fieldview_backend @@ -130,9 +129,6 @@ def fo_from_fo_program(in_field: cases.IFloatField, out: cases.IFloatField): def test_tuple_program_return_constructed_inside(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def pack_tuple( a: cases.IFloatField, b: cases.IFloatField @@ -159,9 +155,6 @@ def prog( def test_tuple_program_return_constructed_inside_with_slicing(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def pack_tuple( a: cases.IFloatField, b: cases.IFloatField @@ -189,9 +182,6 @@ def prog( def test_tuple_program_return_constructed_inside_nested(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def pack_tuple( a: cases.IFloatField, b: cases.IFloatField, c: cases.IFloatField diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index a49dd1fdcf..f9fd2c1353 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -19,7 +19,6 @@ import pytest from gt4py.next import Field, errors, field_operator, float64, index_field, np_as_located_field -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -46,15 +45,8 @@ @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_simple_if(condition, cartesian_case): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def simple_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -71,15 +63,8 @@ def simple_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField @pytest.mark.parametrize("condition1, condition2", [[True, False], [True, False]]) +@pytest.mark.uses_if_stmts def test_simple_if_conditional(condition1, condition2, cartesian_case): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def simple_if( a: cases.IField, @@ -112,15 +97,8 @@ def simple_if( @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_local_if(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def local_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -138,15 +116,8 @@ def local_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_temporary_if(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def temporary_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -167,15 +138,8 @@ def temporary_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IFi @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_if_return(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def temporary_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -196,15 +160,8 @@ def temporary_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IFi @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_if_stmt_if_branch_returns(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def if_branch_returns(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -222,15 +179,8 @@ def if_branch_returns(a: cases.IField, b: cases.IField, condition: bool) -> case @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_if_stmt_else_branch_returns(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def else_branch_returns(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -250,15 +200,8 @@ def else_branch_returns(a: cases.IField, b: cases.IField, condition: bool) -> ca @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_if_stmt_both_branches_return(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def both_branches_return(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -278,15 +221,8 @@ def both_branches_return(a: cases.IField, b: cases.IField, condition: bool) -> c @pytest.mark.parametrize("condition1, condition2", [[True, False], [True, False]]) -def test_nested_if_stmt_conditinal(cartesian_case, condition1, condition2): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - +@pytest.mark.uses_if_stmts +def test_nested_if_stmt_conditional(cartesian_case, condition1, condition2): @field_operator def nested_if_conditional_return( inp: cases.IField, condition1: bool, condition2: bool @@ -322,15 +258,8 @@ def nested_if_conditional_return( @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_nested_if(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def nested_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -364,15 +293,8 @@ def nested_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField @pytest.mark.parametrize("condition1, condition2", [[True, False], [True, False]]) +@pytest.mark.uses_if_stmts def test_if_without_else(cartesian_case, condition1, condition2): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def if_without_else( a: cases.IField, b: cases.IField, condition1: bool, condition2: bool diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 13fcf3b87f..ca29c5b18b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -52,7 +52,6 @@ xor_, ) from gt4py.next.iterator.runtime import closure, fendef, fundef, offset -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from gt4py.next.program_processors.runners.gtfn_cpu import run_gtfn from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data @@ -171,10 +170,6 @@ def arithmetic_and_logical_test_data(): @pytest.mark.parametrize("builtin, inputs, expected", arithmetic_and_logical_test_data()) def test_arithmetic_and_logical_builtins(program_processor, builtin, inputs, expected, as_column): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: argument types are not propagated for ITIR tests" - ) inps = asfield(*asarray(*inputs)) out = asfield((np.zeros_like(*asarray(expected))))[0] @@ -207,10 +202,6 @@ def test_arithmetic_and_logical_functors_gtfn(builtin, inputs, expected): @pytest.mark.parametrize("builtin_name, inputs", math_builtin_test_data()) def test_math_function_builtins(program_processor, builtin_name, inputs, as_column): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: argument types are not propagated for ITIR tests" - ) if builtin_name == "gamma": # numpy has no gamma function @@ -254,10 +245,9 @@ def foo(a): @pytest.mark.parametrize("stencil", [_can_deref, _can_deref_lifted]) +@pytest.mark.uses_can_deref def test_can_deref(program_processor, stencil): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: can_deref") Node = gtx.Dimension("Node") diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py index d20ec2ee3d..c2517f1a07 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py @@ -18,7 +18,6 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.unit_tests.conftest import program_processor, run_processor @@ -27,15 +26,14 @@ @fundef -def test_conditional(inp): +def stencil_conditional(inp): tmp = if_(eq(deref(inp), 0), make_tuple(1.0, 2.0), make_tuple(3.0, 4.0)) return tuple_get(0, tmp) + tuple_get(1, tmp) +@pytest.mark.uses_tuple_returns def test_conditional_w_tuple(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") shape = [5] @@ -46,7 +44,7 @@ def test_conditional_w_tuple(program_processor): IDim: range(0, shape[0]), } run_processor( - test_conditional[dom], + stencil_conditional[dom], program_processor, inp, out=out, diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py index f4ebc596e5..75b935677b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py @@ -34,8 +34,6 @@ from gt4py.next.program_processors.formatters.gtfn import ( format_sourcecode as gtfn_format_sourcecode, ) -from gt4py.next.program_processors.runners import gtfn_cpu -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.integration_tests.cases import IDim from next_tests.unit_tests.conftest import program_processor, run_processor @@ -54,16 +52,13 @@ def conditional_indirection(inp, cond): return deref(compute_shift(cond)(inp)) +@pytest.mark.uses_applied_shifts def test_simple_indirection(program_processor): program_processor, validate = program_processor if program_processor in [ type_check.check, - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, gtfn_format_sourcecode, - run_dace_iterator, ]: pytest.xfail( "We only support applied shifts in type_inference." @@ -97,13 +92,9 @@ def direct_indirection(inp, cond): return deref(shift(I, deref(cond))(inp)) +@pytest.mark.uses_dynamic_offsets def test_direct_offset_for_indirection(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: shift offsets not literals") - - if program_processor == gtfn_cpu.run_gtfn_with_temporaries: - pytest.xfail("Dynamic offsets not supported in temporaries pass.") shape = [4] inp = gtx.np_as_located_field(IDim)(np.asarray(range(shape[0]), dtype=np.float64)) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py index 2076cdd864..d0dc8ec475 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py @@ -18,7 +18,6 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import fundef -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.unit_tests.conftest import program_processor, run_processor @@ -59,10 +58,6 @@ def test_single_argument(program_processor, dom): def test_2_arguments(program_processor, dom): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: argument types are not propagated for ITIR tests" - ) @fundef def fun(inp0, inp1): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py index e0460b67b1..e02dab0a72 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py @@ -18,16 +18,14 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import cartesian_domain, deref, named_range, scan, shift from gt4py.next.iterator.runtime import fundef, offset -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.integration_tests.cases import IDim, KDim from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor +@pytest.mark.uses_index_fields def test_scan_in_stencil(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: shift inside lambda") isize = 1 ksize = 3 diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py index 7bfaa7f643..0ac38e9b9f 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py @@ -18,8 +18,6 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import deref, named_range, shift, unstructured_domain from gt4py.next.iterator.runtime import closure, fendef, fundef, offset -from gt4py.next.program_processors.runners import gtfn_cpu -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.unit_tests.conftest import program_processor, run_processor @@ -49,15 +47,9 @@ def fencil(size, out, inp): ) +@pytest.mark.uses_strided_neighbor_offset def test_strided_offset_provider(program_processor): program_processor, validate = program_processor - if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - run_dace_iterator, - ]: - pytest.xfail("gtx.StridedNeighborOffsetProvider not implemented in bindings.") LocA_size = 2 max_neighbors = LocA2LocAB_offset_provider.max_neighbors diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py index 7cc4e95949..cc12183a24 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py @@ -19,7 +19,6 @@ from gt4py.next.iterator import transforms from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef, offset -from gt4py.next.program_processors.runners.gtfn_cpu import run_gtfn from next_tests.integration_tests.cases import IDim, JDim, KDim from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index 5a6ffe2891..67b439507c 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -18,13 +18,8 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator -from next_tests.unit_tests.conftest import ( - program_processor, - program_processor_no_gtfn_exec, - run_processor, -) +from next_tests.unit_tests.conftest import program_processor, run_processor IDim = gtx.Dimension("IDim") @@ -54,10 +49,9 @@ def tuple_output2(inp1, inp2): "stencil", [tuple_output1, tuple_output2], ) +@pytest.mark.uses_tuple_returns def test_tuple_output(program_processor, stencil): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") shape = [5, 7, 9] rng = np.random.default_rng() @@ -94,10 +88,9 @@ def tuple_of_tuple_output2(inp1, inp2, inp3, inp4): return make_tuple(deref(inp1), deref(inp2)), make_tuple(deref(inp3), deref(inp4)) +@pytest.mark.uses_tuple_returns def test_tuple_of_tuple_of_field_output(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") @fundef def stencil(inp1, inp2, inp3, inp4): @@ -157,8 +150,6 @@ def stencil(inp1, inp2, inp3, inp4): ) def test_tuple_of_field_output_constructed_inside(program_processor, stencil): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") @fendef def fencil(size0, size1, size2, inp1, inp2, out1, out2): @@ -204,8 +195,6 @@ def fencil(size0, size1, size2, inp1, inp2, out1, out2): def test_asymetric_nested_tuple_of_field_output_constructed_inside(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") @fundef def stencil(inp1, inp2, inp3): @@ -265,10 +254,8 @@ def fencil(size0, size1, size2, inp1, inp2, inp3, out1, out2, out3): "stencil", [tuple_output1, tuple_output2], ) -def test_field_of_extra_dim_output(program_processor_no_gtfn_exec, stencil): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") +def test_field_of_extra_dim_output(program_processor, stencil): + program_processor, validate = program_processor shape = [5, 7, 9] rng = np.random.default_rng() @@ -299,10 +286,9 @@ def tuple_input(inp): return tuple_get(0, inp_deref) + tuple_get(1, inp_deref) +@pytest.mark.uses_tuple_args def test_tuple_field_input(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") shape = [5, 7, 9] rng = np.random.default_rng() @@ -326,10 +312,8 @@ def test_tuple_field_input(program_processor): @pytest.mark.xfail(reason="Implement wrapper for extradim as tuple") -def test_field_of_extra_dim_input(program_processor_no_gtfn_exec): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") +def test_field_of_extra_dim_input(program_processor): + program_processor, validate = program_processor shape = [5, 7, 9] rng = np.random.default_rng() @@ -362,10 +346,9 @@ def tuple_tuple_input(inp): ) +@pytest.mark.uses_tuple_args def test_tuple_of_tuple_of_field_input(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") shape = [5, 7, 9] rng = np.random.default_rng() @@ -404,10 +387,8 @@ def test_tuple_of_tuple_of_field_input(program_processor): @pytest.mark.xfail(reason="Implement wrapper for extradim as tuple") -def test_field_of_2_extra_dim_input(program_processor_no_gtfn_exec): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") +def test_field_of_2_extra_dim_input(program_processor): + program_processor, validate = program_processor shape = [5, 7, 9] rng = np.random.default_rng() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 2580c6ba7f..8db9a4c36e 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -18,7 +18,7 @@ import pytest import gt4py.next as gtx -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu, roundtrip +from gt4py.next.program_processors.runners import gtfn_cpu, roundtrip from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( fieldview_backend, @@ -211,6 +211,7 @@ class setup: return setup() +@pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend): if fieldview_backend in [ gtfn_cpu.run_gtfn, @@ -218,8 +219,6 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend): gtfn_cpu.run_gtfn_with_temporaries, ]: pytest.xfail("Needs implementation of scan projector.") - if fieldview_backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: scans") solve_nonhydro_stencil_52_like_z_q.with_backend(fieldview_backend)( test_setup.z_alpha, @@ -233,6 +232,7 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend): assert np.allclose(test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:]) +@pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend): if fieldview_backend in [gtfn_cpu.run_gtfn_with_temporaries]: pytest.xfail( @@ -241,8 +241,6 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend): ) if fieldview_backend == roundtrip.executor: pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") - if fieldview_backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuples, scans") solve_nonhydro_stencil_52_like_z_q_tup.with_backend(fieldview_backend)( test_setup.z_alpha, @@ -256,11 +254,10 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend): assert np.allclose(test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:]) +@pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend): if fieldview_backend in [gtfn_cpu.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - if fieldview_backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: scans") solve_nonhydro_stencil_52_like.with_backend(fieldview_backend)( test_setup.z_alpha, test_setup.z_beta, @@ -274,13 +271,12 @@ def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend): assert np.allclose(test_setup.w_ref, test_setup.w) +@pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup, fieldview_backend): if fieldview_backend in [gtfn_cpu.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") if fieldview_backend == roundtrip.executor: pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") - if fieldview_backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuples, scans") solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge.with_backend(fieldview_backend)( test_setup.z_alpha, diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 14d929e822..16d839a8ab 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -19,7 +19,6 @@ from gt4py.next.iterator.builtins import cartesian_domain, deref, lift, named_range, shift from gt4py.next.iterator.runtime import closure, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn_cpu -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor @@ -75,6 +74,7 @@ def naive_lap(inp): return out +@pytest.mark.uses_origin def test_anton_toy(program_processor, lift_mode): program_processor, validate = program_processor @@ -87,8 +87,6 @@ def test_anton_toy(program_processor, lift_mode): if lift_mode != transforms.LiftMode.FORCE_INLINE: pytest.xfail("TODO: issue with temporaries that crashes the application") - if program_processor == run_dace_iterator: - pytest.xfail("TODO: not supported in DaCe backend") shape = [5, 7, 9] rng = np.random.default_rng() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index 2446d6664f..04cf8c6f9c 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -18,11 +18,6 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef, offset -from gt4py.next.program_processors.formatters.gtfn import ( - format_sourcecode as gtfn_format_sourcecode, -) -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator -from gt4py.next.program_processors.runners.gtfn_cpu import run_gtfn, run_gtfn_imperative from next_tests.integration_tests.cases import IDim, KDim from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor @@ -79,11 +74,10 @@ def basic_stencils(request): return request.param +@pytest.mark.uses_origin def test_basic_column_stencils(program_processor, lift_mode, basic_stencils): program_processor, validate = program_processor stencil, ref_fun, inp_fun = basic_stencils - if program_processor == run_dace_iterator and inp_fun: - pytest.xfail("Not supported in DaCe backend: origin") shape = [5, 7] inp = ( @@ -95,13 +89,6 @@ def test_basic_column_stencils(program_processor, lift_mode, basic_stencils): ref = ref_fun(inp) - if ( - program_processor == run_dace_iterator - and stencil.__name__ == "shift_stencil" - and inp.origin - ): - pytest.xfail("Not supported in DaCe backend: origin") - run_processor( stencil[{IDim: range(0, shape[0]), KDim: range(0, shape[1])}], program_processor, @@ -162,12 +149,10 @@ def k_level_condition_upper_tuple(k_idx, k_level): ), ], ) +@pytest.mark.uses_tuple_args def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_function, ref_function): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple arguments") - k_size = 5 inp = inp_function(k_size) ref = ref_function(inp) @@ -361,10 +346,6 @@ def sum_shifted_fencil(out, inp0, inp1, k_size): def test_different_vertical_sizes(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: argument types are not propagated for ITIR tests" - ) k_size = 10 inp0 = gtx.np_as_located_field(KDim)(np.arange(0, k_size)) @@ -401,10 +382,9 @@ def sum_fencil(out, inp0, inp1, k_size): ) +@pytest.mark.uses_origin def test_different_vertical_sizes_with_origin(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: origin") k_size = 10 inp0 = gtx.np_as_located_field(KDim)(np.arange(0, k_size)) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 2d35fb1e50..42de13ef44 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -15,8 +15,6 @@ import numpy as np import pytest -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator - pytest.importorskip("atlas4py") @@ -136,15 +134,9 @@ def nabla( ) +@pytest.mark.requires_atlas def test_compute_zavgS(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - run_dace_iterator, - ]: - pytest.xfail("TODO: bindings don't support Atlas tables") setup = nabla_setup() pp = gtx.np_as_located_field(Vertex)(setup.input_field) @@ -201,15 +193,9 @@ def compute_zavgS2_fencil( ) +@pytest.mark.requires_atlas def test_compute_zavgS2(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - run_dace_iterator, - ]: - pytest.xfail("TODO: bindings don't support Atlas tables") setup = nabla_setup() pp = gtx.np_as_located_field(Vertex)(setup.input_field) @@ -244,15 +230,9 @@ def test_compute_zavgS2(program_processor, lift_mode): assert_close(1000788897.3202186, np.max(zavgS[1])) +@pytest.mark.requires_atlas def test_nabla(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - run_dace_iterator, - ]: - pytest.xfail("TODO: bindings don't support Atlas tables") if lift_mode != LiftMode.FORCE_INLINE: pytest.xfail("shifted input arguments not supported for lift_mode != LiftMode.FORCE_INLINE") setup = nabla_setup() @@ -310,15 +290,9 @@ def nabla2( ) +@pytest.mark.requires_atlas def test_nabla2(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - run_dace_iterator, - ]: - pytest.xfail("TODO: bindings don't support Atlas tables") setup = nabla_setup() sign = gtx.np_as_located_field(Vertex, V2EDim)(setup.sign_field) @@ -400,13 +374,6 @@ def test_nabla_sign(program_processor, lift_mode): program_processor, validate = program_processor if lift_mode != LiftMode.FORCE_INLINE: pytest.xfail("test is broken due to bad lift semantics in iterator IR") - if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - run_dace_iterator, - ]: - pytest.xfail("TODO: bindings don't support Atlas tables") setup = nabla_setup() is_pole_edge = gtx.np_as_located_field(Edge)(setup.is_pole_edge_field) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index 1dfad40e48..7bd028b7c3 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -24,12 +24,7 @@ from next_tests.integration_tests.multi_feature_tests.iterator_tests.hdiff_reference import ( hdiff_reference, ) -from next_tests.unit_tests.conftest import ( - lift_mode, - program_processor, - program_processor_no_dace_exec, - run_processor, -) +from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor I = offset("I") @@ -76,8 +71,9 @@ def hdiff(inp, coeff, out, x, y): ) -def test_hdiff(hdiff_reference, program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_origin +def test_hdiff(hdiff_reference, program_processor, lift_mode): + program_processor, validate = program_processor if program_processor in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index 4474121876..f11046cb5d 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -25,12 +25,7 @@ from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests.cases import IDim, JDim, KDim -from next_tests.unit_tests.conftest import ( - lift_mode, - program_processor, - program_processor_no_dace_exec, - run_processor, -) +from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor @fundef @@ -120,8 +115,9 @@ def fen_solve_tridiag2(i_size, j_size, k_size, a, b, c, d, x): @pytest.mark.parametrize("fencil", [fen_solve_tridiag, fen_solve_tridiag2]) -def test_tridiag(fencil, tridiag_reference, program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_lift_expressions +def test_tridiag(fencil, tridiag_reference, program_processor, lift_mode): + program_processor, validate = program_processor if ( program_processor in [ diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index e781014c0c..92b93ddb63 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -32,7 +32,6 @@ from gt4py.next.iterator.runtime import fundef from gt4py.next.program_processors.formatters import gtfn from gt4py.next.program_processors.runners import gtfn_cpu -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.toy_connectivity import ( C2E, @@ -51,13 +50,7 @@ v2e_arr, v2v_arr, ) -from next_tests.unit_tests.conftest import ( - lift_mode, - program_processor, - program_processor_no_dace_exec, - program_processor_no_gtfn_exec, - run_processor, -) +from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor def edge_index_field(): # TODO replace by gtx.index_field once supported in bindings @@ -93,8 +86,8 @@ def sum_edges_to_vertices_reduce(in_edges): "stencil", [sum_edges_to_vertices, sum_edges_to_vertices_list_get_neighbors, sum_edges_to_vertices_reduce], ) -def test_sum_edges_to_vertices(program_processor_no_dace_exec, lift_mode, stencil): - program_processor, validate = program_processor_no_dace_exec +def test_sum_edges_to_vertices(program_processor, lift_mode, stencil): + program_processor, validate = program_processor inp = edge_index_field() out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) ref = np.asarray(list(sum(row) for row in v2e_arr)) @@ -116,10 +109,8 @@ def map_neighbors(in_edges): return reduce(plus, 0)(map_(plus)(neighbors(V2E, in_edges), neighbors(V2E, in_edges))) -def test_map_neighbors(program_processor_no_gtfn_exec, lift_mode): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: map_ builtin, neighbors, reduce") +def test_map_neighbors(program_processor, lift_mode): + program_processor, validate = program_processor inp = edge_index_field() out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) ref = 2 * np.sum(v2e_arr, axis=1) @@ -141,12 +132,9 @@ def map_make_const_list(in_edges): return reduce(plus, 0)(map_(multiplies)(neighbors(V2E, in_edges), make_const_list(2))) -def test_map_make_const_list(program_processor_no_gtfn_exec, lift_mode): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: map_ builtin, neighbors, reduce, make_const_list" - ) +@pytest.mark.uses_constant_fields +def test_map_make_const_list(program_processor, lift_mode): + program_processor, validate = program_processor inp = edge_index_field() out = gtx.np_as_located_field(Vertex)(np.zeros([9], inp.dtype)) ref = 2 * np.sum(v2e_arr, axis=1) @@ -194,10 +182,10 @@ def sparse_stencil(non_sparse, inp): return reduce(lambda a, b, c: a + c, 0)(neighbors(V2E, non_sparse), deref(inp)) -def test_sparse_input_field(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +def test_sparse_input_field(program_processor, lift_mode): + program_processor, validate = program_processor - non_sparse = gtx.np_as_located_field(Edge)(np.zeros(18)) + non_sparse = gtx.np_as_located_field(Edge)(np.zeros(18, dtype=np.int32)) inp = gtx.np_as_located_field(Vertex, V2EDim)(np.asarray([[1, 2, 3, 4]] * 9, dtype=np.int32)) out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) @@ -217,10 +205,10 @@ def test_sparse_input_field(program_processor_no_dace_exec, lift_mode): assert np.allclose(out, ref) -def test_sparse_input_field_v2v(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +def test_sparse_input_field_v2v(program_processor, lift_mode): + program_processor, validate = program_processor - non_sparse = gtx.np_as_located_field(Edge)(np.zeros(18)) + non_sparse = gtx.np_as_located_field(Edge)(np.zeros(18, dtype=np.int32)) inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) @@ -248,8 +236,9 @@ def slice_sparse_stencil(sparse): return list_get(1, deref(sparse)) -def test_slice_sparse(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_sparse_fields +def test_slice_sparse(program_processor, lift_mode): + program_processor, validate = program_processor inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) @@ -276,10 +265,10 @@ def slice_twice_sparse_stencil(sparse): @pytest.mark.xfail(reason="Field with more than one sparse dimension is not implemented.") -def test_slice_twice_sparse(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +def test_slice_twice_sparse(program_processor, lift_mode): + program_processor, validate = program_processor inp = gtx.np_as_located_field(Vertex, V2VDim, V2VDim)(v2v_arr[v2v_arr]) - out = gtx.np_as_located_field(Vertex)(np.zeros([9])) + out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) ref = v2v_arr[v2v_arr][:, 2, 1] run_processor( @@ -302,8 +291,9 @@ def shift_sliced_sparse_stencil(sparse): return list_get(1, deref(shift(V2V, 0)(sparse))) -def test_shift_sliced_sparse(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_sparse_fields +def test_shift_sliced_sparse(program_processor, lift_mode): + program_processor, validate = program_processor inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) @@ -329,8 +319,9 @@ def slice_shifted_sparse_stencil(sparse): return list_get(1, deref(shift(V2V, 0)(sparse))) -def test_slice_shifted_sparse(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_sparse_fields +def test_slice_shifted_sparse(program_processor, lift_mode): + program_processor, validate = program_processor inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) @@ -361,8 +352,8 @@ def lift_stencil(inp): return deref(shift(V2V, 2)(lift(deref_stencil)(inp))) -def test_lift(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +def test_lift(program_processor, lift_mode): + program_processor, validate = program_processor inp = vertex_index_field() out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) ref = np.asarray(np.asarray(range(9))) @@ -384,8 +375,9 @@ def sparse_shifted_stencil(inp): return list_get(2, list_get(0, neighbors(V2V, inp))) -def test_shift_sparse_input_field(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_sparse_fields +def test_shift_sparse_input_field(program_processor, lift_mode): + program_processor, validate = program_processor inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) ref = np.asarray(np.asarray(range(9))) @@ -413,8 +405,9 @@ def shift_sparse_stencil2(inp): return list_get(1, list_get(3, neighbors(V2E, inp))) -def test_shift_sparse_input_field2(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_sparse_fields +def test_shift_sparse_input_field2(program_processor, lift_mode): + program_processor, validate = program_processor if program_processor in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, @@ -463,14 +456,10 @@ def sum_(a, b): return reduce(sum_, 0)(neighbors(V2V, lift(lambda x: reduce(sum_, 0)(deref(x)))(inp))) -def test_sparse_shifted_stencil_reduce(program_processor_no_gtfn_exec, lift_mode): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == gtfn.format_sourcecode: - pytest.xfail("We cannot unroll a reduction on a sparse field only.") - # With our current understanding, this iterator IR program is illegal, however we might want to fix it and therefore keep the test for now. - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: illegal iterator IR") - +@pytest.mark.uses_sparse_fields +@pytest.mark.uses_reduction_with_only_sparse_fields +def test_sparse_shifted_stencil_reduce(program_processor, lift_mode): + program_processor, validate = program_processor if lift_mode != transforms.LiftMode.FORCE_INLINE: pytest.xfail("shifted input arguments not supported for lift_mode != LiftMode.FORCE_INLINE") diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 09d58a4376..7a62778be1 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -23,12 +23,16 @@ from gt4py.next.iterator import ir as itir, pretty_parser, pretty_printer, runtime, transforms from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.program_processors.formatters import gtfn, lisp, type_check -from gt4py.next.program_processors.runners import ( - dace_iterator, - double_roundtrip, - gtfn_cpu, - roundtrip, -) +from gt4py.next.program_processors.runners import double_roundtrip, gtfn_cpu, roundtrip + + +try: + from gt4py.next.program_processors.runners import dace_iterator +except ModuleNotFoundError as e: + if "dace" in str(e): + dace_iterator = None + else: + raise e import next_tests @@ -60,6 +64,11 @@ def pretty_format_and_check(root: itir.FencilDefinition, *args, **kwargs) -> str return pretty +OPTIONAL_PROCESSORS = [] +if dace_iterator: + OPTIONAL_PROCESSORS.append((dace_iterator.run_dace_iterator, True)) + + @pytest.fixture( params=[ # (processor, do_validate) @@ -73,30 +82,27 @@ def pretty_format_and_check(root: itir.FencilDefinition, *args, **kwargs) -> str (gtfn_cpu.run_gtfn_imperative, True), (gtfn_cpu.run_gtfn_with_temporaries, True), (gtfn.format_sourcecode, False), - (dace_iterator.run_dace_iterator, True), - ], + ] + + OPTIONAL_PROCESSORS, ids=lambda p: next_tests.get_processor_id(p[0]), ) def program_processor(request): - return request.param - + """ + Fixture creating program processors on-demand for tests. -@pytest.fixture -def program_processor_no_dace_exec(program_processor): - if program_processor[0] == dace_iterator.run_dace_iterator: - pytest.xfail("DaCe backend not yet supported.") - return program_processor + Notes: + Check ADR 15 for details on the test-exclusion matrices. + """ + backend, _ = request.param + backend_id = next_tests.get_processor_id(backend) - -@pytest.fixture -def program_processor_no_gtfn_exec(program_processor): - if ( - program_processor[0] == gtfn_cpu.run_gtfn - or program_processor[0] == gtfn_cpu.run_gtfn_imperative - or program_processor[0] == gtfn_cpu.run_gtfn_with_temporaries + for marker, skip_mark, msg in next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get( + backend_id, [] ): - pytest.xfail("gtfn backend not yet supported.") - return program_processor + if request.node.get_closest_marker(marker): + skip_mark(msg.format(marker=marker, backend=backend_id)) + + return request.param def run_processor( diff --git a/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py b/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py new file mode 100644 index 0000000000..e87f869352 --- /dev/null +++ b/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py @@ -0,0 +1,44 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import ast +import typing +from typing import TypeAlias + +import pytest + +import gt4py.next as gtx +from gt4py.next import float32, float64 +from gt4py.next.ffront.fbuiltins import astype +from gt4py.next.ffront.func_to_foast import FieldOperatorParser + + +TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests. +vpfloat: TypeAlias = float32 +wpfloat: TypeAlias = float64 + + +@pytest.mark.parametrize("test_input,expected", [(vpfloat, "float32"), (wpfloat, "float64")]) +def test_type_alias_replacement(test_input, expected): + def fieldop_with_typealias( + a: gtx.Field[[TDim], test_input], b: gtx.Field[[TDim], float32] + ) -> gtx.Field[[TDim], test_input]: + return test_input("3.1418") + astype(a, test_input) + + foast_tree = FieldOperatorParser.apply_to_function(fieldop_with_typealias) + + assert ( + foast_tree.body.stmts[0].value.left.func.id == expected + and foast_tree.body.stmts[0].value.right.args[1].id == expected + )