From 34516a65da9d5d0de930e2efc082a436e9da2485 Mon Sep 17 00:00:00 2001 From: Rico Haeuselmann Date: Tue, 19 Sep 2023 09:09:34 +0200 Subject: [PATCH 01/10] test[cartesian]: update dependencies and fix use of hypothesis decorators This PR tried to fix an error found in the Daily CI task after updating to hypothesis 6.82.1. The breaking change was fixed a couple of days later directly in in hypothesis, but the changes in this PR are likely to improve the quality of the code anyway. --- .pre-commit-config.yaml | 22 +++--- constraints.txt | 78 +++++++++---------- requirements-dev.txt | 78 +++++++++---------- src/gt4py/eve/extended_typing.py | 5 +- .../feature_tests/test_exec_info.py | 4 +- 5 files changed, 95 insertions(+), 92 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d70f335bef..b1092fafd0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -62,7 +62,7 @@ repos: ## version = re.search('black==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: '{version}' # version from constraints.txt") ##]]] - rev: '23.7.0' # version from constraints.txt + rev: '23.9.1' # version from constraints.txt ##[[[end]]] hooks: - id: black @@ -97,7 +97,7 @@ repos: ## print(f"- {pkg}==" + str(re.search(f'\n{pkg}==([0-9\.]*)', constraints)[1])) ##]]] - darglint==1.8.1 - - flake8-bugbear==23.7.10 + - flake8-bugbear==23.9.16 - flake8-builtins==2.1.0 - flake8-debugger==4.1.2 - flake8-docstrings==1.7.0 @@ -146,9 +146,9 @@ repos: ## version = re.search('mypy==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"#========= FROM constraints.txt: v{version} =========") ##]]] - #========= FROM constraints.txt: v1.5.0 ========= + #========= FROM constraints.txt: v1.5.1 ========= ##[[[end]]] - rev: v1.5.0 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) + rev: v1.5.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) hooks: - id: mypy additional_dependencies: # versions from constraints.txt @@ -162,26 +162,26 @@ repos: ##]]] - astunparse==1.6.3 - attrs==23.1.0 - - black==23.7.0 + - black==23.9.1 - boltons==23.0.0 - cached-property==1.5.2 - - click==8.1.6 - - cmake==3.27.2 + - click==8.1.7 + - cmake==3.27.5 - cytoolz==0.12.2 - - deepdiff==6.3.1 - - devtools==0.11.0 + - deepdiff==6.5.0 + - devtools==0.12.2 - frozendict==2.3.8 - gridtools-cpp==2.3.1 - importlib-resources==6.0.1 - jinja2==3.1.2 - lark==1.1.7 - mako==1.2.4 - - nanobind==1.5.0 + - nanobind==1.5.2 - ninja==1.11.1 - numpy==1.24.4 - packaging==23.1 - pybind11==2.11.1 - - setuptools==68.1.0 + - setuptools==68.2.2 - tabulate==0.9.0 - typing-extensions==4.5.0 - xxhash==3.0.0 diff --git a/constraints.txt b/constraints.txt index 35e3d9e330..b334851af1 100644 --- a/constraints.txt +++ b/constraints.txt @@ -6,14 +6,14 @@ # aenum==3.1.15 # via dace alabaster==0.7.13 # via sphinx -asttokens==2.2.1 # via devtools +asttokens==2.4.0 # via devtools astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing babel==2.12.1 # via sphinx -black==23.7.0 # via gt4py (pyproject.toml) +black==23.9.1 # via gt4py (pyproject.toml) blinker==1.6.2 # via flask boltons==23.0.0 # via gt4py (pyproject.toml) -build==0.10.0 # via pip-tools +build==1.0.3 # via pip-tools cached-property==1.5.2 # via gt4py (pyproject.toml) cachetools==5.3.1 # via tox certifi==2023.7.22 # via requests @@ -22,17 +22,17 @@ cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox charset-normalizer==3.2.0 # via requests clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) -click==8.1.6 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.27.2 # via gt4py (pyproject.toml) +click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools +cmake==3.27.5 # via gt4py (pyproject.toml) cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox -coverage==7.3.0 # via -r requirements-dev.in, pytest-cov +coverage==7.3.1 # via -r requirements-dev.in, pytest-cov cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis cytoolz==0.12.2 # via gt4py (pyproject.toml) dace==0.14.4 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -deepdiff==6.3.1 # via gt4py (pyproject.toml) -devtools==0.11.0 # via gt4py (pyproject.toml) +deepdiff==6.5.0 # via gt4py (pyproject.toml) +devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.7 # via dace distlib==0.3.7 # via virtualenv docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme @@ -41,11 +41,11 @@ exceptiongroup==1.1.3 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist executing==1.2.0 # via devtools factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==19.3.0 # via factory-boy +faker==19.6.1 # via factory-boy fastjsonschema==2.18.0 # via nbformat -filelock==3.12.2 # via tox, virtualenv +filelock==3.12.4 # via tox, virtualenv flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings -flake8-bugbear==23.7.10 # via -r requirements-dev.in +flake8-bugbear==23.9.16 # via -r requirements-dev.in flake8-builtins==2.1.0 # via -r requirements-dev.in flake8-debugger==4.1.2 # via -r requirements-dev.in flake8-docstrings==1.7.0 # via -r requirements-dev.in @@ -53,14 +53,14 @@ flake8-eradicate==1.5.0 # via -r requirements-dev.in flake8-mutable==1.2.0 # via -r requirements-dev.in flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in -flask==2.3.2 # via dace +flask==2.3.3 # via dace frozendict==2.3.8 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.82.4 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.5.26 # via pre-commit +hypothesis==6.86.1 # via -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.5.29 # via pre-commit idna==3.4 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.8.0 # via flask, sphinx +importlib-metadata==6.8.0 # via build, flask, sphinx importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest @@ -70,7 +70,7 @@ jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx jsonschema==4.19.0 # via nbformat jsonschema-specifications==2023.7.1 # via jsonschema jupyter-core==5.3.1 # via nbformat -jupytext==1.15.0 # via -r requirements-dev.in +jupytext==1.15.2 # via -r requirements-dev.in lark==1.1.7 # via gt4py (pyproject.toml) mako==1.2.4 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins @@ -79,9 +79,9 @@ mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.5.0 # via -r requirements-dev.in +mypy==1.5.1 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==1.5.0 # via gt4py (pyproject.toml) +nanobind==1.5.2 # via gt4py (pyproject.toml) nbformat==5.9.2 # via jupytext networkx==3.1 # via dace ninja==1.11.1 # via gt4py (pyproject.toml) @@ -94,36 +94,36 @@ pip-tools==7.3.0 # via -r requirements-dev.in pipdeptree==2.13.0 # via -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via jsonschema platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv -pluggy==1.2.0 # via pytest, tox +pluggy==1.3.0 # via pytest, tox ply==3.11 # via dace -pre-commit==3.3.3 # via -r requirements-dev.in +pre-commit==3.4.0 # via -r requirements-dev.in psutil==5.9.5 # via -r requirements-dev.in, pytest-xdist pybind11==2.11.1 # via gt4py (pyproject.toml) pycodestyle==2.11.0 # via flake8, flake8-debugger pycparser==2.21 # via cffi pydocstyle==6.3.0 # via flake8-docstrings pyflakes==3.1.0 # via flake8 -pygments==2.16.1 # via -r requirements-dev.in, flake8-rst-docstrings, sphinx -pyproject-api==1.5.3 # via tox +pygments==2.16.1 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx +pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.0 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==7.4.2 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in pytest-factoryboy==2.5.1 # via -r requirements-dev.in pytest-xdist==3.3.1 # via -r requirements-dev.in python-dateutil==2.8.2 # via faker -pytz==2023.3 # via babel +pytz==2023.3.post1 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit referencing==0.30.2 # via jsonschema, jsonschema-specifications requests==2.31.0 # via dace, sphinx restructuredtext-lint==1.4.0 # via flake8-rst-docstrings -rpds-py==0.9.2 # via jsonschema, referencing -ruff==0.0.284 # via -r requirements-dev.in +rpds-py==0.10.3 # via jsonschema, referencing +ruff==0.0.290 # via -r requirements-dev.in six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx sortedcontainers==2.4.0 # via hypothesis -sphinx==6.2.1 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==1.2.2 # via -r requirements-dev.in +sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery +sphinx-rtd-theme==1.3.0 # via -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx sphinxcontrib-htmlhelp==2.0.1 # via sphinx @@ -136,8 +136,8 @@ tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox toolz==0.12.0 # via cytoolz -tox==4.9.0 # via -r requirements-dev.in -traitlets==5.9.0 # via jupyter-core, nbformat +tox==4.11.3 # via -r requirements-dev.in +traitlets==5.10.0 # via jupyter-core, nbformat types-aiofiles==23.2.0.0 # via types-all types-all==1.0.0 # via -r requirements-dev.in types-annoy==1.17.8.4 # via types-all @@ -182,14 +182,14 @@ types-kazoo==0.1.3 # via types-all types-markdown==3.4.2.10 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 -types-mock==5.1.0.1 # via types-all +types-mock==5.1.0.2 # via types-all types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all types-paramiko==3.3.0.0 # via types-all, types-pysftp types-pathlib2==2.3.0 # via types-all -types-pillow==10.0.0.2 # via types-all +types-pillow==10.0.0.3 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.1 # via types-all types-protobuf==4.24.0.1 # via types-all @@ -205,17 +205,17 @@ types-pysftp==0.2.17.6 # via types-all types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.0.1 # via types-all, types-tzlocal +types-pytz==2023.3.1.0 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all types-pyyaml==6.0.12.11 # via types-all -types-redis==4.6.0.4 # via types-all +types-redis==4.6.0.6 # via types-all types-requests==2.31.0.2 # via types-all types-retry==0.9.9.4 # via types-all types-routes==2.5.0 # via types-all types-scribe==2.0.0 # via types-all -types-setuptools==68.1.0.0 # via types-cffi +types-setuptools==68.2.0.0 # via types-cffi types-simplejson==3.19.0.2 # via types-all -types-singledispatch==4.0.0.2 # via types-all +types-singledispatch==4.1.0.0 # via types-all types-six==1.16.21.9 # via types-all types-tabulate==0.9.0.3 # via types-all types-termcolor==1.1.6.2 # via types-all @@ -230,13 +230,13 @@ types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy urllib3==2.0.4 # via requests -virtualenv==20.24.3 # via pre-commit, tox +virtualenv==20.24.5 # via pre-commit, tox websockets==11.0.3 # via dace werkzeug==2.3.7 # via flask -wheel==0.41.1 # via astunparse, pip-tools +wheel==0.41.2 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) zipp==3.16.2 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: pip==23.2.1 # via pip-tools -setuptools==68.1.0 # via gt4py (pyproject.toml), nodeenv, pip-tools +setuptools==68.2.2 # via gt4py (pyproject.toml), nodeenv, pip-tools diff --git a/requirements-dev.txt b/requirements-dev.txt index a167b2979a..d6dcc12d21 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,14 +6,14 @@ # aenum==3.1.15 # via dace alabaster==0.7.13 # via sphinx -asttokens==2.2.1 # via devtools +asttokens==2.4.0 # via devtools astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing babel==2.12.1 # via sphinx -black==23.7.0 # via gt4py (pyproject.toml) +black==23.9.1 # via gt4py (pyproject.toml) blinker==1.6.2 # via flask boltons==23.0.0 # via gt4py (pyproject.toml) -build==0.10.0 # via pip-tools +build==1.0.3 # via pip-tools cached-property==1.5.2 # via gt4py (pyproject.toml) cachetools==5.3.1 # via tox certifi==2023.7.22 # via requests @@ -22,17 +22,17 @@ cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox charset-normalizer==3.2.0 # via requests clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) -click==8.1.6 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.27.2 # via gt4py (pyproject.toml) +click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools +cmake==3.27.5 # via gt4py (pyproject.toml) cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox -coverage[toml]==7.3.0 # via -r requirements-dev.in, pytest-cov +coverage[toml]==7.3.1 # via -r requirements-dev.in, pytest-cov cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis cytoolz==0.12.2 # via gt4py (pyproject.toml) dace==0.14.4 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -deepdiff==6.3.1 # via gt4py (pyproject.toml) -devtools==0.11.0 # via gt4py (pyproject.toml) +deepdiff==6.5.0 # via gt4py (pyproject.toml) +devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.7 # via dace distlib==0.3.7 # via virtualenv docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme @@ -41,11 +41,11 @@ exceptiongroup==1.1.3 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist executing==1.2.0 # via devtools factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==19.3.0 # via factory-boy +faker==19.6.1 # via factory-boy fastjsonschema==2.18.0 # via nbformat -filelock==3.12.2 # via tox, virtualenv +filelock==3.12.4 # via tox, virtualenv flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings -flake8-bugbear==23.7.10 # via -r requirements-dev.in +flake8-bugbear==23.9.16 # via -r requirements-dev.in flake8-builtins==2.1.0 # via -r requirements-dev.in flake8-debugger==4.1.2 # via -r requirements-dev.in flake8-docstrings==1.7.0 # via -r requirements-dev.in @@ -53,14 +53,14 @@ flake8-eradicate==1.5.0 # via -r requirements-dev.in flake8-mutable==1.2.0 # via -r requirements-dev.in flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in -flask==2.3.2 # via dace +flask==2.3.3 # via dace frozendict==2.3.8 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.82.4 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.5.26 # via pre-commit +hypothesis==6.86.1 # via -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.5.29 # via pre-commit idna==3.4 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.8.0 # via flask, sphinx +importlib-metadata==6.8.0 # via build, flask, sphinx importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest @@ -70,7 +70,7 @@ jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx jsonschema==4.19.0 # via nbformat jsonschema-specifications==2023.7.1 # via jsonschema jupyter-core==5.3.1 # via nbformat -jupytext==1.15.0 # via -r requirements-dev.in +jupytext==1.15.2 # via -r requirements-dev.in lark==1.1.7 # via gt4py (pyproject.toml) mako==1.2.4 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins @@ -79,9 +79,9 @@ mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.5.0 # via -r requirements-dev.in +mypy==1.5.1 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==1.5.0 # via gt4py (pyproject.toml) +nanobind==1.5.2 # via gt4py (pyproject.toml) nbformat==5.9.2 # via jupytext networkx==3.1 # via dace ninja==1.11.1 # via gt4py (pyproject.toml) @@ -94,36 +94,36 @@ pip-tools==7.3.0 # via -r requirements-dev.in pipdeptree==2.13.0 # via -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via jsonschema platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv -pluggy==1.2.0 # via pytest, tox +pluggy==1.3.0 # via pytest, tox ply==3.11 # via dace -pre-commit==3.3.3 # via -r requirements-dev.in +pre-commit==3.4.0 # via -r requirements-dev.in psutil==5.9.5 # via -r requirements-dev.in, pytest-xdist pybind11==2.11.1 # via gt4py (pyproject.toml) pycodestyle==2.11.0 # via flake8, flake8-debugger pycparser==2.21 # via cffi pydocstyle==6.3.0 # via flake8-docstrings pyflakes==3.1.0 # via flake8 -pygments==2.16.1 # via -r requirements-dev.in, flake8-rst-docstrings, sphinx -pyproject-api==1.5.3 # via tox +pygments==2.16.1 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx +pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.0 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==7.4.2 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in pytest-factoryboy==2.5.1 # via -r requirements-dev.in pytest-xdist[psutil]==3.3.1 # via -r requirements-dev.in python-dateutil==2.8.2 # via faker -pytz==2023.3 # via babel +pytz==2023.3.post1 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit referencing==0.30.2 # via jsonschema, jsonschema-specifications requests==2.31.0 # via dace, sphinx restructuredtext-lint==1.4.0 # via flake8-rst-docstrings -rpds-py==0.9.2 # via jsonschema, referencing -ruff==0.0.284 # via -r requirements-dev.in +rpds-py==0.10.3 # via jsonschema, referencing +ruff==0.0.290 # via -r requirements-dev.in six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx sortedcontainers==2.4.0 # via hypothesis -sphinx==6.2.1 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==1.2.2 # via -r requirements-dev.in +sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery +sphinx-rtd-theme==1.3.0 # via -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx sphinxcontrib-htmlhelp==2.0.1 # via sphinx @@ -136,8 +136,8 @@ tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox toolz==0.12.0 # via cytoolz -tox==4.9.0 # via -r requirements-dev.in -traitlets==5.9.0 # via jupyter-core, nbformat +tox==4.11.3 # via -r requirements-dev.in +traitlets==5.10.0 # via jupyter-core, nbformat types-aiofiles==23.2.0.0 # via types-all types-all==1.0.0 # via -r requirements-dev.in types-annoy==1.17.8.4 # via types-all @@ -182,14 +182,14 @@ types-kazoo==0.1.3 # via types-all types-markdown==3.4.2.10 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 -types-mock==5.1.0.1 # via types-all +types-mock==5.1.0.2 # via types-all types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all types-paramiko==3.3.0.0 # via types-all, types-pysftp types-pathlib2==2.3.0 # via types-all -types-pillow==10.0.0.2 # via types-all +types-pillow==10.0.0.3 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.1 # via types-all types-protobuf==4.24.0.1 # via types-all @@ -205,17 +205,17 @@ types-pysftp==0.2.17.6 # via types-all types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.0.1 # via types-all, types-tzlocal +types-pytz==2023.3.1.0 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all types-pyyaml==6.0.12.11 # via types-all -types-redis==4.6.0.4 # via types-all +types-redis==4.6.0.6 # via types-all types-requests==2.31.0.2 # via types-all types-retry==0.9.9.4 # via types-all types-routes==2.5.0 # via types-all types-scribe==2.0.0 # via types-all -types-setuptools==68.1.0.0 # via types-cffi +types-setuptools==68.2.0.0 # via types-cffi types-simplejson==3.19.0.2 # via types-all -types-singledispatch==4.0.0.2 # via types-all +types-singledispatch==4.1.0.0 # via types-all types-six==1.16.21.9 # via types-all types-tabulate==0.9.0.3 # via types-all types-termcolor==1.1.6.2 # via types-all @@ -230,13 +230,13 @@ types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy urllib3==2.0.4 # via requests -virtualenv==20.24.3 # via pre-commit, tox +virtualenv==20.24.5 # via pre-commit, tox websockets==11.0.3 # via dace werkzeug==2.3.7 # via flask -wheel==0.41.1 # via astunparse, pip-tools +wheel==0.41.2 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) zipp==3.16.2 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: pip==23.2.1 # via pip-tools -setuptools==68.1.0 # via gt4py (pyproject.toml), nodeenv, pip-tools +setuptools==68.2.2 # via gt4py (pyproject.toml), nodeenv, pip-tools diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index 34829317d6..3b8373ade1 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -552,11 +552,14 @@ def is_value_hashable_typing( return type_annotation is None -def is_protocol(type_: Type) -> bool: +def _is_protocol(type_: type, /) -> bool: """Check if a type is a Protocol definition.""" return getattr(type_, "_is_protocol", False) +is_protocol = getattr(_typing_extensions, "is_protocol", _is_protocol) + + def get_partial_type_hints( obj: Union[ object, diff --git a/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py b/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py index 2934e48e7a..6b8c02e41c 100644 --- a/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py +++ b/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py @@ -194,8 +194,8 @@ def subtest_stencil_info(self, exec_info, stencil_info, last_called_stencil=Fals else: assert stencil_info["total_run_cpp_time"] > stencil_info["run_cpp_time"] - @given(data=hyp_st.data()) @pytest.mark.parametrize("backend", ALL_BACKENDS) + @given(data=hyp_st.data()) def test_backcompatibility(self, data, backend, worker_id): # set backend as instance attribute self.backend = backend @@ -237,8 +237,8 @@ def test_backcompatibility(self, data, backend, worker_id): assert type(self.advection).__name__ not in exec_info assert type(self.diffusion).__name__ not in exec_info - @given(data=hyp_st.data()) @pytest.mark.parametrize("backend", ALL_BACKENDS) + @given(data=hyp_st.data()) def test_aggregate(self, data, backend, worker_id): # set backend as instance attribute self.backend = backend From ac6bf945d8b6e7677e3a247339b7698efc8806bd Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 19 Sep 2023 10:47:44 +0200 Subject: [PATCH 02/10] feat[next]: extend DaCe support of reduction operator (#1332) Adding generic implementation of neighbor-reduction to DaCe backend based on map with Write-Conflict Resolution (WCR) on output memlet. This PR enables use of lambdas as reduction function. --- .../runners/dace_iterator/itir_to_sdfg.py | 73 +---- .../runners/dace_iterator/itir_to_tasklet.py | 310 +++++++++++++++--- .../runners/dace_iterator/utility.py | 77 ++++- .../ffront_tests/test_external_local_field.py | 3 - .../ffront_tests/test_gt4py_builtins.py | 7 +- .../test_with_toy_connectivity.py | 32 +- 6 files changed, 351 insertions(+), 151 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 4f93777215..56031d8555 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -32,6 +32,7 @@ is_scan, ) from .utility import ( + add_mapped_nested_sdfg, as_dace_type, connectivity_identifier, create_memlet_at, @@ -321,7 +322,7 @@ def visit_StencilClosure( array_mapping = {**input_mapping, **conn_mapping} symbol_mapping = map_nested_sdfg_symbols(closure_sdfg, nsdfg, array_mapping) - nsdfg_node, map_entry, map_exit = self._add_mapped_nested_sdfg( + nsdfg_node, map_entry, map_exit = add_mapped_nested_sdfg( closure_state, sdfg=nsdfg, map_ranges=map_domain or {"__dummy": "0"}, @@ -584,76 +585,6 @@ def _visit_parallel_stencil_closure( return context.body, map_domain, [r.value.data for r in results] - def _add_mapped_nested_sdfg( - self, - state: dace.SDFGState, - map_ranges: dict[str, str | dace.subsets.Subset] - | list[tuple[str, str | dace.subsets.Subset]], - inputs: dict[str, dace.Memlet], - outputs: dict[str, dace.Memlet], - sdfg: dace.SDFG, - symbol_mapping: dict[str, Any] | None = None, - schedule: Any = dace.dtypes.ScheduleType.Default, - unroll_map: bool = False, - location: Any = None, - debuginfo: Any = None, - input_nodes: dict[str, dace.nodes.AccessNode] | None = None, - output_nodes: dict[str, dace.nodes.AccessNode] | None = None, - ) -> tuple[dace.nodes.NestedSDFG, dace.nodes.MapEntry, dace.nodes.MapExit]: - if not symbol_mapping: - symbol_mapping = {sym: sym for sym in sdfg.free_symbols} - - nsdfg_node = state.add_nested_sdfg( - sdfg, - None, - set(inputs.keys()), - set(outputs.keys()), - symbol_mapping, - name=sdfg.name, - schedule=schedule, - location=location, - debuginfo=debuginfo, - ) - - map_entry, map_exit = state.add_map( - f"{sdfg.name}_map", map_ranges, schedule, unroll_map, debuginfo - ) - - if input_nodes is None: - input_nodes = { - memlet.data: state.add_access(memlet.data) for name, memlet in inputs.items() - } - if output_nodes is None: - output_nodes = { - memlet.data: state.add_access(memlet.data) for name, memlet in outputs.items() - } - if not inputs: - state.add_edge(map_entry, None, nsdfg_node, None, dace.Memlet()) - for name, memlet in inputs.items(): - state.add_memlet_path( - input_nodes[memlet.data], - map_entry, - nsdfg_node, - memlet=memlet, - src_conn=None, - dst_conn=name, - propagate=True, - ) - if not outputs: - state.add_edge(nsdfg_node, None, map_exit, None, dace.Memlet()) - for name, memlet in outputs.items(): - state.add_memlet_path( - nsdfg_node, - map_exit, - output_nodes[memlet.data], - memlet=memlet, - src_conn=name, - dst_conn=None, - propagate=True, - ) - - return nsdfg_node, map_entry, map_exit - def _visit_domain( self, node: itir.FunCall, context: Context ) -> tuple[tuple[str, tuple[ValueExpr, ValueExpr]], ...]: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index d301c3e3cf..2e7a598d9a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -19,6 +19,8 @@ import dace import numpy as np +from dace.transformation.dataflow import MapFusion +from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols import gt4py.eve.codegen from gt4py.next import Dimension, type_inference as next_typing @@ -29,12 +31,14 @@ from gt4py.next.type_system import type_specifications as ts from .utility import ( + add_mapped_nested_sdfg, as_dace_type, connectivity_identifier, create_memlet_at, create_memlet_full, filter_neighbor_tables, map_nested_sdfg_symbols, + unique_name, unique_var_name, ) @@ -56,6 +60,21 @@ def itir_type_as_dace_type(type_: next_typing.Type): raise NotImplementedError() +def get_reduce_identity_value(op_name_: str, type_: Any): + if op_name_ == "plus": + init_value = type_(0) + elif op_name_ == "multiplies": + init_value = type_(1) + elif op_name_ == "minimum": + init_value = type_("inf") + elif op_name_ == "maximum": + init_value = type_("-inf") + else: + raise NotImplementedError() + + return init_value + + _MATH_BUILTINS_MAPPING = { "abs": "abs({})", "sin": "math.sin({})", @@ -136,6 +155,21 @@ class Context: body: dace.SDFG state: dace.SDFGState symbol_map: dict[str, IteratorExpr | ValueExpr | SymbolExpr] + # if we encounter a reduction node, the reduction state needs to be pushed to child nodes + reduce_limit: int + reduce_wcr: Optional[str] + + def __init__( + self, + body: dace.SDFG, + state: dace.SDFGState, + symbol_map: dict[str, IteratorExpr | ValueExpr | SymbolExpr], + ): + self.body = body + self.state = state + self.symbol_map = symbol_map + self.reduce_limit = 0 + self.reduce_wcr = None def builtin_neighbors( @@ -167,13 +201,15 @@ def builtin_neighbors( table_name = connectivity_identifier(offset_dim) table_array = sdfg.arrays[table_name] + # generate unique map index name to avoid conflict with other maps inside same state + index_name = unique_name("__neigh_idx") me, mx = state.add_map( f"{offset_dim}_neighbors_map", - ndrange={"neigh_idx": f"0:{table.max_neighbors}"}, + ndrange={index_name: f"0:{table.max_neighbors}"}, ) shift_tasklet = state.add_tasklet( "shift", - code="__result = __table[__idx, neigh_idx]", + code=f"__result = __table[__idx, {index_name}]", inputs={"__table", "__idx"}, outputs={"__result"}, ) @@ -227,7 +263,7 @@ def builtin_neighbors( data_access_tasklet, mx, result_access, - memlet=dace.Memlet(data=result_name, subset="neigh_idx"), + memlet=dace.Memlet(data=result_name, subset=index_name), src_conn="__result", ) @@ -349,6 +385,8 @@ def visit_Lambda( value = IteratorExpr(field, indices, arg.dtype, arg.dimensions) symbol_map[param] = value context = Context(context_sdfg, context_state, symbol_map) + context.reduce_limit = prev_context.reduce_limit + context.reduce_wcr = prev_context.reduce_wcr self.context = context # Add input parameters as arrays @@ -395,7 +433,12 @@ def visit_Lambda( self.context.body.add_scalar(result_name, result.dtype, transient=True) result_access = self.context.state.add_access(result_name) self.context.state.add_edge( - result.value, None, result_access, None, dace.Memlet(f"{result.value.data}[0]") + result.value, + None, + result_access, + None, + # in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution + dace.Memlet(f"{result_access.data}[0]", wcr=context.reduce_wcr), ) result = ValueExpr(value=result_access, dtype=result.dtype) else: @@ -531,15 +574,71 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: if not isinstance(iterator, IteratorExpr): # already a list of ValueExpr return iterator - sorted_index = sorted(iterator.indices.items(), key=lambda x: x[0]) - flat_index = [ - ValueExpr(x[1], iterator.dtype) for x in sorted_index if x[0] in iterator.dimensions - ] - args: list[ValueExpr] = [ValueExpr(iterator.field, iterator.dtype), *flat_index] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{', '.join(internals[1:])}]" - return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") + args: list[ValueExpr] + if self.context.reduce_limit: + # we are visiting a child node of reduction, so the neighbor index can be used for indirect addressing + result_name = unique_var_name() + self.context.body.add_array( + result_name, + dtype=iterator.dtype, + shape=(self.context.reduce_limit,), + transient=True, + ) + result_access = self.context.state.add_access(result_name) + + # generate unique map index name to avoid conflict with other maps inside same state + index_name = unique_name("__deref_idx") + me, mx = self.context.state.add_map( + "deref_map", + ndrange={index_name: f"0:{self.context.reduce_limit}"}, + ) + + # if dim is not found in iterator indices, we take the neighbor index over the reduction domain + array_index = [ + f"{iterator.indices[dim].data}_v" if dim in iterator.indices else index_name + for dim in sorted(iterator.dimensions) + ] + args = [ValueExpr(iterator.field, iterator.dtype)] + [ + ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.indices + ] + internals = [f"{arg.value.data}_v" for arg in args] + + deref_tasklet = self.context.state.add_tasklet( + name="deref", + inputs=set(internals), + outputs={"__result"}, + code=f"__result = {args[0].value.data}_v[{', '.join(array_index)}]", + ) + + for arg, internal in zip(args, internals): + input_memlet = create_memlet_full( + arg.value.data, self.context.body.arrays[arg.value.data] + ) + self.context.state.add_memlet_path( + arg.value, me, deref_tasklet, memlet=input_memlet, dst_conn=internal + ) + + self.context.state.add_memlet_path( + deref_tasklet, + mx, + result_access, + memlet=dace.Memlet(data=result_name, subset=index_name), + src_conn="__result", + ) + + return [ValueExpr(value=result_access, dtype=iterator.dtype)] + + else: + sorted_index = sorted(iterator.indices.items(), key=lambda x: x[0]) + flat_index = [ + ValueExpr(x[1], iterator.dtype) for x in sorted_index if x[0] in iterator.dimensions + ] + + args = [ValueExpr(iterator.field, int), *flat_index] + internals = [f"{arg.value.data}_v" for arg in args] + expr = f"{internals[0]}[{', '.join(internals[1:])}]" + return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") def _split_shift_args( self, args: list[itir.Expr] @@ -626,47 +725,156 @@ def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr: return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) def _visit_reduce(self, node: itir.FunCall): - assert ( - isinstance(node.args[0], itir.FunCall) - and isinstance(node.args[0].fun, itir.SymRef) - and node.args[0].fun.id == "neighbors" - ) - args = self.visit(node.args) - assert len(args) == 1 - args = args[0] - assert len(args) == 1 - assert isinstance(node.fun, itir.FunCall) - op_name = node.fun.args[0] - assert isinstance(op_name, itir.SymRef) - init = node.fun.args[1] - - nreduce = self.context.body.arrays[args[0].value.data].shape[0] - result_name = unique_var_name() result_access = self.context.state.add_access(result_name) - self.context.body.add_scalar(result_name, args[0].dtype, transient=True) - op_str = _MATH_BUILTINS_MAPPING[str(op_name)].format("__result", "__values[__idx]") - reduce_tasklet = self.context.state.add_tasklet( - "reduce", - code=f"__result = {init}\nfor __idx in range({nreduce}):\n __result = {op_str}", - inputs={"__values"}, - outputs={"__result"}, - ) - self.context.state.add_edge( - args[0].value, - None, - reduce_tasklet, - "__values", - dace.Memlet(data=args[0].value.data, subset=f"0:{nreduce}"), - ) - self.context.state.add_edge( - reduce_tasklet, - "__result", - result_access, - None, - dace.Memlet(data=result_name, subset="0"), - ) - return [ValueExpr(result_access, args[0].dtype)] + + if len(node.args) == 1: + assert ( + isinstance(node.args[0], itir.FunCall) + and isinstance(node.args[0].fun, itir.SymRef) + and node.args[0].fun.id == "neighbors" + ) + args = self.visit(node.args) + assert len(args) == 1 + args = args[0] + assert len(args) == 1 + neighbors_expr = args[0] + result_dtype = neighbors_expr.dtype + assert isinstance(node.fun, itir.FunCall) + op_name = node.fun.args[0] + assert isinstance(op_name, itir.SymRef) + init = node.fun.args[1] + + nreduce = self.context.body.arrays[neighbors_expr.value.data].shape[0] + + self.context.body.add_scalar(result_name, result_dtype, transient=True) + op_str = _MATH_BUILTINS_MAPPING[str(op_name)].format("__result", "__values[__idx]") + reduce_tasklet = self.context.state.add_tasklet( + "reduce", + code=f"__result = {init}\nfor __idx in range({nreduce}):\n __result = {op_str}", + inputs={"__values"}, + outputs={"__result"}, + ) + self.context.state.add_edge( + args[0].value, + None, + reduce_tasklet, + "__values", + dace.Memlet(data=neighbors_expr.value.data, subset=f"0:{nreduce}"), + ) + self.context.state.add_edge( + reduce_tasklet, + "__result", + result_access, + None, + dace.Memlet(data=result_name, subset="0"), + ) + else: + assert isinstance(node.fun, itir.FunCall) + assert isinstance(node.fun.args[0], itir.Lambda) + fun_node = node.fun.args[0] + + args = [] + for node_arg in node.args: + if ( + isinstance(node_arg, itir.FunCall) + and isinstance(node_arg.fun, itir.SymRef) + and node_arg.fun.id == "neighbors" + ): + expr = self.visit(node_arg) + args.append(*expr) + else: + args.append(None) + + # first visit only arguments for neighbor selection, all other arguments are none + neighbor_args = [arg for arg in args if arg] + + # check that all neighbors expression have the same range + assert ( + len( + set([self.context.body.arrays[expr.value.data].shape for expr in neighbor_args]) + ) + == 1 + ) + + nreduce = self.context.body.arrays[neighbor_args[0].value.data].shape[0] + nreduce_domain = {"__idx": f"0:{nreduce}"} + + result_dtype = neighbor_args[0].dtype + self.context.body.add_scalar(result_name, result_dtype, transient=True) + + assert isinstance(fun_node.expr, itir.FunCall) + op_name = fun_node.expr.fun + assert isinstance(op_name, itir.SymRef) + + # initialize the reduction result based on type of operation + init_value = get_reduce_identity_value(op_name.id, result_dtype) + init_state = self.context.body.add_state_before(self.context.state, "init") + init_tasklet = init_state.add_tasklet( + "init_reduce", {}, {"__out"}, f"__out = {init_value}" + ) + init_state.add_edge( + init_tasklet, + "__out", + init_state.add_access(result_name), + None, + dace.Memlet.simple(result_name, "0"), + ) + + # set reduction state to enable dereference of neighbors in input fields and to set WCR on reduce tasklet + self.context.reduce_limit = nreduce + self.context.reduce_wcr = "lambda x, y: " + _MATH_BUILTINS_MAPPING[str(op_name)].format( + "x", "y" + ) + + # visit child nodes for input arguments + for i, node_arg in enumerate(node.args): + if not args[i]: + args[i] = self.visit(node_arg)[0] + + lambda_node = itir.Lambda(expr=fun_node.expr.args[1], params=fun_node.params[1:]) + lambda_context, inner_inputs, inner_outputs = self.visit(lambda_node, args=args) + + # clear context + self.context.reduce_limit = 0 + self.context.reduce_wcr = None + + # the connectivity arrays (neighbor tables) are not needed inside the reduce lambda SDFG + neighbor_tables = filter_neighbor_tables(self.offset_provider) + for conn, _ in neighbor_tables: + var = connectivity_identifier(conn) + lambda_context.body.remove_data(var) + # cleanup symbols previously used for shape and stride of connectivity arrays + p = RemoveUnusedSymbols() + p.apply_pass(lambda_context.body, {}) + + input_memlets = [ + create_memlet_at(expr.value.data, ("__idx",)) for arg, expr in zip(node.args, args) + ] + output_memlet = dace.Memlet.simple(result_name, "0") + + input_mapping = {param: arg for (param, _), arg in zip(inner_inputs, input_memlets)} + output_mapping = {inner_outputs[0].value.data: output_memlet} + symbol_mapping = map_nested_sdfg_symbols( + self.context.body, lambda_context.body, input_mapping + ) + + nsdfg_node, map_entry, _ = add_mapped_nested_sdfg( + self.context.state, + sdfg=lambda_context.body, + map_ranges=nreduce_domain, + inputs=input_mapping, + outputs=output_mapping, + symbol_mapping=symbol_mapping, + input_nodes={arg.value.data: arg.value for arg in args}, + output_nodes={result_name: result_access}, + ) + + # we apply map fusion only to the nested-SDFG which is generated for the reduction operator + # the purpose is to keep the ITIR-visitor program simple and to clean up the generated SDFG + self.context.body.apply_transformations_repeated([MapFusion], validate=False) + + return [ValueExpr(result_access, result_dtype)] def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: assert isinstance(node.fun, itir.SymRef) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 85b1445dd9..889a1ab150 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -81,10 +81,83 @@ def map_nested_sdfg_symbols( return symbol_mapping +def add_mapped_nested_sdfg( + state: dace.SDFGState, + map_ranges: dict[str, str | dace.subsets.Subset] | list[tuple[str, str | dace.subsets.Subset]], + inputs: dict[str, dace.Memlet], + outputs: dict[str, dace.Memlet], + sdfg: dace.SDFG, + symbol_mapping: dict[str, Any] | None = None, + schedule: Any = dace.dtypes.ScheduleType.Default, + unroll_map: bool = False, + location: Any = None, + debuginfo: Any = None, + input_nodes: dict[str, dace.nodes.AccessNode] | None = None, + output_nodes: dict[str, dace.nodes.AccessNode] | None = None, +) -> tuple[dace.nodes.NestedSDFG, dace.nodes.MapEntry, dace.nodes.MapExit]: + if not symbol_mapping: + symbol_mapping = {sym: sym for sym in sdfg.free_symbols} + + nsdfg_node = state.add_nested_sdfg( + sdfg, + None, + set(inputs.keys()), + set(outputs.keys()), + symbol_mapping, + name=sdfg.name, + schedule=schedule, + location=location, + debuginfo=debuginfo, + ) + + map_entry, map_exit = state.add_map( + f"{sdfg.name}_map", map_ranges, schedule, unroll_map, debuginfo + ) + + if input_nodes is None: + input_nodes = { + memlet.data: state.add_access(memlet.data) for name, memlet in inputs.items() + } + if output_nodes is None: + output_nodes = { + memlet.data: state.add_access(memlet.data) for name, memlet in outputs.items() + } + if not inputs: + state.add_edge(map_entry, None, nsdfg_node, None, dace.Memlet()) + for name, memlet in inputs.items(): + state.add_memlet_path( + input_nodes[memlet.data], + map_entry, + nsdfg_node, + memlet=memlet, + src_conn=None, + dst_conn=name, + propagate=True, + ) + if not outputs: + state.add_edge(nsdfg_node, None, map_exit, None, dace.Memlet()) + for name, memlet in outputs.items(): + state.add_memlet_path( + nsdfg_node, + map_exit, + output_nodes[memlet.data], + memlet=memlet, + src_conn=name, + dst_conn=None, + propagate=True, + ) + + return nsdfg_node, map_entry, map_exit + + _unique_id = 0 -def unique_var_name(): +def unique_name(prefix): global _unique_id _unique_id += 1 - return f"__var_{_unique_id}" + return f"{prefix}_{_unique_id}" + + +def unique_var_name(): + return unique_name("__var") diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index f2c8525346..7f2b11afff 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -28,9 +28,6 @@ def test_external_local_field(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: reductions over non-field expressions") - @gtx.field_operator def testee( inp: gtx.Field[[Vertex, V2EDim], int32], ones: gtx.Field[[Edge], int32] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 7acc0e1447..ee88b3764e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -101,9 +101,7 @@ def fencil(edge_f: cases.EField, out: cases.VField): def test_reduction_expression_in_call(unstructured_case): if unstructured_case.backend == dace_iterator.run_dace_iterator: - # -edge_f(V2E) * tmp_nbh * 2 gets inlined with the neighbor_sum operation in the reduction in itir, - # so in addition to the skipped reason, currently itir is a lambda instead of the 'plus' operation - pytest.skip("Not supported in DaCe backend: Reductions not directly on a field.") + pytest.xfail("Not supported in DaCe backend: make_const_list") @gtx.field_operator def reduce_expr(edge_f: cases.EField) -> cases.VField: @@ -124,9 +122,6 @@ def fencil(edge_f: cases.EField, out: cases.VField): def test_reduction_with_common_expression(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.skip("Not supported in DaCe backend: Reductions not directly on a field.") - @gtx.field_operator def testee(flux: cases.EField) -> cases.VField: return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index e781014c0c..ee07372731 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -93,8 +93,8 @@ def sum_edges_to_vertices_reduce(in_edges): "stencil", [sum_edges_to_vertices, sum_edges_to_vertices_list_get_neighbors, sum_edges_to_vertices_reduce], ) -def test_sum_edges_to_vertices(program_processor_no_dace_exec, lift_mode, stencil): - program_processor, validate = program_processor_no_dace_exec +def test_sum_edges_to_vertices(program_processor, lift_mode, stencil): + program_processor, validate = program_processor inp = edge_index_field() out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) ref = np.asarray(list(sum(row) for row in v2e_arr)) @@ -116,10 +116,8 @@ def map_neighbors(in_edges): return reduce(plus, 0)(map_(plus)(neighbors(V2E, in_edges), neighbors(V2E, in_edges))) -def test_map_neighbors(program_processor_no_gtfn_exec, lift_mode): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: map_ builtin, neighbors, reduce") +def test_map_neighbors(program_processor, lift_mode): + program_processor, validate = program_processor inp = edge_index_field() out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) ref = 2 * np.sum(v2e_arr, axis=1) @@ -144,9 +142,7 @@ def map_make_const_list(in_edges): def test_map_make_const_list(program_processor_no_gtfn_exec, lift_mode): program_processor, validate = program_processor_no_gtfn_exec if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: map_ builtin, neighbors, reduce, make_const_list" - ) + pytest.xfail("Not supported in DaCe backend: make_const_list") inp = edge_index_field() out = gtx.np_as_located_field(Vertex)(np.zeros([9], inp.dtype)) ref = 2 * np.sum(v2e_arr, axis=1) @@ -194,10 +190,10 @@ def sparse_stencil(non_sparse, inp): return reduce(lambda a, b, c: a + c, 0)(neighbors(V2E, non_sparse), deref(inp)) -def test_sparse_input_field(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +def test_sparse_input_field(program_processor, lift_mode): + program_processor, validate = program_processor - non_sparse = gtx.np_as_located_field(Edge)(np.zeros(18)) + non_sparse = gtx.np_as_located_field(Edge)(np.zeros(18, dtype=np.int32)) inp = gtx.np_as_located_field(Vertex, V2EDim)(np.asarray([[1, 2, 3, 4]] * 9, dtype=np.int32)) out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) @@ -217,10 +213,10 @@ def test_sparse_input_field(program_processor_no_dace_exec, lift_mode): assert np.allclose(out, ref) -def test_sparse_input_field_v2v(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +def test_sparse_input_field_v2v(program_processor, lift_mode): + program_processor, validate = program_processor - non_sparse = gtx.np_as_located_field(Edge)(np.zeros(18)) + non_sparse = gtx.np_as_located_field(Edge)(np.zeros(18, dtype=np.int32)) inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) @@ -276,10 +272,10 @@ def slice_twice_sparse_stencil(sparse): @pytest.mark.xfail(reason="Field with more than one sparse dimension is not implemented.") -def test_slice_twice_sparse(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +def test_slice_twice_sparse(program_processor, lift_mode): + program_processor, validate = program_processor inp = gtx.np_as_located_field(Vertex, V2VDim, V2VDim)(v2v_arr[v2v_arr]) - out = gtx.np_as_located_field(Vertex)(np.zeros([9])) + out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) ref = v2v_arr[v2v_arr][:, 2, 1] run_processor( From d03ef4f15c63f7abb388c646c0eca33e1749c352 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 26 Sep 2023 12:52:56 +0200 Subject: [PATCH 03/10] test[next]: check for DaCe dependency in test execution (#1336) Expanding the pytest fixture for unit tests with markers to exclude tests based on feature support in the selected backend. In addition, a check is added to the DaCe backend so that tests are skipped if dace module is not installed. This is required for Spack build of icon4py, which uses the base installation of gt4py, where dace module is optional. --- .../ADRs/0015-Test_Exclusion_Matrices.md | 80 ++++++++++++++ docs/development/ADRs/Index.md | 4 +- pyproject.toml | 22 +++- tests/next_tests/exclusion_matrices.py | 89 +++++++++++++++ .../ffront_tests/ffront_test_utils.py | 31 +++++- .../ffront_tests/test_arg_call_interface.py | 11 +- .../ffront_tests/test_execution.py | 86 ++++----------- .../ffront_tests/test_external_local_field.py | 1 - .../ffront_tests/test_gt4py_builtins.py | 15 +-- .../test_math_builtin_execution.py | 1 - .../ffront_tests/test_math_unary_builtins.py | 13 +-- .../ffront_tests/test_program.py | 13 +-- .../ffront_tests/test_scalar_if.py | 102 +++--------------- .../iterator_tests/test_builtins.py | 12 +-- .../iterator_tests/test_conditional.py | 8 +- .../test_horizontal_indirection.py | 13 +-- .../iterator_tests/test_implicit_fencil.py | 5 - .../feature_tests/iterator_tests/test_scan.py | 4 +- .../test_strided_offset_provider.py | 10 +- .../iterator_tests/test_trivial.py | 1 - .../iterator_tests/test_tuple.py | 43 +++----- .../ffront_tests/test_icon_like_scan.py | 14 +-- .../iterator_tests/test_anton_toy.py | 4 +- .../iterator_tests/test_column_stencil.py | 26 +---- .../iterator_tests/test_fvm_nabla.py | 41 +------ .../iterator_tests/test_hdiff.py | 12 +-- .../iterator_tests/test_vertical_advection.py | 12 +-- .../test_with_toy_connectivity.py | 37 +++---- tests/next_tests/unit_tests/conftest.py | 39 ++++--- 29 files changed, 346 insertions(+), 403 deletions(-) create mode 100644 docs/development/ADRs/0015-Test_Exclusion_Matrices.md create mode 100644 tests/next_tests/exclusion_matrices.py diff --git a/docs/development/ADRs/0015-Test_Exclusion_Matrices.md b/docs/development/ADRs/0015-Test_Exclusion_Matrices.md new file mode 100644 index 0000000000..920504db9a --- /dev/null +++ b/docs/development/ADRs/0015-Test_Exclusion_Matrices.md @@ -0,0 +1,80 @@ +--- +tags: [] +--- + +# Test-Exclusion Matrices + +- **Status**: valid +- **Authors**: Edoardo Paone (@edopao), Enrique G. Paredes (@egparedes) +- **Created**: 2023-09-21 +- **Updated**: 2023-09-21 + +In the context of Field View testing, lacking support for specific ITIR features while a certain backend +is being developed, we decided to use `pytest` fixtures to exclude unsupported tests. + +## Context + +It should be possible to run Field View tests on different backends. However, specific tests could be unsupported +on a certain backend, or the backend implementation could be only partially ready. +Therefore, we need a mechanism to specify the features required by each test and selectively enable +the supported backends, while keeping the test code clean. + +## Decision + +It was decided to apply fixtures and markers from `pytest` module. The fixture is the same used to execute the test +on different backends (`fieldview_backend` and `program_processor`), but it is extended with a check on the available feature markers. +If a test is annotated with a feature marker, the fixture will check if this feature is supported on the selected backend. +If no marker is specified, the test is supposed to run on all backends. + +In the example below, `test_offset_field` requires the backend to support dynamic offsets in the translation from ITIR: + +```python +@pytest.mark.uses_dynamic_offsets +def test_offset_field(cartesian_case): +``` + +In order to selectively enable the backends, the dictionary `next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX` +lists for each backend the features that are not supported. +The fixture will check if the annotated feature is present in the exclusion-matrix for the selected backend. +If so, the exclusion matrix will also specify the action `pytest` should take (e.g. `SKIP` or `XFAIL`). + +The test-exclusion matrix is a dictionary, where `key` is the backend name and each entry is a tuple with the following fields: + +`(, , )` + +The backend string, used both as dictionary key and as string formatter in the skip message, is retrieved +by calling `tests.next_tests.get_processor_id()`, which returns the so-called processor name. +The following backend processors are defined: + +```python +DACE = "dace_iterator.run_dace_iterator" +GTFN_CPU = "otf_compile_executor.run_gtfn" +GTFN_CPU_IMPERATIVE = "otf_compile_executor.run_gtfn_imperative" +GTFN_CPU_WITH_TEMPORARIES = "otf_compile_executor.run_gtfn_with_temporaries" +``` + +Following the previous example, the GTFN backend with temporaries does not support yet dynamic offsets in ITIR: + +```python +BACKEND_SKIP_TEST_MATRIX = { + GTFN_CPU_WITH_TEMPORARIES: [ + ("uses_dynamic_offsets", pytest.XFAIL, "'{marker}' tests not supported by '{backend}' backend"), + ] +} +``` + +## Consequences + +Positive outcomes of this decision: + +- The solution provides a central place to specify test exclusion. +- The test code remains clean from if-statements for backend exclusion. +- The exclusion matrix gives an overview of the feature-readiness of different backends. + +Negative outcomes: + +- There is not (yet) any code-style check to enforce this solution, so code reviews should be aware of the ADR. + +## References + +- [pytest - Using markers to pass data to fixtures](https://docs.pytest.org/en/6.2.x/fixture.html#using-markers-to-pass-data-to-fixtures) diff --git a/docs/development/ADRs/Index.md b/docs/development/ADRs/Index.md index 1bbfd62d81..09d2273ee9 100644 --- a/docs/development/ADRs/Index.md +++ b/docs/development/ADRs/Index.md @@ -51,9 +51,9 @@ _None_ - [0011 - On The Fly Compilation](0011-On_The_Fly_Compilation.md) - [0012 - GridTools C++ OTF](0011-_GridTools_Cpp_OTF.md) -### Miscellanea +### Testing -_None_ +- [0015 - Exclusion Matrices](0015-Test_Exclusion_Matrices.md) ### Superseded diff --git a/pyproject.toml b/pyproject.toml index e915622857..e2d2a7dfe9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -326,9 +326,25 @@ module = 'gt4py.next.iterator.runtime' [tool.pytest.ini_options] markers = [ - 'requires_atlas', # mark tests that require 'atlas4py' bindings package - 'requires_dace', # mark tests that require 'dace' package - 'requires_gpu:' # mark tests that require a NVidia GPU (assume 'cupy' and 'cudatoolkit' are installed) + 'requires_atlas: tests that require `atlas4py` bindings package', + 'requires_dace: tests that require `dace` package', + 'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)', + 'uses_applied_shifts: tests that require backend support for applied-shifts', + 'uses_can_deref: tests that require backend support for can_deref', + 'uses_constant_fields: tests that require backend support for constant fields', + 'uses_dynamic_offsets: tests that require backend support for dynamic offsets', + 'uses_if_stmts: tests that require backend support for if-statements', + 'uses_index_fields: tests that require backend support for index fields', + 'uses_lift_expressions: tests that require backend support for lift expressions', + 'uses_negative_modulo: tests that require backend support for modulo on negative numbers', + 'uses_origin: tests that require backend support for domain origin', + 'uses_reduction_over_lift_expressions: tests that require backend support for reduction over lift expressions', + 'uses_scan_in_field_operator: tests that require backend support for scan in field operator', + 'uses_sparse_fields: tests that require backend support for sparse fields', + 'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset', + 'uses_tuple_args: tests that require backend support for tuple arguments', + 'uses_tuple_returns: tests that require backend support for tuple results', + 'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields' ] norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*'] testpaths = 'tests' diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py new file mode 100644 index 0000000000..d0a44080ad --- /dev/null +++ b/tests/next_tests/exclusion_matrices.py @@ -0,0 +1,89 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +import pytest + + +""" +Contains definition of test-exclusion matrices, see ADR 15. +""" + +# Skip definitions +XFAIL = pytest.xfail +SKIP = pytest.skip + +# Skip messages (available format keys: 'marker', 'backend') +UNSUPPORTED_MESSAGE = "'{marker}' tests not supported by '{backend}' backend" +BINDINGS_UNSUPPORTED_MESSAGE = "'{marker}' not supported by '{backend}' bindings" + +# Processor ids as returned by next_tests.get_processor_id() +DACE = "dace_iterator.run_dace_iterator" +GTFN_CPU = "otf_compile_executor.run_gtfn" +GTFN_CPU_IMPERATIVE = "otf_compile_executor.run_gtfn_imperative" +GTFN_CPU_WITH_TEMPORARIES = "otf_compile_executor.run_gtfn_with_temporaries" + +# Test markers +REQUIRES_ATLAS = "requires_atlas" +USES_APPLIED_SHIFTS = "uses_applied_shifts" +USES_CAN_DEREF = "uses_can_deref" +USES_CONSTANT_FIELDS = "uses_constant_fields" +USES_DYNAMIC_OFFSETS = "uses_dynamic_offsets" +USES_IF_STMTS = "uses_if_stmts" +USES_INDEX_FIELDS = "uses_index_fields" +USES_LIFT_EXPRESSIONS = "uses_lift_expressions" +USES_NEGATIVE_MODULO = "uses_negative_modulo" +USES_ORIGIN = "uses_origin" +USES_REDUCTION_OVER_LIFT_EXPRESSIONS = "uses_reduction_over_lift_expressions" +USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator" +USES_SPARSE_FIELDS = "uses_sparse_fields" +USES_STRIDED_NEIGHBOR_OFFSET = "uses_strided_neighbor_offset" +USES_TUPLE_ARGS = "uses_tuple_args" +USES_TUPLE_RETURNS = "uses_tuple_returns" +USES_ZERO_DIMENSIONAL_FIELDS = "uses_zero_dimensional_fields" + +# Common list of feature markers to skip +GTFN_SKIP_TEST_LIST = [ + (REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), +] + +""" +Skip matrix, contains for each backend processor a list of tuples with following fields: +(, ) +""" +BACKEND_SKIP_TEST_MATRIX = { + DACE: GTFN_SKIP_TEST_LIST + + [ + (USES_CAN_DEREF, XFAIL, UNSUPPORTED_MESSAGE), + (USES_CONSTANT_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), + (USES_REDUCTION_OVER_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + ], + GTFN_CPU: GTFN_SKIP_TEST_LIST, + GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST, + GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST + + [ + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + ], +} diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index a8c35cc28f..d3863f5a28 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -22,7 +22,17 @@ import gt4py.next as gtx from gt4py.next.ffront import decorator from gt4py.next.iterator import embedded, ir as itir -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu, roundtrip +from gt4py.next.program_processors.runners import gtfn_cpu, roundtrip +from tests.next_tests import exclusion_matrices + + +try: + from gt4py.next.program_processors.runners import dace_iterator +except ModuleNotFoundError as e: + if "dace" in str(e): + dace_iterator = None + else: + raise e import next_tests @@ -32,20 +42,33 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non raise ValueError("No backend selected! Backend selection is mandatory in tests.") +OPTIONAL_PROCESSORS = [] +if dace_iterator: + OPTIONAL_PROCESSORS.append(dace_iterator.run_dace_iterator) + + @pytest.fixture( params=[ roundtrip.executor, gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ], + ] + + OPTIONAL_PROCESSORS, ids=lambda p: next_tests.get_processor_id(p), ) def fieldview_backend(request): + backend = request.param + backend_id = next_tests.get_processor_id(backend) + + """See ADR 15.""" + for marker, skip_mark, msg in exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get(backend_id, []): + if request.node.get_closest_marker(marker): + skip_mark(msg.format(marker=marker, backend=backend_id)) + backup_backend = decorator.DEFAULT_BACKEND decorator.DEFAULT_BACKEND = no_backend - yield request.param + yield backend decorator.DEFAULT_BACKEND = backup_backend diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index 71e31542f7..1402649127 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -24,7 +24,7 @@ from gt4py.next.errors.exceptions import TypeError_ from gt4py.next.ffront.decorator import field_operator, program, scan_operator from gt4py.next.ffront.fbuiltins import broadcast, int32, int64 -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu +from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -169,15 +169,8 @@ def testee( ) +@pytest.mark.uses_scan_in_field_operator def test_call_scan_operator_from_field_operator(cartesian_case): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("Calling scan from field operator not fully supported.") - @scan_operator(axis=KDim, forward=True, init=0.0) def testee_scan(state: float, x: float, y: float) -> float: return state + x + 2.0 * y diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index f50f16ea0f..865950eeab 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -33,7 +33,7 @@ where, ) from gt4py.next.ffront.experimental import as_offset -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu +from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -68,10 +68,8 @@ def testee(a: cases.IJKField) -> cases.IJKField: cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) +@pytest.mark.uses_tuple_returns def test_multicopy(cartesian_case): # noqa: F811 # fixtures - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def testee(a: cases.IJKField, b: cases.IJKField) -> tuple[cases.IJKField, cases.IJKField]: return a, b @@ -161,10 +159,8 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> cases.IJKField: cases.verify(cartesian_case, testee, a, b, out=out, ref=a.ndarray[1:] + b.ndarray[2:]) +@pytest.mark.uses_tuple_returns def test_tuples(cartesian_case): # noqa: F811 # fixtures - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def testee(a: cases.IJKFloatField, b: cases.IJKFloatField) -> cases.IJKFloatField: inps = a, b @@ -211,10 +207,8 @@ def testee(a: int32) -> cases.VField: ) +@pytest.mark.uses_index_fields def test_scalar_arg_with_field(cartesian_case): # noqa: F811 # fixtures - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: index fields, constant fields") - @gtx.field_operator def testee(a: cases.IJKField, b: int32) -> cases.IJKField: tmp = b * a @@ -272,16 +266,8 @@ def testee(qc: cases.IKFloatField, scalar: float): cases.verify(cartesian_case, testee, qc, scalar, inout=qc, ref=expected) +@pytest.mark.uses_scan_in_field_operator def test_tuple_scalar_scan(cartesian_case): # noqa: F811 # fixtures - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - ]: - pytest.xfail("Scalar tuple arguments are not supported in gtfn yet.") - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple arguments") - @gtx.scan_operator(axis=KDim, forward=True, init=0.0) def testee_scan( state: float, qc_in: float, tuple_scalar: tuple[float, tuple[float, float]] @@ -301,10 +287,8 @@ def testee_op( cases.verify(cartesian_case, testee_op, qc, tuple_scalar, out=qc, ref=expected) +@pytest.mark.uses_index_fields def test_scalar_scan_vertical_offset(cartesian_case): # noqa: F811 # fixtures - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: scans") - @gtx.scan_operator(axis=KDim, forward=True, init=(0.0)) def testee_scan(state: float, inp: float) -> float: return inp @@ -382,12 +366,8 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], np.float32]: ) +@pytest.mark.uses_dynamic_offsets def test_offset_field(cartesian_case): - if cartesian_case.backend == gtfn_cpu.run_gtfn_with_temporaries: - pytest.xfail("Dynamic offsets not supported in gtfn") - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: offset fields") - ref = np.full( (cartesian_case.default_sizes[IDim], cartesian_case.default_sizes[KDim]), True, dtype=bool ) @@ -420,10 +400,8 @@ def testee(a: cases.IKField, offset_field: cases.IKField) -> gtx.Field[[IDim, KD assert np.allclose(out, ref) +@pytest.mark.uses_tuple_returns def test_nested_tuple_return(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def pack_tuple( a: cases.IField, b: cases.IField @@ -438,10 +416,8 @@ def combine(a: cases.IField, b: cases.IField) -> cases.IField: cases.verify_with_default_data(cartesian_case, combine, ref=lambda a, b: a + a + b) +@pytest.mark.uses_reduction_over_lift_expressions def test_nested_reduction(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: reductions over lift expressions") - @gtx.field_operator def testee(a: cases.EField) -> cases.EField: tmp = neighbor_sum(a(V2E), axis=V2EDim) @@ -481,10 +457,8 @@ def testee(inp: cases.EField) -> cases.EField: ) +@pytest.mark.uses_tuple_returns def test_tuple_return_2(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def testee(a: cases.EField, b: cases.EField) -> tuple[cases.VField, cases.VField]: tmp = neighbor_sum(a(V2E), axis=V2EDim) @@ -502,10 +476,8 @@ def testee(a: cases.EField, b: cases.EField) -> tuple[cases.VField, cases.VField ) +@pytest.mark.uses_tuple_returns def test_tuple_with_local_field_in_reduction_shifted(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuples") - @gtx.field_operator def reduce_tuple_element(e: cases.EField, v: cases.VField) -> cases.EField: tup = e(V2E), v @@ -522,10 +494,8 @@ def reduce_tuple_element(e: cases.EField, v: cases.VField) -> cases.EField: ) +@pytest.mark.uses_tuple_args def test_tuple_arg(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple args") - @gtx.field_operator def testee(a: tuple[tuple[cases.IField, cases.IField], cases.IField]) -> cases.IField: return 3 * a[0][0] + a[0][1] + a[1] @@ -555,6 +525,7 @@ def simple_scan_operator(carry: float) -> float: cases.verify(cartesian_case, simple_scan_operator, out=out, ref=expected) +@pytest.mark.uses_lift_expressions def test_solve_triag(cartesian_case): if cartesian_case.backend in [ gtfn_cpu.run_gtfn, @@ -564,8 +535,6 @@ def test_solve_triag(cartesian_case): pytest.xfail("Nested `scan`s requires creating temporaries.") if cartesian_case.backend == gtfn_cpu.run_gtfn_with_temporaries: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: scans") @gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0)) def tridiag_forward( @@ -627,10 +596,8 @@ def testee(left: int32, right: int32) -> cases.IField: @pytest.mark.parametrize("left, right", [(2, 3), (3, 2)]) +@pytest.mark.uses_tuple_returns def test_ternary_operator_tuple(cartesian_case, left, right): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def testee( a: cases.IField, b: cases.IField, left: int32, right: int32 @@ -646,10 +613,8 @@ def testee( ) +@pytest.mark.uses_reduction_over_lift_expressions def test_ternary_builtin_neighbor_sum(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: reductions over lift expressions") - @gtx.field_operator def testee(a: cases.EField, b: cases.EField) -> cases.VField: tmp = neighbor_sum(b(V2E) if 2 < 3 else a(V2E), axis=V2EDim) @@ -688,11 +653,10 @@ def simple_scan_operator(carry: float, a: float) -> float: @pytest.mark.parametrize("forward", [True, False]) +@pytest.mark.uses_tuple_returns def test_scan_nested_tuple_output(forward, cartesian_case): if cartesian_case.backend in [gtfn_cpu.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") init = (1, (2, 3)) k_size = cartesian_case.default_sizes[KDim] @@ -720,9 +684,8 @@ def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]): ) +@pytest.mark.uses_tuple_args def test_scan_nested_tuple_input(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple args") init = 1.0 k_size = cartesian_case.default_sizes[KDim] inp1 = gtx.np_as_located_field(KDim)(np.ones((k_size,))) @@ -877,10 +840,8 @@ def program_domain( ) +@pytest.mark.uses_tuple_returns def test_domain_tuple(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def fieldop_domain_tuple( a: cases.IJField, b: cases.IJField @@ -939,10 +900,8 @@ def return_undefined(): return undefined_symbol +@pytest.mark.uses_zero_dimensional_fields def test_zero_dims_fields(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: zero-dimensional fields") - @gtx.field_operator def implicit_broadcast_scalar(inp: cases.EmptyField): return inp @@ -970,10 +929,8 @@ def fieldop_implicit_broadcast_2(inp: cases.IField) -> cases.IField: ) +@pytest.mark.uses_tuple_returns def test_tuple_unpacking(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def unpack( inp: cases.IField, @@ -986,9 +943,8 @@ def unpack( ) +@pytest.mark.uses_tuple_returns def test_tuple_unpacking_star_multi(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") OutType = tuple[ cases.IField, cases.IField, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 7f2b11afff..dbc35ddfdf 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -17,7 +17,6 @@ import gt4py.next as gtx from gt4py.next import int32, neighbor_sum -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu from next_tests.integration_tests import cases from next_tests.integration_tests.cases import V2E, Edge, V2EDim, Vertex, unstructured_case diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index ee88b3764e..0ae874f3a6 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -18,7 +18,7 @@ import gt4py.next as gtx from gt4py.next import broadcast, float64, int32, int64, max_over, min_over, neighbor_sum, where -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu +from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -46,8 +46,6 @@ ids=["positive_values", "negative_values"], ) def test_maxover_execution_(unstructured_case, strategy): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: reductions") if unstructured_case.backend in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, @@ -69,9 +67,6 @@ def testee(edge_f: cases.EField) -> cases.VField: def test_minover_execution(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: reductions") - @gtx.field_operator def minover(edge_f: cases.EField) -> cases.VField: out = min_over(edge_f(V2E), axis=V2EDim) @@ -99,10 +94,8 @@ def fencil(edge_f: cases.EField, out: cases.VField): ) +@pytest.mark.uses_constant_fields def test_reduction_expression_in_call(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: make_const_list") - @gtx.field_operator def reduce_expr(edge_f: cases.EField) -> cases.VField: tmp_nbh_tup = edge_f(V2E), edge_f(V2E) @@ -133,10 +126,8 @@ def testee(flux: cases.EField) -> cases.VField: ) +@pytest.mark.uses_tuple_returns def test_conditional_nested_tuple(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def conditional_nested_tuple( mask: cases.IBoolField, a: cases.IFloatField, b: cases.IFloatField diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index 9ceab7f2d0..f7121dc82f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -22,7 +22,6 @@ from gt4py.next.ffront import dialect_ast_enums, fbuiltins, field_operator_ast as foast from gt4py.next.ffront.decorator import FieldOperator from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction -from gt4py.next.program_processors.runners import dace_iterator from gt4py.next.type_system import type_translation from next_tests.integration_tests import cases diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 54374077b4..85826c1ac0 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -37,7 +37,7 @@ tanh, trunc, ) -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu +from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests import cases from next_tests.integration_tests.cases import IDim, cartesian_case, unstructured_case @@ -84,17 +84,8 @@ def floorDiv(inp1: cases.IField) -> cases.IField: cases.verify_with_default_data(cartesian_case, floorDiv, ref=lambda inp1: inp1 // 2) +@pytest.mark.uses_negative_modulo def test_mod(cartesian_case): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail( - "Modulo not properly supported for negative numbers." - ) # see https://github.com/GridTools/gt4py/issues/1219 - @gtx.field_operator def mod_fieldop(inp1: cases.IField) -> cases.IField: return inp1 % 2 diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index d7c50e83f0..f489126fa7 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -20,7 +20,6 @@ import pytest import gt4py.next as gtx -from gt4py.next.program_processors.runners import dace_iterator from next_tests.integration_tests import cases from next_tests.integration_tests.cases import IDim, Ioff, JDim, cartesian_case, fieldview_backend @@ -129,10 +128,8 @@ def fo_from_fo_program(in_field: cases.IFloatField, out: cases.IFloatField): ) +@pytest.mark.uses_tuple_returns def test_tuple_program_return_constructed_inside(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def pack_tuple( a: cases.IFloatField, b: cases.IFloatField @@ -158,10 +155,8 @@ def prog( assert np.allclose((a, b), (out_a, out_b)) +@pytest.mark.uses_tuple_returns def test_tuple_program_return_constructed_inside_with_slicing(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def pack_tuple( a: cases.IFloatField, b: cases.IFloatField @@ -188,10 +183,8 @@ def prog( assert out_a[0] == 0 and out_b[0] == 0 +@pytest.mark.uses_tuple_returns def test_tuple_program_return_constructed_inside_nested(cartesian_case): - if cartesian_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") - @gtx.field_operator def pack_tuple( a: cases.IFloatField, b: cases.IFloatField, c: cases.IFloatField diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index a49dd1fdcf..f9fd2c1353 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -19,7 +19,6 @@ import pytest from gt4py.next import Field, errors, field_operator, float64, index_field, np_as_located_field -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -46,15 +45,8 @@ @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_simple_if(condition, cartesian_case): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def simple_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -71,15 +63,8 @@ def simple_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField @pytest.mark.parametrize("condition1, condition2", [[True, False], [True, False]]) +@pytest.mark.uses_if_stmts def test_simple_if_conditional(condition1, condition2, cartesian_case): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def simple_if( a: cases.IField, @@ -112,15 +97,8 @@ def simple_if( @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_local_if(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def local_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -138,15 +116,8 @@ def local_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_temporary_if(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def temporary_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -167,15 +138,8 @@ def temporary_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IFi @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_if_return(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def temporary_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -196,15 +160,8 @@ def temporary_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IFi @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_if_stmt_if_branch_returns(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def if_branch_returns(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -222,15 +179,8 @@ def if_branch_returns(a: cases.IField, b: cases.IField, condition: bool) -> case @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_if_stmt_else_branch_returns(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def else_branch_returns(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -250,15 +200,8 @@ def else_branch_returns(a: cases.IField, b: cases.IField, condition: bool) -> ca @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_if_stmt_both_branches_return(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def both_branches_return(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -278,15 +221,8 @@ def both_branches_return(a: cases.IField, b: cases.IField, condition: bool) -> c @pytest.mark.parametrize("condition1, condition2", [[True, False], [True, False]]) -def test_nested_if_stmt_conditinal(cartesian_case, condition1, condition2): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - +@pytest.mark.uses_if_stmts +def test_nested_if_stmt_conditional(cartesian_case, condition1, condition2): @field_operator def nested_if_conditional_return( inp: cases.IField, condition1: bool, condition2: bool @@ -322,15 +258,8 @@ def nested_if_conditional_return( @pytest.mark.parametrize("condition", [True, False]) +@pytest.mark.uses_if_stmts def test_nested_if(cartesian_case, condition): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def nested_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField: if condition: @@ -364,15 +293,8 @@ def nested_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField @pytest.mark.parametrize("condition1, condition2", [[True, False], [True, False]]) +@pytest.mark.uses_if_stmts def test_if_without_else(cartesian_case, condition1, condition2): - if cartesian_case.backend in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - dace_iterator.run_dace_iterator, - ]: - pytest.xfail("If-stmts are not supported yet.") - @field_operator def if_without_else( a: cases.IField, b: cases.IField, condition1: bool, condition2: bool diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 13fcf3b87f..ca29c5b18b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -52,7 +52,6 @@ xor_, ) from gt4py.next.iterator.runtime import closure, fendef, fundef, offset -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from gt4py.next.program_processors.runners.gtfn_cpu import run_gtfn from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data @@ -171,10 +170,6 @@ def arithmetic_and_logical_test_data(): @pytest.mark.parametrize("builtin, inputs, expected", arithmetic_and_logical_test_data()) def test_arithmetic_and_logical_builtins(program_processor, builtin, inputs, expected, as_column): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: argument types are not propagated for ITIR tests" - ) inps = asfield(*asarray(*inputs)) out = asfield((np.zeros_like(*asarray(expected))))[0] @@ -207,10 +202,6 @@ def test_arithmetic_and_logical_functors_gtfn(builtin, inputs, expected): @pytest.mark.parametrize("builtin_name, inputs", math_builtin_test_data()) def test_math_function_builtins(program_processor, builtin_name, inputs, as_column): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: argument types are not propagated for ITIR tests" - ) if builtin_name == "gamma": # numpy has no gamma function @@ -254,10 +245,9 @@ def foo(a): @pytest.mark.parametrize("stencil", [_can_deref, _can_deref_lifted]) +@pytest.mark.uses_can_deref def test_can_deref(program_processor, stencil): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: can_deref") Node = gtx.Dimension("Node") diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py index d20ec2ee3d..c2517f1a07 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py @@ -18,7 +18,6 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.unit_tests.conftest import program_processor, run_processor @@ -27,15 +26,14 @@ @fundef -def test_conditional(inp): +def stencil_conditional(inp): tmp = if_(eq(deref(inp), 0), make_tuple(1.0, 2.0), make_tuple(3.0, 4.0)) return tuple_get(0, tmp) + tuple_get(1, tmp) +@pytest.mark.uses_tuple_returns def test_conditional_w_tuple(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") shape = [5] @@ -46,7 +44,7 @@ def test_conditional_w_tuple(program_processor): IDim: range(0, shape[0]), } run_processor( - test_conditional[dom], + stencil_conditional[dom], program_processor, inp, out=out, diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py index f4ebc596e5..75b935677b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py @@ -34,8 +34,6 @@ from gt4py.next.program_processors.formatters.gtfn import ( format_sourcecode as gtfn_format_sourcecode, ) -from gt4py.next.program_processors.runners import gtfn_cpu -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.integration_tests.cases import IDim from next_tests.unit_tests.conftest import program_processor, run_processor @@ -54,16 +52,13 @@ def conditional_indirection(inp, cond): return deref(compute_shift(cond)(inp)) +@pytest.mark.uses_applied_shifts def test_simple_indirection(program_processor): program_processor, validate = program_processor if program_processor in [ type_check.check, - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, gtfn_format_sourcecode, - run_dace_iterator, ]: pytest.xfail( "We only support applied shifts in type_inference." @@ -97,13 +92,9 @@ def direct_indirection(inp, cond): return deref(shift(I, deref(cond))(inp)) +@pytest.mark.uses_dynamic_offsets def test_direct_offset_for_indirection(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: shift offsets not literals") - - if program_processor == gtfn_cpu.run_gtfn_with_temporaries: - pytest.xfail("Dynamic offsets not supported in temporaries pass.") shape = [4] inp = gtx.np_as_located_field(IDim)(np.asarray(range(shape[0]), dtype=np.float64)) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py index 2076cdd864..d0dc8ec475 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py @@ -18,7 +18,6 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import fundef -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.unit_tests.conftest import program_processor, run_processor @@ -59,10 +58,6 @@ def test_single_argument(program_processor, dom): def test_2_arguments(program_processor, dom): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: argument types are not propagated for ITIR tests" - ) @fundef def fun(inp0, inp1): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py index e0460b67b1..e02dab0a72 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py @@ -18,16 +18,14 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import cartesian_domain, deref, named_range, scan, shift from gt4py.next.iterator.runtime import fundef, offset -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.integration_tests.cases import IDim, KDim from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor +@pytest.mark.uses_index_fields def test_scan_in_stencil(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: shift inside lambda") isize = 1 ksize = 3 diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py index 7bfaa7f643..0ac38e9b9f 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py @@ -18,8 +18,6 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import deref, named_range, shift, unstructured_domain from gt4py.next.iterator.runtime import closure, fendef, fundef, offset -from gt4py.next.program_processors.runners import gtfn_cpu -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.unit_tests.conftest import program_processor, run_processor @@ -49,15 +47,9 @@ def fencil(size, out, inp): ) +@pytest.mark.uses_strided_neighbor_offset def test_strided_offset_provider(program_processor): program_processor, validate = program_processor - if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - run_dace_iterator, - ]: - pytest.xfail("gtx.StridedNeighborOffsetProvider not implemented in bindings.") LocA_size = 2 max_neighbors = LocA2LocAB_offset_provider.max_neighbors diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py index 7cc4e95949..cc12183a24 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py @@ -19,7 +19,6 @@ from gt4py.next.iterator import transforms from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef, offset -from gt4py.next.program_processors.runners.gtfn_cpu import run_gtfn from next_tests.integration_tests.cases import IDim, JDim, KDim from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index 5a6ffe2891..bd5a717bb2 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -18,13 +18,8 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator -from next_tests.unit_tests.conftest import ( - program_processor, - program_processor_no_gtfn_exec, - run_processor, -) +from next_tests.unit_tests.conftest import program_processor, run_processor IDim = gtx.Dimension("IDim") @@ -54,10 +49,9 @@ def tuple_output2(inp1, inp2): "stencil", [tuple_output1, tuple_output2], ) +@pytest.mark.uses_tuple_returns def test_tuple_output(program_processor, stencil): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") shape = [5, 7, 9] rng = np.random.default_rng() @@ -94,10 +88,9 @@ def tuple_of_tuple_output2(inp1, inp2, inp3, inp4): return make_tuple(deref(inp1), deref(inp2)), make_tuple(deref(inp3), deref(inp4)) +@pytest.mark.uses_tuple_returns def test_tuple_of_tuple_of_field_output(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") @fundef def stencil(inp1, inp2, inp3, inp4): @@ -155,10 +148,9 @@ def stencil(inp1, inp2, inp3, inp4): "stencil", [tuple_output1, tuple_output2], ) +@pytest.mark.uses_tuple_returns def test_tuple_of_field_output_constructed_inside(program_processor, stencil): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") @fendef def fencil(size0, size1, size2, inp1, inp2, out1, out2): @@ -202,10 +194,9 @@ def fencil(size0, size1, size2, inp1, inp2, out1, out2): assert np.allclose(inp2, out2) +@pytest.mark.uses_tuple_returns def test_asymetric_nested_tuple_of_field_output_constructed_inside(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") @fundef def stencil(inp1, inp2, inp3): @@ -265,10 +256,8 @@ def fencil(size0, size1, size2, inp1, inp2, inp3, out1, out2, out3): "stencil", [tuple_output1, tuple_output2], ) -def test_field_of_extra_dim_output(program_processor_no_gtfn_exec, stencil): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") +def test_field_of_extra_dim_output(program_processor, stencil): + program_processor, validate = program_processor shape = [5, 7, 9] rng = np.random.default_rng() @@ -299,10 +288,9 @@ def tuple_input(inp): return tuple_get(0, inp_deref) + tuple_get(1, inp_deref) +@pytest.mark.uses_tuple_returns def test_tuple_field_input(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") shape = [5, 7, 9] rng = np.random.default_rng() @@ -326,10 +314,8 @@ def test_tuple_field_input(program_processor): @pytest.mark.xfail(reason="Implement wrapper for extradim as tuple") -def test_field_of_extra_dim_input(program_processor_no_gtfn_exec): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") +def test_field_of_extra_dim_input(program_processor): + program_processor, validate = program_processor shape = [5, 7, 9] rng = np.random.default_rng() @@ -362,10 +348,9 @@ def tuple_tuple_input(inp): ) +@pytest.mark.uses_tuple_returns def test_tuple_of_tuple_of_field_input(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") shape = [5, 7, 9] rng = np.random.default_rng() @@ -404,10 +389,8 @@ def test_tuple_of_tuple_of_field_input(program_processor): @pytest.mark.xfail(reason="Implement wrapper for extradim as tuple") -def test_field_of_2_extra_dim_input(program_processor_no_gtfn_exec): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple returns") +def test_field_of_2_extra_dim_input(program_processor): + program_processor, validate = program_processor shape = [5, 7, 9] rng = np.random.default_rng() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 2580c6ba7f..8db9a4c36e 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -18,7 +18,7 @@ import pytest import gt4py.next as gtx -from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu, roundtrip +from gt4py.next.program_processors.runners import gtfn_cpu, roundtrip from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( fieldview_backend, @@ -211,6 +211,7 @@ class setup: return setup() +@pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend): if fieldview_backend in [ gtfn_cpu.run_gtfn, @@ -218,8 +219,6 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend): gtfn_cpu.run_gtfn_with_temporaries, ]: pytest.xfail("Needs implementation of scan projector.") - if fieldview_backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: scans") solve_nonhydro_stencil_52_like_z_q.with_backend(fieldview_backend)( test_setup.z_alpha, @@ -233,6 +232,7 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend): assert np.allclose(test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:]) +@pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend): if fieldview_backend in [gtfn_cpu.run_gtfn_with_temporaries]: pytest.xfail( @@ -241,8 +241,6 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend): ) if fieldview_backend == roundtrip.executor: pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") - if fieldview_backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuples, scans") solve_nonhydro_stencil_52_like_z_q_tup.with_backend(fieldview_backend)( test_setup.z_alpha, @@ -256,11 +254,10 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend): assert np.allclose(test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:]) +@pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend): if fieldview_backend in [gtfn_cpu.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - if fieldview_backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: scans") solve_nonhydro_stencil_52_like.with_backend(fieldview_backend)( test_setup.z_alpha, test_setup.z_beta, @@ -274,13 +271,12 @@ def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend): assert np.allclose(test_setup.w_ref, test_setup.w) +@pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup, fieldview_backend): if fieldview_backend in [gtfn_cpu.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") if fieldview_backend == roundtrip.executor: pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") - if fieldview_backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuples, scans") solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge.with_backend(fieldview_backend)( test_setup.z_alpha, diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 14d929e822..16d839a8ab 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -19,7 +19,6 @@ from gt4py.next.iterator.builtins import cartesian_domain, deref, lift, named_range, shift from gt4py.next.iterator.runtime import closure, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn_cpu -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor @@ -75,6 +74,7 @@ def naive_lap(inp): return out +@pytest.mark.uses_origin def test_anton_toy(program_processor, lift_mode): program_processor, validate = program_processor @@ -87,8 +87,6 @@ def test_anton_toy(program_processor, lift_mode): if lift_mode != transforms.LiftMode.FORCE_INLINE: pytest.xfail("TODO: issue with temporaries that crashes the application") - if program_processor == run_dace_iterator: - pytest.xfail("TODO: not supported in DaCe backend") shape = [5, 7, 9] rng = np.random.default_rng() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index 2446d6664f..41d6c8f0f9 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -18,11 +18,6 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef, offset -from gt4py.next.program_processors.formatters.gtfn import ( - format_sourcecode as gtfn_format_sourcecode, -) -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator -from gt4py.next.program_processors.runners.gtfn_cpu import run_gtfn, run_gtfn_imperative from next_tests.integration_tests.cases import IDim, KDim from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor @@ -79,11 +74,10 @@ def basic_stencils(request): return request.param +@pytest.mark.uses_origin def test_basic_column_stencils(program_processor, lift_mode, basic_stencils): program_processor, validate = program_processor stencil, ref_fun, inp_fun = basic_stencils - if program_processor == run_dace_iterator and inp_fun: - pytest.xfail("Not supported in DaCe backend: origin") shape = [5, 7] inp = ( @@ -95,13 +89,6 @@ def test_basic_column_stencils(program_processor, lift_mode, basic_stencils): ref = ref_fun(inp) - if ( - program_processor == run_dace_iterator - and stencil.__name__ == "shift_stencil" - and inp.origin - ): - pytest.xfail("Not supported in DaCe backend: origin") - run_processor( stencil[{IDim: range(0, shape[0]), KDim: range(0, shape[1])}], program_processor, @@ -162,12 +149,10 @@ def k_level_condition_upper_tuple(k_idx, k_level): ), ], ) +@pytest.mark.uses_tuple_returns def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_function, ref_function): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: tuple arguments") - k_size = 5 inp = inp_function(k_size) ref = ref_function(inp) @@ -361,10 +346,6 @@ def sum_shifted_fencil(out, inp0, inp1, k_size): def test_different_vertical_sizes(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: argument types are not propagated for ITIR tests" - ) k_size = 10 inp0 = gtx.np_as_located_field(KDim)(np.arange(0, k_size)) @@ -401,10 +382,9 @@ def sum_fencil(out, inp0, inp1, k_size): ) +@pytest.mark.uses_origin def test_different_vertical_sizes_with_origin(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: origin") k_size = 10 inp0 = gtx.np_as_located_field(KDim)(np.arange(0, k_size)) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 2d35fb1e50..42de13ef44 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -15,8 +15,6 @@ import numpy as np import pytest -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator - pytest.importorskip("atlas4py") @@ -136,15 +134,9 @@ def nabla( ) +@pytest.mark.requires_atlas def test_compute_zavgS(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - run_dace_iterator, - ]: - pytest.xfail("TODO: bindings don't support Atlas tables") setup = nabla_setup() pp = gtx.np_as_located_field(Vertex)(setup.input_field) @@ -201,15 +193,9 @@ def compute_zavgS2_fencil( ) +@pytest.mark.requires_atlas def test_compute_zavgS2(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - run_dace_iterator, - ]: - pytest.xfail("TODO: bindings don't support Atlas tables") setup = nabla_setup() pp = gtx.np_as_located_field(Vertex)(setup.input_field) @@ -244,15 +230,9 @@ def test_compute_zavgS2(program_processor, lift_mode): assert_close(1000788897.3202186, np.max(zavgS[1])) +@pytest.mark.requires_atlas def test_nabla(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - run_dace_iterator, - ]: - pytest.xfail("TODO: bindings don't support Atlas tables") if lift_mode != LiftMode.FORCE_INLINE: pytest.xfail("shifted input arguments not supported for lift_mode != LiftMode.FORCE_INLINE") setup = nabla_setup() @@ -310,15 +290,9 @@ def nabla2( ) +@pytest.mark.requires_atlas def test_nabla2(program_processor, lift_mode): program_processor, validate = program_processor - if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - run_dace_iterator, - ]: - pytest.xfail("TODO: bindings don't support Atlas tables") setup = nabla_setup() sign = gtx.np_as_located_field(Vertex, V2EDim)(setup.sign_field) @@ -400,13 +374,6 @@ def test_nabla_sign(program_processor, lift_mode): program_processor, validate = program_processor if lift_mode != LiftMode.FORCE_INLINE: pytest.xfail("test is broken due to bad lift semantics in iterator IR") - if program_processor in [ - gtfn_cpu.run_gtfn, - gtfn_cpu.run_gtfn_imperative, - gtfn_cpu.run_gtfn_with_temporaries, - run_dace_iterator, - ]: - pytest.xfail("TODO: bindings don't support Atlas tables") setup = nabla_setup() is_pole_edge = gtx.np_as_located_field(Edge)(setup.is_pole_edge_field) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index 1dfad40e48..7bd028b7c3 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -24,12 +24,7 @@ from next_tests.integration_tests.multi_feature_tests.iterator_tests.hdiff_reference import ( hdiff_reference, ) -from next_tests.unit_tests.conftest import ( - lift_mode, - program_processor, - program_processor_no_dace_exec, - run_processor, -) +from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor I = offset("I") @@ -76,8 +71,9 @@ def hdiff(inp, coeff, out, x, y): ) -def test_hdiff(hdiff_reference, program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_origin +def test_hdiff(hdiff_reference, program_processor, lift_mode): + program_processor, validate = program_processor if program_processor in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index 4474121876..f11046cb5d 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -25,12 +25,7 @@ from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests.cases import IDim, JDim, KDim -from next_tests.unit_tests.conftest import ( - lift_mode, - program_processor, - program_processor_no_dace_exec, - run_processor, -) +from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor @fundef @@ -120,8 +115,9 @@ def fen_solve_tridiag2(i_size, j_size, k_size, a, b, c, d, x): @pytest.mark.parametrize("fencil", [fen_solve_tridiag, fen_solve_tridiag2]) -def test_tridiag(fencil, tridiag_reference, program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_lift_expressions +def test_tridiag(fencil, tridiag_reference, program_processor, lift_mode): + program_processor, validate = program_processor if ( program_processor in [ diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index ee07372731..27c9f6d124 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -32,7 +32,6 @@ from gt4py.next.iterator.runtime import fundef from gt4py.next.program_processors.formatters import gtfn from gt4py.next.program_processors.runners import gtfn_cpu -from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator from next_tests.toy_connectivity import ( C2E, @@ -54,7 +53,6 @@ from next_tests.unit_tests.conftest import ( lift_mode, program_processor, - program_processor_no_dace_exec, program_processor_no_gtfn_exec, run_processor, ) @@ -139,10 +137,9 @@ def map_make_const_list(in_edges): return reduce(plus, 0)(map_(multiplies)(neighbors(V2E, in_edges), make_const_list(2))) +@pytest.mark.uses_constant_fields def test_map_make_const_list(program_processor_no_gtfn_exec, lift_mode): program_processor, validate = program_processor_no_gtfn_exec - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: make_const_list") inp = edge_index_field() out = gtx.np_as_located_field(Vertex)(np.zeros([9], inp.dtype)) ref = 2 * np.sum(v2e_arr, axis=1) @@ -244,8 +241,9 @@ def slice_sparse_stencil(sparse): return list_get(1, deref(sparse)) -def test_slice_sparse(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_sparse_fields +def test_slice_sparse(program_processor, lift_mode): + program_processor, validate = program_processor inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) @@ -298,8 +296,9 @@ def shift_sliced_sparse_stencil(sparse): return list_get(1, deref(shift(V2V, 0)(sparse))) -def test_shift_sliced_sparse(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_sparse_fields +def test_shift_sliced_sparse(program_processor, lift_mode): + program_processor, validate = program_processor inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) @@ -325,8 +324,9 @@ def slice_shifted_sparse_stencil(sparse): return list_get(1, deref(shift(V2V, 0)(sparse))) -def test_slice_shifted_sparse(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_sparse_fields +def test_slice_shifted_sparse(program_processor, lift_mode): + program_processor, validate = program_processor inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) @@ -357,8 +357,8 @@ def lift_stencil(inp): return deref(shift(V2V, 2)(lift(deref_stencil)(inp))) -def test_lift(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +def test_lift(program_processor, lift_mode): + program_processor, validate = program_processor inp = vertex_index_field() out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) ref = np.asarray(np.asarray(range(9))) @@ -380,8 +380,9 @@ def sparse_shifted_stencil(inp): return list_get(2, list_get(0, neighbors(V2V, inp))) -def test_shift_sparse_input_field(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_sparse_fields +def test_shift_sparse_input_field(program_processor, lift_mode): + program_processor, validate = program_processor inp = gtx.np_as_located_field(Vertex, V2VDim)(v2v_arr) out = gtx.np_as_located_field(Vertex)(np.zeros([9], dtype=inp.dtype)) ref = np.asarray(np.asarray(range(9))) @@ -409,8 +410,9 @@ def shift_sparse_stencil2(inp): return list_get(1, list_get(3, neighbors(V2E, inp))) -def test_shift_sparse_input_field2(program_processor_no_dace_exec, lift_mode): - program_processor, validate = program_processor_no_dace_exec +@pytest.mark.uses_sparse_fields +def test_shift_sparse_input_field2(program_processor, lift_mode): + program_processor, validate = program_processor if program_processor in [ gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative, @@ -459,13 +461,12 @@ def sum_(a, b): return reduce(sum_, 0)(neighbors(V2V, lift(lambda x: reduce(sum_, 0)(deref(x)))(inp))) +@pytest.mark.uses_sparse_fields def test_sparse_shifted_stencil_reduce(program_processor_no_gtfn_exec, lift_mode): program_processor, validate = program_processor_no_gtfn_exec if program_processor == gtfn.format_sourcecode: pytest.xfail("We cannot unroll a reduction on a sparse field only.") # With our current understanding, this iterator IR program is illegal, however we might want to fix it and therefore keep the test for now. - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: illegal iterator IR") if lift_mode != transforms.LiftMode.FORCE_INLINE: pytest.xfail("shifted input arguments not supported for lift_mode != LiftMode.FORCE_INLINE") diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 09d58a4376..04c34dfaab 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -23,12 +23,17 @@ from gt4py.next.iterator import ir as itir, pretty_parser, pretty_printer, runtime, transforms from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.program_processors.formatters import gtfn, lisp, type_check -from gt4py.next.program_processors.runners import ( - dace_iterator, - double_roundtrip, - gtfn_cpu, - roundtrip, -) +from gt4py.next.program_processors.runners import double_roundtrip, gtfn_cpu, roundtrip +from tests.next_tests import exclusion_matrices + + +try: + from gt4py.next.program_processors.runners import dace_iterator +except ModuleNotFoundError as e: + if "dace" in str(e): + dace_iterator = None + else: + raise e import next_tests @@ -60,6 +65,11 @@ def pretty_format_and_check(root: itir.FencilDefinition, *args, **kwargs) -> str return pretty +OPTIONAL_PROCESSORS = [] +if dace_iterator: + OPTIONAL_PROCESSORS.append((dace_iterator.run_dace_iterator, True)) + + @pytest.fixture( params=[ # (processor, do_validate) @@ -73,19 +83,20 @@ def pretty_format_and_check(root: itir.FencilDefinition, *args, **kwargs) -> str (gtfn_cpu.run_gtfn_imperative, True), (gtfn_cpu.run_gtfn_with_temporaries, True), (gtfn.format_sourcecode, False), - (dace_iterator.run_dace_iterator, True), - ], + ] + + OPTIONAL_PROCESSORS, ids=lambda p: next_tests.get_processor_id(p[0]), ) def program_processor(request): - return request.param + backend, _ = request.param + backend_id = next_tests.get_processor_id(backend) + """See ADR 15.""" + for marker, skip_mark, msg in exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get(backend_id, []): + if request.node.get_closest_marker(marker): + skip_mark(msg.format(marker=marker, backend=backend_id)) -@pytest.fixture -def program_processor_no_dace_exec(program_processor): - if program_processor[0] == dace_iterator.run_dace_iterator: - pytest.xfail("DaCe backend not yet supported.") - return program_processor + return request.param @pytest.fixture From 54bca831100455741e5bed459ef8378a1418b5ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 4 Oct 2023 15:57:35 +0200 Subject: [PATCH 04/10] Fixes and additions to test exclusion matrices functionality. (#1345) Fixes and additions to test exclusion matrices. Changes: - Fix import path of exclusion matrices. - Fix wrong locations of docstrings. - Remove deprecated fixtures. - Add missing marker to parametrize forgotten custom case. --- .../ADRs/0015-Test_Exclusion_Matrices.md | 2 +- tests/next_tests/__init__.py | 5 ++++ tests/next_tests/exclusion_matrices.py | 26 +++++++++++-------- .../ffront_tests/ffront_test_utils.py | 12 ++++++--- .../test_with_toy_connectivity.py | 20 +++++--------- tests/next_tests/unit_tests/conftest.py | 23 +++++++--------- 6 files changed, 45 insertions(+), 43 deletions(-) diff --git a/docs/development/ADRs/0015-Test_Exclusion_Matrices.md b/docs/development/ADRs/0015-Test_Exclusion_Matrices.md index 920504db9a..6c6a043560 100644 --- a/docs/development/ADRs/0015-Test_Exclusion_Matrices.md +++ b/docs/development/ADRs/0015-Test_Exclusion_Matrices.md @@ -43,7 +43,7 @@ The test-exclusion matrix is a dictionary, where `key` is the backend name and e `(, , )` The backend string, used both as dictionary key and as string formatter in the skip message, is retrieved -by calling `tests.next_tests.get_processor_id()`, which returns the so-called processor name. +by calling `next_tests.get_processor_id()`, which returns the so-called processor name. The following backend processors are defined: ```python diff --git a/tests/next_tests/__init__.py b/tests/next_tests/__init__.py index bd9b968948..54bc4d9c69 100644 --- a/tests/next_tests/__init__.py +++ b/tests/next_tests/__init__.py @@ -12,6 +12,11 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +from . import exclusion_matrices + + +__all__ = ["exclusion_matrices", "get_processor_id"] + def get_processor_id(processor): if hasattr(processor, "__module__") and hasattr(processor, "__name__"): diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index d0a44080ad..27ccb29095 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -14,23 +14,18 @@ import pytest -""" -Contains definition of test-exclusion matrices, see ADR 15. -""" +"""Contains definition of test-exclusion matrices, see ADR 15.""" # Skip definitions XFAIL = pytest.xfail SKIP = pytest.skip -# Skip messages (available format keys: 'marker', 'backend') -UNSUPPORTED_MESSAGE = "'{marker}' tests not supported by '{backend}' backend" -BINDINGS_UNSUPPORTED_MESSAGE = "'{marker}' not supported by '{backend}' bindings" - # Processor ids as returned by next_tests.get_processor_id() DACE = "dace_iterator.run_dace_iterator" GTFN_CPU = "otf_compile_executor.run_gtfn" GTFN_CPU_IMPERATIVE = "otf_compile_executor.run_gtfn_imperative" GTFN_CPU_WITH_TEMPORARIES = "otf_compile_executor.run_gtfn_with_temporaries" +GTFN_FORMAT_SOURCECODE = "gtfn.format_sourcecode" # Test markers REQUIRES_ATLAS = "requires_atlas" @@ -46,25 +41,31 @@ USES_REDUCTION_OVER_LIFT_EXPRESSIONS = "uses_reduction_over_lift_expressions" USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator" USES_SPARSE_FIELDS = "uses_sparse_fields" +USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS = "uses_reduction_with_only_sparse_fields" USES_STRIDED_NEIGHBOR_OFFSET = "uses_strided_neighbor_offset" USES_TUPLE_ARGS = "uses_tuple_args" USES_TUPLE_RETURNS = "uses_tuple_returns" USES_ZERO_DIMENSIONAL_FIELDS = "uses_zero_dimensional_fields" +# Skip messages (available format keys: 'marker', 'backend') +UNSUPPORTED_MESSAGE = "'{marker}' tests not supported by '{backend}' backend" +BINDINGS_UNSUPPORTED_MESSAGE = "'{marker}' not supported by '{backend}' bindings" +REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE = ( + "We cannot unroll a reduction on a sparse field only (not clear if it is legal ITIR)" +) # Common list of feature markers to skip GTFN_SKIP_TEST_LIST = [ (REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), + (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), ] -""" -Skip matrix, contains for each backend processor a list of tuples with following fields: -(, ) -""" +#: Skip matrix, contains for each backend processor a list of tuples with following fields: +#: (, ) BACKEND_SKIP_TEST_MATRIX = { DACE: GTFN_SKIP_TEST_LIST + [ @@ -86,4 +87,7 @@ + [ (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), ], + GTFN_FORMAT_SOURCECODE: [ + (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), + ], } diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index d3863f5a28..383716484e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -23,7 +23,6 @@ from gt4py.next.ffront import decorator from gt4py.next.iterator import embedded, ir as itir from gt4py.next.program_processors.runners import gtfn_cpu, roundtrip -from tests.next_tests import exclusion_matrices try: @@ -58,11 +57,18 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non ids=lambda p: next_tests.get_processor_id(p), ) def fieldview_backend(request): + """ + Fixture creating field-view operator backend on-demand for tests. + + Notes: + Check ADR 15 for details on the test-exclusion matrices. + """ backend = request.param backend_id = next_tests.get_processor_id(backend) - """See ADR 15.""" - for marker, skip_mark, msg in exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get(backend_id, []): + for marker, skip_mark, msg in next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get( + backend_id, [] + ): if request.node.get_closest_marker(marker): skip_mark(msg.format(marker=marker, backend=backend_id)) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index 27c9f6d124..92b93ddb63 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -50,12 +50,7 @@ v2e_arr, v2v_arr, ) -from next_tests.unit_tests.conftest import ( - lift_mode, - program_processor, - program_processor_no_gtfn_exec, - run_processor, -) +from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor def edge_index_field(): # TODO replace by gtx.index_field once supported in bindings @@ -138,8 +133,8 @@ def map_make_const_list(in_edges): @pytest.mark.uses_constant_fields -def test_map_make_const_list(program_processor_no_gtfn_exec, lift_mode): - program_processor, validate = program_processor_no_gtfn_exec +def test_map_make_const_list(program_processor, lift_mode): + program_processor, validate = program_processor inp = edge_index_field() out = gtx.np_as_located_field(Vertex)(np.zeros([9], inp.dtype)) ref = 2 * np.sum(v2e_arr, axis=1) @@ -462,12 +457,9 @@ def sum_(a, b): @pytest.mark.uses_sparse_fields -def test_sparse_shifted_stencil_reduce(program_processor_no_gtfn_exec, lift_mode): - program_processor, validate = program_processor_no_gtfn_exec - if program_processor == gtfn.format_sourcecode: - pytest.xfail("We cannot unroll a reduction on a sparse field only.") - # With our current understanding, this iterator IR program is illegal, however we might want to fix it and therefore keep the test for now. - +@pytest.mark.uses_reduction_with_only_sparse_fields +def test_sparse_shifted_stencil_reduce(program_processor, lift_mode): + program_processor, validate = program_processor if lift_mode != transforms.LiftMode.FORCE_INLINE: pytest.xfail("shifted input arguments not supported for lift_mode != LiftMode.FORCE_INLINE") diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 04c34dfaab..7a62778be1 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -24,7 +24,6 @@ from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.program_processors.formatters import gtfn, lisp, type_check from gt4py.next.program_processors.runners import double_roundtrip, gtfn_cpu, roundtrip -from tests.next_tests import exclusion_matrices try: @@ -88,28 +87,24 @@ def pretty_format_and_check(root: itir.FencilDefinition, *args, **kwargs) -> str ids=lambda p: next_tests.get_processor_id(p[0]), ) def program_processor(request): + """ + Fixture creating program processors on-demand for tests. + + Notes: + Check ADR 15 for details on the test-exclusion matrices. + """ backend, _ = request.param backend_id = next_tests.get_processor_id(backend) - """See ADR 15.""" - for marker, skip_mark, msg in exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get(backend_id, []): + for marker, skip_mark, msg in next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get( + backend_id, [] + ): if request.node.get_closest_marker(marker): skip_mark(msg.format(marker=marker, backend=backend_id)) return request.param -@pytest.fixture -def program_processor_no_gtfn_exec(program_processor): - if ( - program_processor[0] == gtfn_cpu.run_gtfn - or program_processor[0] == gtfn_cpu.run_gtfn_imperative - or program_processor[0] == gtfn_cpu.run_gtfn_with_temporaries - ): - pytest.xfail("gtfn backend not yet supported.") - return program_processor - - def run_processor( program: runtime.FendefDispatcher, processor: ppi.ProgramExecutor | ppi.ProgramFormatter, From 0d821b150177d8a805df37887fd74427a751e5af Mon Sep 17 00:00:00 2001 From: ninaburg <83002751+ninaburg@users.noreply.github.com> Date: Thu, 5 Oct 2023 15:16:48 +0200 Subject: [PATCH 05/10] feat[next]: Add support for using Type Aliases (#1335) * Add Type Alias replacement pass + tests * Fix: actual type not added in symbol list if already present * Address requested changes * Pre-commit fixes * Address requested changes * Prevent multiple float32 or float64 definitions in symtable * pre-commit run changes and 'returns' arg type modifications * Use 'from_type_hint' to avoid 'ScalarKind' construct --------- Co-authored-by: Nina Burgdorfer --- .../foast_passes/type_alias_replacement.py | 105 ++++++++++++++++++ src/gt4py/next/ffront/func_to_foast.py | 2 + .../test_type_alias_replacement.py | 44 ++++++++ 3 files changed, 151 insertions(+) create mode 100644 src/gt4py/next/ffront/foast_passes/type_alias_replacement.py create mode 100644 tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py diff --git a/src/gt4py/next/ffront/foast_passes/type_alias_replacement.py b/src/gt4py/next/ffront/foast_passes/type_alias_replacement.py new file mode 100644 index 0000000000..c5857999ee --- /dev/null +++ b/src/gt4py/next/ffront/foast_passes/type_alias_replacement.py @@ -0,0 +1,105 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from dataclasses import dataclass +from typing import Any, cast + +import gt4py.next.ffront.field_operator_ast as foast +from gt4py.eve import NodeTranslator, traits +from gt4py.eve.concepts import SourceLocation, SymbolName, SymbolRef +from gt4py.next.ffront import dialect_ast_enums +from gt4py.next.ffront.fbuiltins import TYPE_BUILTIN_NAMES +from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system.type_translation import from_type_hint + + +@dataclass +class TypeAliasReplacement(NodeTranslator, traits.VisitorWithSymbolTableTrait): + """ + Replace Type Aliases with their actual type. + + After this pass, the type aliases used for explicit construction of literal + values and for casting field values are replaced by their actual types. + """ + + closure_vars: dict[str, Any] + + @classmethod + def apply( + cls, node: foast.FunctionDefinition | foast.FieldOperator, closure_vars: dict[str, Any] + ) -> tuple[foast.FunctionDefinition, dict[str, Any]]: + foast_node = cls(closure_vars=closure_vars).visit(node) + new_closure_vars = closure_vars.copy() + for key, value in closure_vars.items(): + if isinstance(value, type) and key not in TYPE_BUILTIN_NAMES: + new_closure_vars[value.__name__] = closure_vars[key] + return foast_node, new_closure_vars + + def is_type_alias(self, node_id: SymbolName | SymbolRef) -> bool: + return ( + node_id in self.closure_vars + and isinstance(self.closure_vars[node_id], type) + and node_id not in TYPE_BUILTIN_NAMES + ) + + def visit_Name(self, node: foast.Name, **kwargs) -> foast.Name: + if self.is_type_alias(node.id): + return foast.Name( + id=self.closure_vars[node.id].__name__, location=node.location, type=node.type + ) + return node + + def _update_closure_var_symbols( + self, closure_vars: list[foast.Symbol], location: SourceLocation + ) -> list[foast.Symbol]: + new_closure_vars: list[foast.Symbol] = [] + existing_type_names: set[str] = set() + + for var in closure_vars: + if self.is_type_alias(var.id): + actual_type_name = self.closure_vars[var.id].__name__ + # Avoid multiple definitions of a type in closure_vars + if actual_type_name not in existing_type_names: + new_closure_vars.append( + foast.Symbol( + id=actual_type_name, + type=ts.FunctionType( + pos_or_kw_args={}, + kw_only_args={}, + pos_only_args=[ts.DeferredType(constraint=ts.ScalarType)], + returns=cast( + ts.DataType, from_type_hint(self.closure_vars[var.id]) + ), + ), + namespace=dialect_ast_enums.Namespace.CLOSURE, + location=location, + ) + ) + existing_type_names.add(actual_type_name) + elif var.id not in existing_type_names: + new_closure_vars.append(var) + existing_type_names.add(var.id) + + return new_closure_vars + + def visit_FunctionDefinition( + self, node: foast.FunctionDefinition, **kwargs + ) -> foast.FunctionDefinition: + return foast.FunctionDefinition( + id=node.id, + params=node.params, + body=self.visit(node.body, **kwargs), + closure_vars=self._update_closure_var_symbols(node.closure_vars, node.location), + location=node.location, + ) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 082939c938..c7c4c3a23f 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -33,6 +33,7 @@ from gt4py.next.ffront.foast_passes.closure_var_type_deduction import ClosureVarTypeDeduction from gt4py.next.ffront.foast_passes.dead_closure_var_elimination import DeadClosureVarElimination from gt4py.next.ffront.foast_passes.iterable_unpack import UnpackedAssignPass +from gt4py.next.ffront.foast_passes.type_alias_replacement import TypeAliasReplacement from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -91,6 +92,7 @@ def _postprocess_dialect_ast( closure_vars: dict[str, Any], annotations: dict[str, Any], ) -> foast.FunctionDefinition: + foast_node, closure_vars = TypeAliasReplacement.apply(foast_node, closure_vars) foast_node = ClosureVarFolding.apply(foast_node, closure_vars) foast_node = DeadClosureVarElimination.apply(foast_node) foast_node = ClosureVarTypeDeduction.apply(foast_node, closure_vars) diff --git a/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py b/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py new file mode 100644 index 0000000000..e87f869352 --- /dev/null +++ b/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py @@ -0,0 +1,44 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import ast +import typing +from typing import TypeAlias + +import pytest + +import gt4py.next as gtx +from gt4py.next import float32, float64 +from gt4py.next.ffront.fbuiltins import astype +from gt4py.next.ffront.func_to_foast import FieldOperatorParser + + +TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests. +vpfloat: TypeAlias = float32 +wpfloat: TypeAlias = float64 + + +@pytest.mark.parametrize("test_input,expected", [(vpfloat, "float32"), (wpfloat, "float64")]) +def test_type_alias_replacement(test_input, expected): + def fieldop_with_typealias( + a: gtx.Field[[TDim], test_input], b: gtx.Field[[TDim], float32] + ) -> gtx.Field[[TDim], test_input]: + return test_input("3.1418") + astype(a, test_input) + + foast_tree = FieldOperatorParser.apply_to_function(fieldop_with_typealias) + + assert ( + foast_tree.body.stmts[0].value.left.func.id == expected + and foast_tree.body.stmts[0].value.right.args[1].id == expected + ) From 6c69398e576f5b8598f96dad617931dda32f62bf Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 16 Oct 2023 10:56:00 +0200 Subject: [PATCH 06/10] feat[next-dace]: Add support for GPU execution (#1347) This PR adds support for GPU execution in DaCe Backend. Additionally, it also introduces a build cache for each visited ITIR program and corresponding binary DaCe program. --- .../runners/dace_iterator/__init__.py | 101 ++++++++++++++++-- 1 file changed, 91 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index f78d90095c..25609b1035 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -16,6 +16,8 @@ import dace import numpy as np +from dace.codegen.compiled_sdfg import CompiledSDFG +from dace.transformation.auto import auto_optimize as autoopt import gt4py.next.iterator.ir as itir from gt4py.next import common @@ -29,6 +31,14 @@ from .utility import connectivity_identifier, filter_neighbor_tables +""" Default build configuration in DaCe backend """ +_build_type = "Release" +# removing -ffast-math from DaCe default compiler args in order to support isfinite/isinf/isnan built-ins +_cpu_args = ( + "-std=c++14 -fPIC -Wall -Wextra -O3 -march=native -Wno-unused-parameter -Wno-unused-label" +) + + def convert_arg(arg: Any): if common.is_field(arg): sorted_dims = sorted(enumerate(arg.__gt_dims__), key=lambda v: v[1].value) @@ -85,17 +95,67 @@ def get_stride_args( return stride_args +_build_cache_cpu: dict[int, CompiledSDFG] = {} +_build_cache_gpu: dict[int, CompiledSDFG] = {} + + +def get_cache_id(*cache_args) -> int: + return sum([hash(str(arg)) for arg in cache_args]) + + @program_executor def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: + # build parameters + auto_optimize = kwargs.get("auto_optimize", False) + build_type = kwargs.get("build_type", "RelWithDebInfo") + run_on_gpu = kwargs.get("run_on_gpu", False) + build_cache = kwargs.get("build_cache", None) + # ITIR parameters column_axis = kwargs.get("column_axis", None) offset_provider = kwargs["offset_provider"] - neighbor_tables = filter_neighbor_tables(offset_provider) - program = preprocess_program(program, offset_provider) arg_types = [type_translation.from_value(arg) for arg in args] - sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) - sdfg: dace.SDFG = sdfg_genenerator.visit(program) - sdfg.simplify() + neighbor_tables = filter_neighbor_tables(offset_provider) + + cache_id = get_cache_id(program, *arg_types, column_axis) + if build_cache is not None and cache_id in build_cache: + # retrieve SDFG program from build cache + sdfg_program = build_cache[cache_id] + sdfg = sdfg_program.sdfg + else: + # visit ITIR and generate SDFG + program = preprocess_program(program, offset_provider) + sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) + sdfg = sdfg_genenerator.visit(program) + sdfg.simplify() + + # set array storage for GPU execution + if run_on_gpu: + device = dace.DeviceType.GPU + sdfg._name = f"{sdfg.name}_gpu" + for _, _, array in sdfg.arrays_recursive(): + if not array.transient: + array.storage = dace.dtypes.StorageType.GPU_Global + else: + device = dace.DeviceType.CPU + + # run DaCe auto-optimization heuristics + if auto_optimize: + # TODO Investigate how symbol definitions improve autoopt transformations, + # in which case the cache table should take the symbols map into account. + symbols: dict[str, int] = {} + sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols) + + # compile SDFG and retrieve SDFG program + sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" + with dace.config.temporary_config(): + dace.config.Config.set("compiler", "build_type", value=build_type) + dace.config.Config.set("compiler", "cpu", "args", value=_cpu_args) + sdfg_program = sdfg.compile(validate=False) + + # store SDFG program in build cache + if build_cache is not None: + build_cache[cache_id] = sdfg_program dace_args = get_args(program.params, args) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} @@ -105,8 +165,6 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: dace_strides = get_stride_args(sdfg.arrays, dace_field_args) dace_conn_stirdes = get_stride_args(sdfg.arrays, dace_conn_args) - sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" - all_args = { **dace_args, **dace_conn_args, @@ -120,9 +178,32 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: for key, value in all_args.items() if key in sdfg.signature_arglist(with_types=False) } + with dace.config.temporary_config(): dace.config.Config.set("compiler", "allow_view_arguments", value=True) - dace.config.Config.set("compiler", "build_type", value="Debug") - dace.config.Config.set("compiler", "cpu", "args", value="-O0") dace.config.Config.set("frontend", "check_args", value=True) - sdfg(**expected_args) + sdfg_program(**expected_args) + + +@program_executor +def run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: + run_dace_iterator( + program, + *args, + **kwargs, + build_cache=_build_cache_cpu, + build_type=_build_type, + run_on_gpu=False, + ) + + +@program_executor +def run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: + run_dace_iterator( + program, + *args, + **kwargs, + build_cache=_build_cache_gpu, + build_type=_build_type, + run_on_gpu=True, + ) From d07104da19d0e3b467210e9afb940214cddad79a Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 16 Oct 2023 11:07:46 +0200 Subject: [PATCH 07/10] fix[next-dace]: scan_dim consistent with canonical field domain (#1346) The DaCe backend is reordering the dimensions of field domain based on alphabetical order - we call this the canonical representation of field domain. Therefore, array strides, sizes and offsets need to be shuffled, everywhere, to be consistent with the alphabetical order of dimensions. This PR corrects indexing of field domain in get_scan_dim() which was not consistent with the canonical representation. Additional minor edit: * rename map_domain -> map_ranges * replace dace.Memlet() with dace.Memlet.simple() --- .../runners/dace_iterator/itir_to_sdfg.py | 55 ++++++++++--------- .../runners/dace_iterator/itir_to_tasklet.py | 12 ++-- .../runners/dace_iterator/utility.py | 9 ++- 3 files changed, 40 insertions(+), 36 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 56031d8555..2b4ad721b8 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -38,6 +38,7 @@ create_memlet_at, create_memlet_full, filter_neighbor_tables, + get_sorted_dims, map_nested_sdfg_symbols, unique_var_name, ) @@ -79,9 +80,10 @@ def get_scan_dim( - scan_dim_dtype: data type along the scan dimension """ output_type = cast(ts.FieldType, storage_types[output.id]) + sorted_dims = [dim for _, dim in get_sorted_dims(output_type.dims)] return ( column_axis.value, - output_type.dims.index(column_axis), + sorted_dims.index(column_axis), output_type.dtype, ) @@ -246,7 +248,7 @@ def visit_StencilClosure( ) access = closure_init_state.add_access(out_name) value = ValueExpr(access, dtype) - memlet = create_memlet_at(out_name, ("0",)) + memlet = dace.Memlet.simple(out_name, "0") closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet) program_arg_syms[name] = value else: @@ -274,7 +276,7 @@ def visit_StencilClosure( transient_to_arg_name_mapping[nsdfg_output_name] = output_name # scan operator should always be the first function call in a closure if is_scan(node.stencil): - nsdfg, map_domain, scan_dim_index = self._visit_scan_stencil_closure( + nsdfg, map_ranges, scan_dim_index = self._visit_scan_stencil_closure( node, closure_sdfg.arrays, closure_domain, nsdfg_output_name ) results = [nsdfg_output_name] @@ -294,13 +296,13 @@ def visit_StencilClosure( output_name, tuple( f"i_{dim}" - if f"i_{dim}" in map_domain + if f"i_{dim}" in map_ranges else f"0:{output_descriptor.shape[scan_dim_index]}" for dim, _ in closure_domain ), ) else: - nsdfg, map_domain, results = self._visit_parallel_stencil_closure( + nsdfg, map_ranges, results = self._visit_parallel_stencil_closure( node, closure_sdfg.arrays, closure_domain ) assert len(results) == 1 @@ -313,7 +315,7 @@ def visit_StencilClosure( transient=True, ) - output_memlet = create_memlet_at(output_name, tuple(idx for idx in map_domain.keys())) + output_memlet = create_memlet_at(output_name, tuple(idx for idx in map_ranges.keys())) input_mapping = {param: arg for param, arg in zip(input_names, input_memlets)} output_mapping = {param: arg_memlet for param, arg_memlet in zip(results, [output_memlet])} @@ -325,7 +327,7 @@ def visit_StencilClosure( nsdfg_node, map_entry, map_exit = add_mapped_nested_sdfg( closure_state, sdfg=nsdfg, - map_ranges=map_domain or {"__dummy": "0"}, + map_ranges=map_ranges or {"__dummy": "0"}, inputs=array_mapping, outputs=output_mapping, symbol_mapping=symbol_mapping, @@ -341,10 +343,10 @@ def visit_StencilClosure( edge.src_conn, transient_access, None, - dace.Memlet(data=memlet.data, subset=output_subset), + dace.Memlet.simple(memlet.data, output_subset), ) - inner_memlet = dace.Memlet( - data=memlet.data, subset=output_subset, other_subset=memlet.subset + inner_memlet = dace.Memlet.simple( + memlet.data, output_subset, other_subset_str=memlet.subset ) closure_state.add_edge(transient_access, None, map_exit, edge.dst_conn, inner_memlet) closure_state.remove_edge(edge) @@ -360,7 +362,7 @@ def visit_StencilClosure( None, map_entry, b.value.data, - create_memlet_at(b.value.data, ("0",)), + dace.Memlet.simple(b.value.data, "0"), ) return closure_sdfg @@ -390,12 +392,12 @@ def _visit_scan_stencil_closure( connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] # find the scan dimension, same as output dimension, and exclude it from the map domain - map_domain = {} + map_ranges = {} for dim, (lb, ub) in closure_domain: lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value if not dim == scan_dim: - map_domain[f"i_{dim}"] = f"{lb_str}:{ub_str}" + map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" else: scan_lb_str = lb_str scan_ub_str = ub_str @@ -481,29 +483,28 @@ def _visit_scan_stencil_closure( "__result", carry_node1, None, - dace.Memlet(data=f"{scan_carry_name}", subset="0"), + dace.Memlet.simple(scan_carry_name, "0"), ) carry_node2 = lambda_state.add_access(scan_carry_name) lambda_state.add_memlet_path( carry_node2, scan_inner_node, - memlet=dace.Memlet(data=f"{scan_carry_name}", subset="0"), + memlet=dace.Memlet.simple(scan_carry_name, "0"), src_conn=None, dst_conn=lambda_carry_name, ) # connect access nodes to lambda inputs for (inner_name, _), data_name in zip(lambda_inputs[1:], input_names): - data_subset = ( - ", ".join([f"i_{dim}" for dim, _ in closure_domain]) - if isinstance(self.storage_types[data_name], ts.FieldType) - else "0" - ) + if isinstance(self.storage_types[data_name], ts.FieldType): + memlet = create_memlet_at(data_name, tuple(f"i_{dim}" for dim, _ in closure_domain)) + else: + memlet = dace.Memlet.simple(data_name, "0") lambda_state.add_memlet_path( lambda_state.add_access(data_name), scan_inner_node, - memlet=dace.Memlet(data=f"{data_name}", subset=data_subset), + memlet=memlet, src_conn=None, dst_conn=inner_name, ) @@ -532,7 +533,7 @@ def _visit_scan_stencil_closure( lambda_state.add_memlet_path( scan_inner_node, lambda_state.add_access(data_name), - memlet=dace.Memlet(data=data_name, subset=f"i_{scan_dim}"), + memlet=dace.Memlet.simple(data_name, f"i_{scan_dim}"), src_conn=lambda_connector.value.label, dst_conn=None, ) @@ -544,10 +545,10 @@ def _visit_scan_stencil_closure( lambda_update_state.add_memlet_path( result_node, carry_node3, - memlet=dace.Memlet(data=f"{output_names[0]}", subset=f"i_{scan_dim}", other_subset="0"), + memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"), ) - return scan_sdfg, map_domain, scan_dim_index + return scan_sdfg, map_ranges, scan_dim_index def _visit_parallel_stencil_closure( self, @@ -562,11 +563,11 @@ def _visit_parallel_stencil_closure( conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] # find the scan dimension, same as output dimension, and exclude it from the map domain - map_domain = {} + map_ranges = {} for dim, (lb, ub) in closure_domain: lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value - map_domain[f"i_{dim}"] = f"{lb_str}:{ub_str}" + map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" # Create an SDFG for the tasklet that computes a single item of the output domain. index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} @@ -583,7 +584,7 @@ def _visit_parallel_stencil_closure( self.node_types, ) - return context.body, map_domain, [r.value.data for r in results] + return context.body, map_ranges, [r.value.data for r in results] def _visit_domain( self, node: itir.FunCall, context: Context diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 2e7a598d9a..d3bfb5ff0e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -34,7 +34,6 @@ add_mapped_nested_sdfg, as_dace_type, connectivity_identifier, - create_memlet_at, create_memlet_full, filter_neighbor_tables, map_nested_sdfg_symbols, @@ -595,7 +594,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: ) # if dim is not found in iterator indices, we take the neighbor index over the reduction domain - array_index = [ + flat_index = [ f"{iterator.indices[dim].data}_v" if dim in iterator.indices else index_name for dim in sorted(iterator.dimensions) ] @@ -608,7 +607,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: name="deref", inputs=set(internals), outputs={"__result"}, - code=f"__result = {args[0].value.data}_v[{', '.join(array_index)}]", + code=f"__result = {args[0].value.data}_v[{', '.join(flat_index)}]", ) for arg, internal in zip(args, internals): @@ -634,8 +633,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: flat_index = [ ValueExpr(x[1], iterator.dtype) for x in sorted_index if x[0] in iterator.dimensions ] - - args = [ValueExpr(iterator.field, int), *flat_index] + args = [ValueExpr(iterator.field, iterator.dtype), *flat_index] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{', '.join(internals[1:])}]" return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") @@ -849,7 +847,7 @@ def _visit_reduce(self, node: itir.FunCall): p.apply_pass(lambda_context.body, {}) input_memlets = [ - create_memlet_at(expr.value.data, ("__idx",)) for arg, expr in zip(node.args, args) + dace.Memlet.simple(expr.value.data, "__idx") for arg, expr in zip(node.args, args) ] output_memlet = dace.Memlet.simple(result_name, "0") @@ -928,7 +926,7 @@ def add_expr_tasklet( ) self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet) - memlet = create_memlet_at(result_access.data, ("0",)) + memlet = dace.Memlet.simple(result_access.data, "0") self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet) return [ValueExpr(result_access, result_type)] diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 889a1ab150..7e6fe13ac7 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -12,10 +12,11 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any +from typing import Any, Sequence import dace +from gt4py.next import Dimension from gt4py.next.iterator.embedded import NeighborTableOffsetProvider from gt4py.next.type_system import type_specifications as ts @@ -49,7 +50,7 @@ def connectivity_identifier(name: str): def create_memlet_full(source_identifier: str, source_array: dace.data.Array): bounds = [(0, size) for size in source_array.shape] subset = ", ".join(f"{lb}:{ub}" for lb, ub in bounds) - return dace.Memlet(data=source_identifier, subset=subset) + return dace.Memlet.simple(source_identifier, subset) def create_memlet_at(source_identifier: str, index: tuple[str, ...]): @@ -57,6 +58,10 @@ def create_memlet_at(source_identifier: str, index: tuple[str, ...]): return dace.Memlet(data=source_identifier, subset=subset) +def get_sorted_dims(dims: Sequence[Dimension]) -> Sequence[tuple[int, Dimension]]: + return sorted(enumerate(dims), key=lambda v: v[1].value) + + def map_nested_sdfg_symbols( parent_sdfg: dace.SDFG, nested_sdfg: dace.SDFG, array_mapping: dict[str, dace.Memlet] ) -> dict[str, str]: From 45a6e6d10939b14580eb9c1c243f5779ccd3dc44 Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 16 Oct 2023 11:54:59 +0200 Subject: [PATCH 08/10] feat[next]: Add DaCe support for field arguments with domain offset (#1348) This PR adds support in DaCe backend for field arguments with domain offset. This feature is required by icon4py stencils. --- .../runners/dace_iterator/__init__.py | 32 +++++++++++++++---- .../runners/dace_iterator/itir_to_sdfg.py | 16 +++++++--- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 25609b1035..18e257d462 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -20,7 +20,7 @@ from dace.transformation.auto import auto_optimize as autoopt import gt4py.next.iterator.ir as itir -from gt4py.next import common +from gt4py.next.common import Domain, UnitRange, is_field from gt4py.next.iterator.embedded import NeighborTableOffsetProvider from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms from gt4py.next.otf.compilation import cache @@ -28,7 +28,12 @@ from gt4py.next.type_system import type_translation from .itir_to_sdfg import ItirToSDFG -from .utility import connectivity_identifier, filter_neighbor_tables +from .utility import connectivity_identifier, filter_neighbor_tables, get_sorted_dims + + +def get_sorted_dim_ranges(domain: Domain) -> Sequence[UnitRange]: + sorted_dims = get_sorted_dims(domain.dims) + return [domain.ranges[dim_index] for dim_index, _ in sorted_dims] """ Default build configuration in DaCe backend """ @@ -40,10 +45,10 @@ def convert_arg(arg: Any): - if common.is_field(arg): - sorted_dims = sorted(enumerate(arg.__gt_dims__), key=lambda v: v[1].value) + if is_field(arg): + sorted_dims = get_sorted_dims(arg.domain.dims) ndim = len(sorted_dims) - dim_indices = [dim[0] for dim in sorted_dims] + dim_indices = [dim_index for dim_index, _ in sorted_dims] assert isinstance(arg.ndarray, np.ndarray) return np.moveaxis(arg.ndarray, range(ndim), dim_indices) return arg @@ -79,6 +84,17 @@ def get_shape_args( } +def get_offset_args( + arrays: Mapping[str, dace.data.Array], params: Sequence[itir.Sym], args: Sequence[Any] +) -> Mapping[str, int]: + return { + str(sym): -drange.start + for param, arg in zip(params, args) + if is_field(arg) + for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain)) + } + + def get_stride_args( arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any] ) -> Mapping[str, int]: @@ -163,7 +179,8 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: dace_shapes = get_shape_args(sdfg.arrays, dace_field_args) dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) dace_strides = get_stride_args(sdfg.arrays, dace_field_args) - dace_conn_stirdes = get_stride_args(sdfg.arrays, dace_conn_args) + dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args) + dace_offsets = get_offset_args(sdfg.arrays, program.params, args) all_args = { **dace_args, @@ -171,7 +188,8 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: **dace_shapes, **dace_conn_shapes, **dace_strides, - **dace_conn_stirdes, + **dace_conn_strides, + **dace_offsets, } expected_args = { key: value diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 2b4ad721b8..7017815688 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -107,12 +107,17 @@ def __init__( self.offset_provider = offset_provider self.storage_types = {} - def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec): + def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True): if isinstance(type_, ts.FieldType): shape = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] strides = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] + offset = ( + [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] + if has_offset + else None + ) dtype = as_dace_type(type_.dtype) - sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) + sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype) elif isinstance(type_, ts.ScalarType): sdfg.add_symbol(name, as_dace_type(type_)) else: @@ -136,7 +141,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): scalar_kind = type_translation.get_scalar_kind(table.table.dtype) local_dim = Dimension("ElementDim", kind=DimensionKind.LOCAL) type_ = ts.FieldType([table.origin_axis, local_dim], ts.ScalarType(scalar_kind)) - self.add_storage(program_sdfg, connectivity_identifier(offset), type_) + self.add_storage(program_sdfg, connectivity_identifier(offset), type_, has_offset=False) # Create a nested SDFG for all stencil closures. for closure in node.closures: @@ -287,8 +292,8 @@ def visit_StencilClosure( closure_sdfg.add_array( nsdfg_output_name, dtype=output_descriptor.dtype, - shape=(array_table[output_name].shape[scan_dim_index],), - strides=(array_table[output_name].strides[scan_dim_index],), + shape=(output_descriptor.shape[scan_dim_index],), + strides=(output_descriptor.strides[scan_dim_index],), transient=True, ) @@ -528,6 +533,7 @@ def _visit_scan_stencil_closure( data_name, shape=(array_table[node.output.id].shape[scan_dim_index],), strides=(array_table[node.output.id].strides[scan_dim_index],), + offset=(array_table[node.output.id].offset[scan_dim_index],), dtype=array_table[node.output.id].dtype, ) lambda_state.add_memlet_path( From 90eea30a000d8b1780bb65457a73e9839e5b3e93 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 17 Oct 2023 10:11:49 +0200 Subject: [PATCH 09/10] feat[next]: DaCe support for neighbor strided offset (#1344) This PR adds support for neighbor strided offset in DaCe backend, another ITIR feature needed by icon4py stencils. The design choice has been to extract max_neighbors from offset_provider at compile-time and hard-code it in the SDFG. Additionally, the hash function to check the SDFG binary cache is modified to use SHA256, in order to reduce collision risk. --- .../runners/dace_iterator/__init__.py | 46 ++++++++++++++----- .../runners/dace_iterator/itir_to_tasklet.py | 35 +++++++++----- tests/next_tests/exclusion_matrices.py | 12 +++-- 3 files changed, 67 insertions(+), 26 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 18e257d462..1c1bed9c5e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -11,8 +11,8 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - -from typing import Any, Mapping, Sequence +import hashlib +from typing import Any, Mapping, Optional, Sequence import dace import numpy as np @@ -20,12 +20,12 @@ from dace.transformation.auto import auto_optimize as autoopt import gt4py.next.iterator.ir as itir -from gt4py.next.common import Domain, UnitRange, is_field -from gt4py.next.iterator.embedded import NeighborTableOffsetProvider +from gt4py.next.common import Dimension, Domain, UnitRange, is_field +from gt4py.next.iterator.embedded import NeighborTableOffsetProvider, StridedNeighborOffsetProvider from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms from gt4py.next.otf.compilation import cache from gt4py.next.program_processors.processor_interface import program_executor -from gt4py.next.type_system import type_translation +from gt4py.next.type_system import type_specifications as ts, type_translation from .itir_to_sdfg import ItirToSDFG from .utility import connectivity_identifier, filter_neighbor_tables, get_sorted_dims @@ -111,12 +111,34 @@ def get_stride_args( return stride_args -_build_cache_cpu: dict[int, CompiledSDFG] = {} -_build_cache_gpu: dict[int, CompiledSDFG] = {} - - -def get_cache_id(*cache_args) -> int: - return sum([hash(str(arg)) for arg in cache_args]) +_build_cache_cpu: dict[str, CompiledSDFG] = {} +_build_cache_gpu: dict[str, CompiledSDFG] = {} + + +def get_cache_id( + program: itir.FencilDefinition, + arg_types: Sequence[ts.TypeSpec], + column_axis: Optional[Dimension], + offset_provider: Mapping[str, Any], +) -> str: + max_neighbors = [ + (k, v.max_neighbors) + for k, v in offset_provider.items() + if isinstance(v, (NeighborTableOffsetProvider, StridedNeighborOffsetProvider)) + ] + cache_id_args = [ + str(arg) + for arg in ( + program, + *arg_types, + column_axis, + *max_neighbors, + ) + ] + m = hashlib.sha256() + for s in cache_id_args: + m.update(s.encode()) + return m.hexdigest() @program_executor @@ -133,7 +155,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: arg_types = [type_translation.from_value(arg) for arg in args] neighbor_tables = filter_neighbor_tables(offset_provider) - cache_id = get_cache_id(program, *arg_types, column_axis) + cache_id = get_cache_id(program, arg_types, column_axis, offset_provider) if build_cache is not None and cache_id in build_cache: # retrieve SDFG program from build cache sdfg_program = build_cache[cache_id] diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index d3bfb5ff0e..6acc39c50a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -23,7 +23,7 @@ from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols import gt4py.eve.codegen -from gt4py.next import Dimension, type_inference as next_typing +from gt4py.next import Dimension, StridedNeighborOffsetProvider, type_inference as next_typing from gt4py.next.iterator import ir as itir, type_inference as itir_typing from gt4py.next.iterator.embedded import NeighborTableOffsetProvider from gt4py.next.iterator.ir import FunCall, Lambda @@ -700,18 +700,31 @@ def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr: element = tail[1].value assert isinstance(element, int) - table: NeighborTableOffsetProvider = self.offset_provider[offset] - shifted_dim = table.origin_axis.value - target_dim = table.neighbor_axis.value + if isinstance(self.offset_provider[offset], NeighborTableOffsetProvider): + table = self.offset_provider[offset] + shifted_dim = table.origin_axis.value + target_dim = table.neighbor_axis.value - conn = self.context.state.add_access(connectivity_identifier(offset)) + conn = self.context.state.add_access(connectivity_identifier(offset)) + + args = [ + ValueExpr(conn, table.table.dtype), + ValueExpr(iterator.indices[shifted_dim], dace.int64), + ] + + internals = [f"{arg.value.data}_v" for arg in args] + expr = f"{internals[0]}[{internals[1]}, {element}]" + else: + offset_provider = self.offset_provider[offset] + assert isinstance(offset_provider, StridedNeighborOffsetProvider) + + shifted_dim = offset_provider.origin_axis.value + target_dim = offset_provider.neighbor_axis.value + offset_value = iterator.indices[shifted_dim] + args = [ValueExpr(offset_value, dace.int64)] + internals = [f"{offset_value.data}_v"] + expr = f"{internals[0]} * {offset_provider.max_neighbors} + {element}" - args = [ - ValueExpr(conn, table.table.dtype), - ValueExpr(iterator.indices[shifted_dim], dace.int64), - ] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{internals[1]}, {element}]" shifted_value = self.add_expr_tasklet( list(zip(args, internals)), expr, dace.dtypes.int64, "ind_addr" )[0].value diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index 27ccb29095..98ac9352c3 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -61,7 +61,6 @@ (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), ] #: Skip matrix, contains for each backend processor a list of tuples with following fields: @@ -81,11 +80,18 @@ (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), ], - GTFN_CPU: GTFN_SKIP_TEST_LIST, - GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST, + GTFN_CPU: GTFN_SKIP_TEST_LIST + + [ + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + ], + GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST + + [ + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + ], GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST + [ (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), ], GTFN_FORMAT_SOURCECODE: [ (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), From f96ead5fbb0b7d9edfbe6c6e9c67a08ddb2ee50a Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 17 Oct 2023 14:19:33 +0200 Subject: [PATCH 10/10] fix[next]: DaCe field addressing in builtin_neighbors (#1349) Bugfix in DaCe backend to make field addressing in builtin_neighbors consistent with the canonical representation (field dimensions alphabetically sorted). --- .../runners/dace_iterator/itir_to_tasklet.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 6acc39c50a..610698646a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -243,10 +243,8 @@ def builtin_neighbors( ) # select full shape only in the neighbor-axis dimension field_subset = [ - f"0:{sdfg.arrays[iterator.field.data].shape[idx]}" - if dim == table.neighbor_axis.value - else f"i_{dim}" - for idx, dim in enumerate(iterator.dimensions) + f"0:{shape}" if dim == table.neighbor_axis.value else f"i_{dim}" + for dim, shape in zip(sorted(iterator.dimensions), sdfg.arrays[iterator.field.data].shape) ] state.add_memlet_path( iterator.field, @@ -575,6 +573,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: return iterator args: list[ValueExpr] + sorted_dims = sorted(iterator.dimensions) if self.context.reduce_limit: # we are visiting a child node of reduction, so the neighbor index can be used for indirect addressing result_name = unique_var_name() @@ -596,7 +595,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: # if dim is not found in iterator indices, we take the neighbor index over the reduction domain flat_index = [ f"{iterator.indices[dim].data}_v" if dim in iterator.indices else index_name - for dim in sorted(iterator.dimensions) + for dim in sorted_dims ] args = [ValueExpr(iterator.field, iterator.dtype)] + [ ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.indices @@ -629,11 +628,9 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: return [ValueExpr(value=result_access, dtype=iterator.dtype)] else: - sorted_index = sorted(iterator.indices.items(), key=lambda x: x[0]) - flat_index = [ - ValueExpr(x[1], iterator.dtype) for x in sorted_index if x[0] in iterator.dimensions + args = [ValueExpr(iterator.field, iterator.dtype)] + [ + ValueExpr(iterator.indices[dim], iterator.dtype) for dim in sorted_dims ] - args = [ValueExpr(iterator.field, iterator.dtype), *flat_index] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{', '.join(internals[1:])}]" return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref")