diff --git a/.github/workflows/daily-ci.yml b/.github/workflows/daily-ci.yml
index 77ba39a361..8631390dbb 100644
--- a/.github/workflows/daily-ci.yml
+++ b/.github/workflows/daily-ci.yml
@@ -14,7 +14,7 @@ jobs:
daily-ci:
strategy:
matrix:
- python-version: ["3.8", "3.9", "3.10"]
+ python-version: ["3.8", "3.9", "3.10", "3.11"]
tox-module-factor: ["cartesian", "eve", "next", "storage"]
os: ["ubuntu-latest"]
requirements-file: ["requirements-dev.txt", "min-requirements-test.txt", "min-extra-requirements-test.txt"]
diff --git a/.github/workflows/test-cartesian-fallback.yml b/.github/workflows/test-cartesian-fallback.yml
index b2eaead47a..7e9a948e9c 100644
--- a/.github/workflows/test-cartesian-fallback.yml
+++ b/.github/workflows/test-cartesian-fallback.yml
@@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ["3.8", "3.9", "3.10"]
+ python-version: ["3.8", "3.9", "3.10", "3.11"]
backends: [internal-cpu, dace-cpu]
steps:
diff --git a/.github/workflows/test-cartesian.yml b/.github/workflows/test-cartesian.yml
index 2c2b97aaa6..ebdc4ce749 100644
--- a/.github/workflows/test-cartesian.yml
+++ b/.github/workflows/test-cartesian.yml
@@ -23,7 +23,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ["3.8", "3.9", "3.10"]
+ python-version: ["3.8", "3.9", "3.10", "3.11"]
backends: [internal-cpu, dace-cpu]
steps:
- uses: actions/checkout@v2
diff --git a/.github/workflows/test-eve-fallback.yml b/.github/workflows/test-eve-fallback.yml
index 93dc308a53..fd7ab5452c 100644
--- a/.github/workflows/test-eve-fallback.yml
+++ b/.github/workflows/test-eve-fallback.yml
@@ -17,7 +17,7 @@ jobs:
test-eve:
strategy:
matrix:
- python-version: ["3.8", "3.9", "3.10"]
+ python-version: ["3.8", "3.9", "3.10", "3.11"]
os: ["ubuntu-latest"]
runs-on: ${{ matrix.os }}
diff --git a/.github/workflows/test-eve.yml b/.github/workflows/test-eve.yml
index 1322c573db..222b825f38 100644
--- a/.github/workflows/test-eve.yml
+++ b/.github/workflows/test-eve.yml
@@ -20,7 +20,7 @@ jobs:
test-eve:
strategy:
matrix:
- python-version: ["3.8", "3.9", "3.10"]
+ python-version: ["3.8", "3.9", "3.10", "3.11"]
os: ["ubuntu-latest"]
fail-fast: false
@@ -68,4 +68,3 @@ jobs:
# with:
# name: info-py${{ matrix.python-version }}-${{ matrix.os }}
# path: info.txt
-
diff --git a/.github/workflows/test-next-fallback.yml b/.github/workflows/test-next-fallback.yml
index 8490a3e393..bdcc061db0 100644
--- a/.github/workflows/test-next-fallback.yml
+++ b/.github/workflows/test-next-fallback.yml
@@ -15,7 +15,7 @@ jobs:
test-next:
strategy:
matrix:
- python-version: ["3.10"]
+ python-version: ["3.10", "3.11"]
tox-env-factor: ["nomesh", "atlas"]
os: ["ubuntu-latest"]
diff --git a/.github/workflows/test-next.yml b/.github/workflows/test-next.yml
index 52f8c25386..4282a22da6 100644
--- a/.github/workflows/test-next.yml
+++ b/.github/workflows/test-next.yml
@@ -18,7 +18,7 @@ jobs:
test-next:
strategy:
matrix:
- python-version: ["3.10"]
+ python-version: ["3.10", "3.11"]
tox-env-factor: ["nomesh", "atlas"]
os: ["ubuntu-latest"]
fail-fast: false
diff --git a/.github/workflows/test-storage-fallback.yml b/.github/workflows/test-storage-fallback.yml
index 0cbc735564..99e4923de8 100644
--- a/.github/workflows/test-storage-fallback.yml
+++ b/.github/workflows/test-storage-fallback.yml
@@ -18,7 +18,7 @@ jobs:
test-storage:
strategy:
matrix:
- python-version: ["3.8", "3.9", "3.10"]
+ python-version: ["3.8", "3.9", "3.10", "3.11"]
backends: [internal-cpu, dace-cpu]
os: ["ubuntu-latest"]
diff --git a/.github/workflows/test-storage.yml b/.github/workflows/test-storage.yml
index 1133353f30..34841ed71c 100644
--- a/.github/workflows/test-storage.yml
+++ b/.github/workflows/test-storage.yml
@@ -21,7 +21,7 @@ jobs:
test-storage:
strategy:
matrix:
- python-version: ["3.8", "3.9", "3.10"]
+ python-version: ["3.8", "3.9", "3.10", "3.11"]
backends: [internal-cpu, dace-cpu]
os: ["ubuntu-latest"]
fail-fast: false
@@ -70,4 +70,3 @@ jobs:
# with:
# name: info-py${{ matrix.python-version }}-${{ matrix.os }}
# path: info.txt
-
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index b1092fafd0..d9cfa0ff48 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -62,7 +62,7 @@ repos:
## version = re.search('black==([0-9\.]*)', open("constraints.txt").read())[1]
## print(f"rev: '{version}' # version from constraints.txt")
##]]]
- rev: '23.9.1' # version from constraints.txt
+ rev: '23.11.0' # version from constraints.txt
##[[[end]]]
hooks:
- id: black
@@ -73,7 +73,7 @@ repos:
## version = re.search('isort==([0-9\.]*)', open("constraints.txt").read())[1]
## print(f"rev: '{version}' # version from constraints.txt")
##]]]
- rev: '5.12.0' # version from constraints.txt
+ rev: '5.13.0' # version from constraints.txt
##[[[end]]]
hooks:
- id: isort
@@ -97,14 +97,14 @@ repos:
## print(f"- {pkg}==" + str(re.search(f'\n{pkg}==([0-9\.]*)', constraints)[1]))
##]]]
- darglint==1.8.1
- - flake8-bugbear==23.9.16
- - flake8-builtins==2.1.0
+ - flake8-bugbear==23.12.2
+ - flake8-builtins==2.2.0
- flake8-debugger==4.1.2
- flake8-docstrings==1.7.0
- flake8-eradicate==1.5.0
- flake8-mutable==1.2.0
- flake8-pyproject==1.2.3
- - pygments==2.16.1
+ - pygments==2.17.2
##[[[end]]]
# - flake8-rst-docstrings # Disabled for now due to random false positives
exclude: |
@@ -146,9 +146,9 @@ repos:
## version = re.search('mypy==([0-9\.]*)', open("constraints.txt").read())[1]
## print(f"#========= FROM constraints.txt: v{version} =========")
##]]]
- #========= FROM constraints.txt: v1.5.1 =========
+ #========= FROM constraints.txt: v1.7.1 =========
##[[[end]]]
- rev: v1.5.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date)
+ rev: v1.7.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date)
hooks:
- id: mypy
additional_dependencies: # versions from constraints.txt
@@ -162,26 +162,26 @@ repos:
##]]]
- astunparse==1.6.3
- attrs==23.1.0
- - black==23.9.1
- - boltons==23.0.0
+ - black==23.11.0
+ - boltons==23.1.1
- cached-property==1.5.2
- click==8.1.7
- - cmake==3.27.5
+ - cmake==3.27.9
- cytoolz==0.12.2
- - deepdiff==6.5.0
+ - deepdiff==6.7.1
- devtools==0.12.2
- - frozendict==2.3.8
+ - frozendict==2.3.10
- gridtools-cpp==2.3.1
- - importlib-resources==6.0.1
+ - importlib-resources==6.1.1
- jinja2==3.1.2
- - lark==1.1.7
- - mako==1.2.4
- - nanobind==1.5.2
- - ninja==1.11.1
+ - lark==1.1.8
+ - mako==1.3.0
+ - nanobind==1.8.0
+ - ninja==1.11.1.1
- numpy==1.24.4
- - packaging==23.1
+ - packaging==23.2
- pybind11==2.11.1
- - setuptools==68.2.2
+ - setuptools==69.0.2
- tabulate==0.9.0
- typing-extensions==4.5.0
- xxhash==3.0.0
diff --git a/AUTHORS.md b/AUTHORS.md
index 89aafb9971..6c76e5759e 100644
--- a/AUTHORS.md
+++ b/AUTHORS.md
@@ -9,6 +9,7 @@
- Deconinck, Florian. SSAI/NASA-GSFC
- Ehrengruber, Till. ETH Zurich - CSCS
- Elbert, Oliver D. NOAA-GFDL
+- Faghih-Naini, Sara. ECMWF
- Farabullini, Nicoletta. ETH Zurich - C2SM
- George, Rhea. Allen Institute for AI
- González Paredes, Enrique. ETH Zurich - CSCS
diff --git a/constraints.txt b/constraints.txt
index b334851af1..81abd64c6e 100644
--- a/constraints.txt
+++ b/constraints.txt
@@ -6,124 +6,136 @@
#
aenum==3.1.15 # via dace
alabaster==0.7.13 # via sphinx
-asttokens==2.4.0 # via devtools
+asttokens==2.4.1 # via devtools
astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml)
attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing
-babel==2.12.1 # via sphinx
-black==23.9.1 # via gt4py (pyproject.toml)
-blinker==1.6.2 # via flask
-boltons==23.0.0 # via gt4py (pyproject.toml)
+babel==2.13.1 # via sphinx
+black==23.11.0 # via gt4py (pyproject.toml)
+blinker==1.7.0 # via flask
+boltons==23.1.1 # via gt4py (pyproject.toml)
build==1.0.3 # via pip-tools
cached-property==1.5.2 # via gt4py (pyproject.toml)
-cachetools==5.3.1 # via tox
-certifi==2023.7.22 # via requests
-cffi==1.15.1 # via cryptography
+cachetools==5.3.2 # via tox
+cerberus==1.3.5 # via plette
+certifi==2023.11.17 # via requests
+cffi==1.16.0 # via cryptography
cfgv==3.4.0 # via pre-commit
chardet==5.2.0 # via tox
-charset-normalizer==3.2.0 # via requests
-clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml)
+charset-normalizer==3.3.2 # via requests
+clang-format==17.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml)
click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools
-cmake==3.27.5 # via gt4py (pyproject.toml)
+cmake==3.27.9 # via dace, gt4py (pyproject.toml)
cogapp==3.3.0 # via -r requirements-dev.in
colorama==0.4.6 # via tox
-coverage==7.3.1 # via -r requirements-dev.in, pytest-cov
-cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis
+coverage==7.3.2 # via -r requirements-dev.in, pytest-cov
+cryptography==41.0.7 # via types-paramiko, types-pyopenssl, types-redis
cytoolz==0.12.2 # via gt4py (pyproject.toml)
-dace==0.14.4 # via gt4py (pyproject.toml)
+dace==0.15.1 # via gt4py (pyproject.toml)
darglint==1.8.1 # via -r requirements-dev.in
-deepdiff==6.5.0 # via gt4py (pyproject.toml)
+deepdiff==6.7.1 # via gt4py (pyproject.toml)
devtools==0.12.2 # via gt4py (pyproject.toml)
dill==0.3.7 # via dace
-distlib==0.3.7 # via virtualenv
-docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme
+distlib==0.3.7 # via requirementslib, virtualenv
+distro==1.8.0 # via scikit-build
+docopt==0.6.2 # via pipreqs
+docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme
eradicate==2.3.0 # via flake8-eradicate
-exceptiongroup==1.1.3 # via hypothesis, pytest
+exceptiongroup==1.2.0 # via hypothesis, pytest
execnet==2.0.2 # via pytest-cache, pytest-xdist
-executing==1.2.0 # via devtools
+executing==2.0.1 # via devtools
factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy
-faker==19.6.1 # via factory-boy
-fastjsonschema==2.18.0 # via nbformat
-filelock==3.12.4 # via tox, virtualenv
+faker==20.1.0 # via factory-boy
+fastjsonschema==2.19.0 # via nbformat
+filelock==3.13.1 # via tox, virtualenv
flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings
-flake8-bugbear==23.9.16 # via -r requirements-dev.in
-flake8-builtins==2.1.0 # via -r requirements-dev.in
+flake8-bugbear==23.12.2 # via -r requirements-dev.in
+flake8-builtins==2.2.0 # via -r requirements-dev.in
flake8-debugger==4.1.2 # via -r requirements-dev.in
flake8-docstrings==1.7.0 # via -r requirements-dev.in
flake8-eradicate==1.5.0 # via -r requirements-dev.in
flake8-mutable==1.2.0 # via -r requirements-dev.in
flake8-pyproject==1.2.3 # via -r requirements-dev.in
flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in
-flask==2.3.3 # via dace
-frozendict==2.3.8 # via gt4py (pyproject.toml)
+flask==3.0.0 # via dace
+fparser==0.1.3 # via dace
+frozendict==2.3.10 # via gt4py (pyproject.toml)
gridtools-cpp==2.3.1 # via gt4py (pyproject.toml)
-hypothesis==6.86.1 # via -r requirements-dev.in, gt4py (pyproject.toml)
-identify==2.5.29 # via pre-commit
-idna==3.4 # via requests
+hypothesis==6.92.0 # via -r requirements-dev.in, gt4py (pyproject.toml)
+identify==2.5.33 # via pre-commit
+idna==3.6 # via requests
imagesize==1.4.1 # via sphinx
-importlib-metadata==6.8.0 # via build, flask, sphinx
-importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications
+importlib-metadata==7.0.0 # via build, flask, fparser, sphinx
+importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications
inflection==0.5.1 # via pytest-factoryboy
iniconfig==2.0.0 # via pytest
-isort==5.12.0 # via -r requirements-dev.in
+isort==5.13.0 # via -r requirements-dev.in
itsdangerous==2.1.2 # via flask
jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx
-jsonschema==4.19.0 # via nbformat
-jsonschema-specifications==2023.7.1 # via jsonschema
-jupyter-core==5.3.1 # via nbformat
-jupytext==1.15.2 # via -r requirements-dev.in
-lark==1.1.7 # via gt4py (pyproject.toml)
-mako==1.2.4 # via gt4py (pyproject.toml)
+jsonschema==4.20.0 # via nbformat
+jsonschema-specifications==2023.11.2 # via jsonschema
+jupyter-core==5.5.0 # via nbformat
+jupytext==1.16.0 # via -r requirements-dev.in
+lark==1.1.8 # via gt4py (pyproject.toml)
+mako==1.3.0 # via gt4py (pyproject.toml)
markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins
markupsafe==2.1.3 # via jinja2, mako, werkzeug
mccabe==0.7.0 # via flake8
mdit-py-plugins==0.4.0 # via jupytext
mdurl==0.1.2 # via markdown-it-py
mpmath==1.3.0 # via sympy
-mypy==1.5.1 # via -r requirements-dev.in
+mypy==1.7.1 # via -r requirements-dev.in
mypy-extensions==1.0.0 # via black, mypy
-nanobind==1.5.2 # via gt4py (pyproject.toml)
+nanobind==1.8.0 # via gt4py (pyproject.toml)
nbformat==5.9.2 # via jupytext
networkx==3.1 # via dace
-ninja==1.11.1 # via gt4py (pyproject.toml)
+ninja==1.11.1.1 # via gt4py (pyproject.toml)
nodeenv==1.8.0 # via pre-commit
numpy==1.24.4 # via dace, gt4py (pyproject.toml), types-jack-client
ordered-set==4.1.0 # via deepdiff
-packaging==23.1 # via black, build, gt4py (pyproject.toml), pyproject-api, pytest, sphinx, tox
-pathspec==0.11.2 # via black
+packaging==23.2 # via black, build, gt4py (pyproject.toml), jupytext, pyproject-api, pytest, scikit-build, setuptools-scm, sphinx, tox
+pathspec==0.12.1 # via black
+pep517==0.13.1 # via requirementslib
+pip-api==0.0.30 # via isort
pip-tools==7.3.0 # via -r requirements-dev.in
-pipdeptree==2.13.0 # via -r requirements-dev.in
+pipdeptree==2.13.1 # via -r requirements-dev.in
+pipreqs==0.4.13 # via isort
pkgutil-resolve-name==1.3.10 # via jsonschema
-platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv
+platformdirs==4.1.0 # via black, jupyter-core, requirementslib, tox, virtualenv
+plette==0.4.4 # via requirementslib
pluggy==1.3.0 # via pytest, tox
ply==3.11 # via dace
-pre-commit==3.4.0 # via -r requirements-dev.in
-psutil==5.9.5 # via -r requirements-dev.in, pytest-xdist
+pre-commit==3.5.0 # via -r requirements-dev.in
+psutil==5.9.6 # via -r requirements-dev.in, pytest-xdist
pybind11==2.11.1 # via gt4py (pyproject.toml)
-pycodestyle==2.11.0 # via flake8, flake8-debugger
+pycodestyle==2.11.1 # via flake8, flake8-debugger
pycparser==2.21 # via cffi
+pydantic==1.10.13 # via requirementslib
pydocstyle==6.3.0 # via flake8-docstrings
pyflakes==3.1.0 # via flake8
-pygments==2.16.1 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx
+pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx
pyproject-api==1.6.1 # via tox
pyproject-hooks==1.0.0 # via build
-pytest==7.4.2 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist
+pytest==7.4.3 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist
pytest-cache==1.0 # via -r requirements-dev.in
pytest-cov==4.1.0 # via -r requirements-dev.in
-pytest-factoryboy==2.5.1 # via -r requirements-dev.in
-pytest-xdist==3.3.1 # via -r requirements-dev.in
+pytest-factoryboy==2.6.0 # via -r requirements-dev.in
+pytest-xdist==3.5.0 # via -r requirements-dev.in
python-dateutil==2.8.2 # via faker
pytz==2023.3.post1 # via babel
pyyaml==6.0.1 # via dace, jupytext, pre-commit
-referencing==0.30.2 # via jsonschema, jsonschema-specifications
-requests==2.31.0 # via dace, sphinx
+referencing==0.32.0 # via jsonschema, jsonschema-specifications
+requests==2.31.0 # via dace, requirementslib, sphinx, yarg
+requirementslib==3.0.0 # via isort
restructuredtext-lint==1.4.0 # via flake8-rst-docstrings
-rpds-py==0.10.3 # via jsonschema, referencing
-ruff==0.0.290 # via -r requirements-dev.in
+rpds-py==0.13.2 # via jsonschema, referencing
+ruff==0.1.7 # via -r requirements-dev.in
+scikit-build==0.17.6 # via dace
+setuptools-scm==8.0.4 # via fparser
six==1.16.0 # via asttokens, astunparse, python-dateutil
snowballstemmer==2.2.0 # via pydocstyle, sphinx
sortedcontainers==2.4.0 # via hypothesis
sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery
-sphinx-rtd-theme==1.3.0 # via -r requirements-dev.in
+sphinx-rtd-theme==2.0.0 # via -r requirements-dev.in
sphinxcontrib-applehelp==1.0.4 # via sphinx
sphinxcontrib-devhelp==1.0.2 # via sphinx
sphinxcontrib-htmlhelp==2.0.1 # via sphinx
@@ -131,31 +143,32 @@ sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme
sphinxcontrib-jsmath==1.0.1 # via sphinx
sphinxcontrib-qthelp==1.0.3 # via sphinx
sphinxcontrib-serializinghtml==1.1.5 # via sphinx
-sympy==1.12 # via dace, gt4py (pyproject.toml)
+sympy==1.9 # via dace, gt4py (pyproject.toml)
tabulate==0.9.0 # via gt4py (pyproject.toml)
toml==0.10.2 # via jupytext
-tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox
+tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pep517, pip-tools, pyproject-api, pyproject-hooks, pytest, scikit-build, setuptools-scm, tox
+tomlkit==0.12.3 # via plette, requirementslib
toolz==0.12.0 # via cytoolz
-tox==4.11.3 # via -r requirements-dev.in
-traitlets==5.10.0 # via jupyter-core, nbformat
+tox==4.11.4 # via -r requirements-dev.in
+traitlets==5.14.0 # via jupyter-core, nbformat
types-aiofiles==23.2.0.0 # via types-all
types-all==1.0.0 # via -r requirements-dev.in
types-annoy==1.17.8.4 # via types-all
types-atomicwrites==1.4.5.1 # via types-all
types-backports==0.1.3 # via types-all
types-backports-abc==0.5.2 # via types-all
-types-bleach==6.0.0.4 # via types-all
+types-bleach==6.1.0.1 # via types-all
types-boto==2.49.18.9 # via types-all
-types-cachetools==5.3.0.6 # via types-all
+types-cachetools==5.3.0.7 # via types-all
types-certifi==2021.10.8.3 # via types-all
-types-cffi==1.15.1.15 # via types-jack-client
+types-cffi==1.16.0.0 # via types-jack-client
types-characteristic==14.3.7 # via types-all
types-chardet==5.0.4.6 # via types-all
types-click==7.1.8 # via types-all, types-flask
-types-click-spinner==0.1.13.5 # via types-all
+types-click-spinner==0.1.13.6 # via types-all
types-colorama==0.4.15.12 # via types-all
types-contextvars==2.4.7.3 # via types-all
-types-croniter==1.4.0.1 # via types-all
+types-croniter==2.0.0.0 # via types-all
types-cryptography==3.3.23.2 # via types-all, types-openssl-python, types-pyjwt
types-dataclasses==0.6.6 # via types-all
types-dateparser==1.1.4.10 # via types-all
@@ -176,44 +189,44 @@ types-futures==3.3.8 # via types-all
types-geoip2==3.0.0 # via types-all
types-ipaddress==1.0.8 # via types-all, types-maxminddb
types-itsdangerous==1.1.6 # via types-all
-types-jack-client==0.5.10.9 # via types-all
+types-jack-client==0.5.10.10 # via types-all
types-jinja2==2.11.9 # via types-all, types-flask
types-kazoo==0.1.3 # via types-all
-types-markdown==3.4.2.10 # via types-all
+types-markdown==3.5.0.3 # via types-all
types-markupsafe==1.1.10 # via types-all, types-jinja2
types-maxminddb==1.5.0 # via types-all, types-geoip2
-types-mock==5.1.0.2 # via types-all
+types-mock==5.1.0.3 # via types-all
types-mypy-extensions==1.0.0.5 # via types-all
types-nmap==0.1.6 # via types-all
types-openssl-python==0.1.3 # via types-all
types-orjson==3.6.2 # via types-all
-types-paramiko==3.3.0.0 # via types-all, types-pysftp
+types-paramiko==3.3.0.2 # via types-all, types-pysftp
types-pathlib2==2.3.0 # via types-all
-types-pillow==10.0.0.3 # via types-all
+types-pillow==10.1.0.2 # via types-all
types-pkg-resources==0.1.3 # via types-all
types-polib==1.2.0.1 # via types-all
-types-protobuf==4.24.0.1 # via types-all
+types-protobuf==4.24.0.4 # via types-all
types-pyaudio==0.2.16.7 # via types-all
types-pycurl==7.45.2.5 # via types-all
types-pyfarmhash==0.3.1.2 # via types-all
types-pyjwt==1.7.1 # via types-all
types-pymssql==2.1.0 # via types-all
types-pymysql==1.1.0.1 # via types-all
-types-pyopenssl==23.2.0.2 # via types-redis
+types-pyopenssl==23.3.0.0 # via types-redis
types-pyrfc3339==1.1.1.5 # via types-all
types-pysftp==0.2.17.6 # via types-all
types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange
types-python-gflags==3.1.7.3 # via types-all
types-python-slugify==8.0.0.3 # via types-all
-types-pytz==2023.3.1.0 # via types-all, types-tzlocal
+types-pytz==2023.3.1.1 # via types-all, types-tzlocal
types-pyvmomi==8.0.0.6 # via types-all
-types-pyyaml==6.0.12.11 # via types-all
-types-redis==4.6.0.6 # via types-all
-types-requests==2.31.0.2 # via types-all
+types-pyyaml==6.0.12.12 # via types-all
+types-redis==4.6.0.11 # via types-all
+types-requests==2.31.0.10 # via types-all
types-retry==0.9.9.4 # via types-all
types-routes==2.5.0 # via types-all
types-scribe==2.0.0 # via types-all
-types-setuptools==68.2.0.0 # via types-cffi
+types-setuptools==69.0.0.0 # via types-cffi
types-simplejson==3.19.0.2 # via types-all
types-singledispatch==4.1.0.0 # via types-all
types-six==1.16.21.9 # via types-all
@@ -222,21 +235,21 @@ types-termcolor==1.1.6.2 # via types-all
types-toml==0.10.8.7 # via types-all
types-tornado==5.1.1 # via types-all
types-typed-ast==1.5.8.7 # via types-all
-types-tzlocal==5.0.1.1 # via types-all
+types-tzlocal==5.1.0.1 # via types-all
types-ujson==5.8.0.1 # via types-all
-types-urllib3==1.26.25.14 # via types-requests
types-waitress==2.1.4.9 # via types-all
types-werkzeug==1.0.9 # via types-all, types-flask
types-xxhash==3.0.5.2 # via types-all
-typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy
-urllib3==2.0.4 # via requests
-virtualenv==20.24.5 # via pre-commit, tox
-websockets==11.0.3 # via dace
-werkzeug==2.3.7 # via flask
-wheel==0.41.2 # via astunparse, pip-tools
+typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pydantic, pytest-factoryboy, setuptools-scm
+urllib3==2.1.0 # via requests, types-requests
+virtualenv==20.25.0 # via pre-commit, tox
+websockets==12.0 # via dace
+werkzeug==3.0.1 # via flask
+wheel==0.42.0 # via astunparse, pip-tools, scikit-build
xxhash==3.0.0 # via gt4py (pyproject.toml)
-zipp==3.16.2 # via importlib-metadata, importlib-resources
+yarg==0.1.9 # via pipreqs
+zipp==3.17.0 # via importlib-metadata, importlib-resources
# The following packages are considered to be unsafe in a requirements file:
-pip==23.2.1 # via pip-tools
-setuptools==68.2.2 # via gt4py (pyproject.toml), nodeenv, pip-tools
+pip==23.3.1 # via pip-api, pip-tools, requirementslib
+setuptools==69.0.2 # via gt4py (pyproject.toml), nodeenv, pip-tools, requirementslib, scikit-build, setuptools-scm
diff --git a/examples/lap_cartesian_vs_next.ipynb b/examples/lap_cartesian_vs_next.ipynb
new file mode 100644
index 0000000000..cb80122570
--- /dev/null
+++ b/examples/lap_cartesian_vs_next.ipynb
@@ -0,0 +1,189 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "GT4Py - GridTools for Python\n",
+ "\n",
+ "Copyright (c) 2014-2023, ETH Zurich\n",
+ "All rights reserved.\n",
+ "\n",
+ "This file is part the GT4Py project and the GridTools framework.\n",
+ "GT4Py is free software: you can redistribute it and/or modify it under\n",
+ "the terms of the GNU General Public License as published by the\n",
+ "Free Software Foundation, either version 3 of the License, or any later\n",
+ "version. See the LICENSE.txt file at the top-level directory of this\n",
+ "distribution for a copy of the license or check .\n",
+ "\n",
+ "SPDX-License-Identifier: GPL-3.0-or-later"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Demonstrates gt4py.cartesian with gt4py.next compatibility"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "\n",
+ "nx = 32\n",
+ "ny = 32\n",
+ "nz = 1\n",
+ "dtype = np.float64"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Storages\n",
+ "--\n",
+ "\n",
+ "We create fields using the gt4py.next constructors. These fields are compatible with gt4py.cartesian when we use \"I\", \"J\", \"K\" as the dimension names."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import gt4py.next as gtx\n",
+ "\n",
+ "allocator = gtx.itir_embedded # should match the executor\n",
+ "# allocator = gtx.gtfn_cpu\n",
+ "# allocator = gtx.gtfn_gpu\n",
+ "\n",
+ "# Note: for gt4py.next, names don't matter, for gt4py.cartesian they have to be \"I\", \"J\", \"K\"\n",
+ "I = gtx.Dimension(\"I\")\n",
+ "J = gtx.Dimension(\"J\")\n",
+ "K = gtx.Dimension(\"K\", kind=gtx.DimensionKind.VERTICAL)\n",
+ "\n",
+ "domain = gtx.domain({I: nx, J: ny, K: nz})\n",
+ "\n",
+ "inp = gtx.as_field(domain, np.fromfunction(lambda x, y, z: x**2+y**2, shape=(nx, ny, nz)), dtype, allocator=allocator)\n",
+ "out_cartesian = gtx.zeros(domain, dtype, allocator=allocator)\n",
+ "out_next = gtx.zeros(domain, dtype, allocator=allocator)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "gt4py.cartesian\n",
+ "--"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import gt4py.cartesian.gtscript as gtscript\n",
+ "\n",
+ "cartesian_backend = \"numpy\"\n",
+ "# cartesian_backend = \"gt:cpu_ifirst\"\n",
+ "# cartesian_backend = \"gt:gpu\"\n",
+ "\n",
+ "@gtscript.stencil(backend=cartesian_backend)\n",
+ "def lap_cartesian(\n",
+ " inp: gtscript.Field[dtype],\n",
+ " out: gtscript.Field[dtype],\n",
+ "):\n",
+ " with computation(PARALLEL), interval(...):\n",
+ " out = -4.0 * inp[0, 0, 0] + inp[-1, 0, 0] + inp[1, 0, 0] + inp[0, -1, 0] + inp[0, 1, 0]\n",
+ "\n",
+ "lap_cartesian(inp=inp, out=out_cartesian, origin=(1, 1, 0), domain=(nx-2, ny-2, nz))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from gt4py.next import Field\n",
+ "\n",
+ "next_backend = gtx.itir_embedded\n",
+ "# next_backend = gtx.gtfn_cpu\n",
+ "# next_backend = gtx.gtfn_gpu\n",
+ "\n",
+ "Ioff = gtx.FieldOffset(\"I\", source=I, target=(I,))\n",
+ "Joff = gtx.FieldOffset(\"J\", source=J, target=(J,))\n",
+ "\n",
+ "@gtx.field_operator\n",
+ "def lap_next(inp: Field[[I, J, K], dtype]) -> Field[[I, J, K], dtype]:\n",
+ " return -4.0 * inp + inp(Ioff[-1]) + inp(Ioff[1]) + inp(Joff[-1]) + inp(Joff[1])\n",
+ "\n",
+ "@gtx.program(backend=next_backend)\n",
+ "def lap_next_program(inp: Field[[I, J, K], dtype], out: Field[[I, J, K], dtype]):\n",
+ " lap_next(inp, out=out[1:-1, 1:-1, :])\n",
+ "\n",
+ "lap_next_program(inp, out_next, offset_provider={\"Ioff\": I, \"Joff\": J})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "assert np.allclose(out_cartesian.asnumpy(), out_next.asnumpy())"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt
index 17709206a0..fd7724bac9 100644
--- a/min-extra-requirements-test.txt
+++ b/min-extra-requirements-test.txt
@@ -25,7 +25,7 @@ cmake==3.22
cogapp==3.3
coverage[toml]==5.0
cytoolz==0.12.0
-dace==0.14.2
+dace==0.15.1
darglint==1.6
deepdiff==5.6.0
devtools==0.6
@@ -70,7 +70,7 @@ scipy==1.7.2
setuptools==65.5.0
sphinx==4.4
sphinx_rtd_theme==1.0
-sympy==1.7
+sympy==1.9
tabulate==0.8.10
tomli==2.0.1
tox==3.2.0
diff --git a/pyproject.toml b/pyproject.toml
index 5d7a2f2cb6..675bdae9d0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -69,15 +69,15 @@ requires-python = '>=3.8'
cuda = ['cupy>=12.0']
cuda11x = ['cupy-cuda11x>=12.0']
cuda12x = ['cupy-cuda12x>=12.0']
-dace = ['dace>=0.14.2,<0.15', 'sympy>=1.7']
+dace = ['dace>=0.15.1,<0.16', 'sympy>=1.9']
formatting = ['clang-format>=9.0']
# Always add all extra packages to 'full' for a simple full gt4py installation
full = [
'clang-format>=9.0',
- 'dace>=0.14.2,<0.15',
+ 'dace>=0.15.1,<0.16',
'hypothesis>=6.0.0',
'pytest>=7.0',
- 'sympy>=1.7',
+ 'sympy>=1.9',
'scipy>=1.7.2',
'jax[cpu]>=0.4.13'
]
diff --git a/requirements-dev.txt b/requirements-dev.txt
index d6dcc12d21..0fa523866f 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -6,124 +6,136 @@
#
aenum==3.1.15 # via dace
alabaster==0.7.13 # via sphinx
-asttokens==2.4.0 # via devtools
+asttokens==2.4.1 # via devtools
astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml)
attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing
-babel==2.12.1 # via sphinx
-black==23.9.1 # via gt4py (pyproject.toml)
-blinker==1.6.2 # via flask
-boltons==23.0.0 # via gt4py (pyproject.toml)
+babel==2.13.1 # via sphinx
+black==23.11.0 # via gt4py (pyproject.toml)
+blinker==1.7.0 # via flask
+boltons==23.1.1 # via gt4py (pyproject.toml)
build==1.0.3 # via pip-tools
cached-property==1.5.2 # via gt4py (pyproject.toml)
-cachetools==5.3.1 # via tox
-certifi==2023.7.22 # via requests
-cffi==1.15.1 # via cryptography
+cachetools==5.3.2 # via tox
+cerberus==1.3.5 # via plette
+certifi==2023.11.17 # via requests
+cffi==1.16.0 # via cryptography
cfgv==3.4.0 # via pre-commit
chardet==5.2.0 # via tox
-charset-normalizer==3.2.0 # via requests
-clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml)
+charset-normalizer==3.3.2 # via requests
+clang-format==17.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml)
click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools
-cmake==3.27.5 # via gt4py (pyproject.toml)
+cmake==3.27.9 # via dace, gt4py (pyproject.toml)
cogapp==3.3.0 # via -r requirements-dev.in
colorama==0.4.6 # via tox
-coverage[toml]==7.3.1 # via -r requirements-dev.in, pytest-cov
-cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis
+coverage[toml]==7.3.2 # via -r requirements-dev.in, pytest-cov
+cryptography==41.0.7 # via types-paramiko, types-pyopenssl, types-redis
cytoolz==0.12.2 # via gt4py (pyproject.toml)
-dace==0.14.4 # via gt4py (pyproject.toml)
+dace==0.15.1 # via gt4py (pyproject.toml)
darglint==1.8.1 # via -r requirements-dev.in
-deepdiff==6.5.0 # via gt4py (pyproject.toml)
+deepdiff==6.7.1 # via gt4py (pyproject.toml)
devtools==0.12.2 # via gt4py (pyproject.toml)
dill==0.3.7 # via dace
-distlib==0.3.7 # via virtualenv
-docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme
+distlib==0.3.7 # via requirementslib, virtualenv
+distro==1.8.0 # via scikit-build
+docopt==0.6.2 # via pipreqs
+docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme
eradicate==2.3.0 # via flake8-eradicate
-exceptiongroup==1.1.3 # via hypothesis, pytest
+exceptiongroup==1.2.0 # via hypothesis, pytest
execnet==2.0.2 # via pytest-cache, pytest-xdist
-executing==1.2.0 # via devtools
+executing==2.0.1 # via devtools
factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy
-faker==19.6.1 # via factory-boy
-fastjsonschema==2.18.0 # via nbformat
-filelock==3.12.4 # via tox, virtualenv
+faker==20.1.0 # via factory-boy
+fastjsonschema==2.19.0 # via nbformat
+filelock==3.13.1 # via tox, virtualenv
flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings
-flake8-bugbear==23.9.16 # via -r requirements-dev.in
-flake8-builtins==2.1.0 # via -r requirements-dev.in
+flake8-bugbear==23.12.2 # via -r requirements-dev.in
+flake8-builtins==2.2.0 # via -r requirements-dev.in
flake8-debugger==4.1.2 # via -r requirements-dev.in
flake8-docstrings==1.7.0 # via -r requirements-dev.in
flake8-eradicate==1.5.0 # via -r requirements-dev.in
flake8-mutable==1.2.0 # via -r requirements-dev.in
flake8-pyproject==1.2.3 # via -r requirements-dev.in
flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in
-flask==2.3.3 # via dace
-frozendict==2.3.8 # via gt4py (pyproject.toml)
+flask==3.0.0 # via dace
+fparser==0.1.3 # via dace
+frozendict==2.3.10 # via gt4py (pyproject.toml)
gridtools-cpp==2.3.1 # via gt4py (pyproject.toml)
-hypothesis==6.86.1 # via -r requirements-dev.in, gt4py (pyproject.toml)
-identify==2.5.29 # via pre-commit
-idna==3.4 # via requests
+hypothesis==6.92.0 # via -r requirements-dev.in, gt4py (pyproject.toml)
+identify==2.5.33 # via pre-commit
+idna==3.6 # via requests
imagesize==1.4.1 # via sphinx
-importlib-metadata==6.8.0 # via build, flask, sphinx
-importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications
+importlib-metadata==7.0.0 # via build, flask, fparser, sphinx
+importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications
inflection==0.5.1 # via pytest-factoryboy
iniconfig==2.0.0 # via pytest
-isort==5.12.0 # via -r requirements-dev.in
+isort==5.13.0 # via -r requirements-dev.in
itsdangerous==2.1.2 # via flask
jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx
-jsonschema==4.19.0 # via nbformat
-jsonschema-specifications==2023.7.1 # via jsonschema
-jupyter-core==5.3.1 # via nbformat
-jupytext==1.15.2 # via -r requirements-dev.in
-lark==1.1.7 # via gt4py (pyproject.toml)
-mako==1.2.4 # via gt4py (pyproject.toml)
+jsonschema==4.20.0 # via nbformat
+jsonschema-specifications==2023.11.2 # via jsonschema
+jupyter-core==5.5.0 # via nbformat
+jupytext==1.16.0 # via -r requirements-dev.in
+lark==1.1.8 # via gt4py (pyproject.toml)
+mako==1.3.0 # via gt4py (pyproject.toml)
markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins
markupsafe==2.1.3 # via jinja2, mako, werkzeug
mccabe==0.7.0 # via flake8
mdit-py-plugins==0.4.0 # via jupytext
mdurl==0.1.2 # via markdown-it-py
mpmath==1.3.0 # via sympy
-mypy==1.5.1 # via -r requirements-dev.in
+mypy==1.7.1 # via -r requirements-dev.in
mypy-extensions==1.0.0 # via black, mypy
-nanobind==1.5.2 # via gt4py (pyproject.toml)
+nanobind==1.8.0 # via gt4py (pyproject.toml)
nbformat==5.9.2 # via jupytext
networkx==3.1 # via dace
-ninja==1.11.1 # via gt4py (pyproject.toml)
+ninja==1.11.1.1 # via gt4py (pyproject.toml)
nodeenv==1.8.0 # via pre-commit
numpy==1.24.4 # via dace, gt4py (pyproject.toml), types-jack-client
ordered-set==4.1.0 # via deepdiff
-packaging==23.1 # via black, build, gt4py (pyproject.toml), pyproject-api, pytest, sphinx, tox
-pathspec==0.11.2 # via black
+packaging==23.2 # via black, build, gt4py (pyproject.toml), jupytext, pyproject-api, pytest, scikit-build, setuptools-scm, sphinx, tox
+pathspec==0.12.1 # via black
+pep517==0.13.1 # via requirementslib
+pip-api==0.0.30 # via isort
pip-tools==7.3.0 # via -r requirements-dev.in
-pipdeptree==2.13.0 # via -r requirements-dev.in
+pipdeptree==2.13.1 # via -r requirements-dev.in
+pipreqs==0.4.13 # via isort
pkgutil-resolve-name==1.3.10 # via jsonschema
-platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv
+platformdirs==4.1.0 # via black, jupyter-core, requirementslib, tox, virtualenv
+plette[validation]==0.4.4 # via requirementslib
pluggy==1.3.0 # via pytest, tox
ply==3.11 # via dace
-pre-commit==3.4.0 # via -r requirements-dev.in
-psutil==5.9.5 # via -r requirements-dev.in, pytest-xdist
+pre-commit==3.5.0 # via -r requirements-dev.in
+psutil==5.9.6 # via -r requirements-dev.in, pytest-xdist
pybind11==2.11.1 # via gt4py (pyproject.toml)
-pycodestyle==2.11.0 # via flake8, flake8-debugger
+pycodestyle==2.11.1 # via flake8, flake8-debugger
pycparser==2.21 # via cffi
+pydantic==1.10.13 # via requirementslib
pydocstyle==6.3.0 # via flake8-docstrings
pyflakes==3.1.0 # via flake8
-pygments==2.16.1 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx
+pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx
pyproject-api==1.6.1 # via tox
pyproject-hooks==1.0.0 # via build
-pytest==7.4.2 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist
+pytest==7.4.3 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist
pytest-cache==1.0 # via -r requirements-dev.in
pytest-cov==4.1.0 # via -r requirements-dev.in
-pytest-factoryboy==2.5.1 # via -r requirements-dev.in
-pytest-xdist[psutil]==3.3.1 # via -r requirements-dev.in
+pytest-factoryboy==2.6.0 # via -r requirements-dev.in
+pytest-xdist[psutil]==3.5.0 # via -r requirements-dev.in
python-dateutil==2.8.2 # via faker
pytz==2023.3.post1 # via babel
pyyaml==6.0.1 # via dace, jupytext, pre-commit
-referencing==0.30.2 # via jsonschema, jsonschema-specifications
-requests==2.31.0 # via dace, sphinx
+referencing==0.32.0 # via jsonschema, jsonschema-specifications
+requests==2.31.0 # via dace, requirementslib, sphinx, yarg
+requirementslib==3.0.0 # via isort
restructuredtext-lint==1.4.0 # via flake8-rst-docstrings
-rpds-py==0.10.3 # via jsonschema, referencing
-ruff==0.0.290 # via -r requirements-dev.in
+rpds-py==0.13.2 # via jsonschema, referencing
+ruff==0.1.7 # via -r requirements-dev.in
+scikit-build==0.17.6 # via dace
+setuptools-scm==8.0.4 # via fparser
six==1.16.0 # via asttokens, astunparse, python-dateutil
snowballstemmer==2.2.0 # via pydocstyle, sphinx
sortedcontainers==2.4.0 # via hypothesis
sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery
-sphinx-rtd-theme==1.3.0 # via -r requirements-dev.in
+sphinx-rtd-theme==2.0.0 # via -r requirements-dev.in
sphinxcontrib-applehelp==1.0.4 # via sphinx
sphinxcontrib-devhelp==1.0.2 # via sphinx
sphinxcontrib-htmlhelp==2.0.1 # via sphinx
@@ -131,31 +143,32 @@ sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme
sphinxcontrib-jsmath==1.0.1 # via sphinx
sphinxcontrib-qthelp==1.0.3 # via sphinx
sphinxcontrib-serializinghtml==1.1.5 # via sphinx
-sympy==1.12 # via dace, gt4py (pyproject.toml)
+sympy==1.9 # via dace, gt4py (pyproject.toml)
tabulate==0.9.0 # via gt4py (pyproject.toml)
toml==0.10.2 # via jupytext
-tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox
+tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pep517, pip-tools, pyproject-api, pyproject-hooks, pytest, scikit-build, setuptools-scm, tox
+tomlkit==0.12.3 # via plette, requirementslib
toolz==0.12.0 # via cytoolz
-tox==4.11.3 # via -r requirements-dev.in
-traitlets==5.10.0 # via jupyter-core, nbformat
+tox==4.11.4 # via -r requirements-dev.in
+traitlets==5.14.0 # via jupyter-core, nbformat
types-aiofiles==23.2.0.0 # via types-all
types-all==1.0.0 # via -r requirements-dev.in
types-annoy==1.17.8.4 # via types-all
types-atomicwrites==1.4.5.1 # via types-all
types-backports==0.1.3 # via types-all
types-backports-abc==0.5.2 # via types-all
-types-bleach==6.0.0.4 # via types-all
+types-bleach==6.1.0.1 # via types-all
types-boto==2.49.18.9 # via types-all
-types-cachetools==5.3.0.6 # via types-all
+types-cachetools==5.3.0.7 # via types-all
types-certifi==2021.10.8.3 # via types-all
-types-cffi==1.15.1.15 # via types-jack-client
+types-cffi==1.16.0.0 # via types-jack-client
types-characteristic==14.3.7 # via types-all
types-chardet==5.0.4.6 # via types-all
types-click==7.1.8 # via types-all, types-flask
-types-click-spinner==0.1.13.5 # via types-all
+types-click-spinner==0.1.13.6 # via types-all
types-colorama==0.4.15.12 # via types-all
types-contextvars==2.4.7.3 # via types-all
-types-croniter==1.4.0.1 # via types-all
+types-croniter==2.0.0.0 # via types-all
types-cryptography==3.3.23.2 # via types-all, types-openssl-python, types-pyjwt
types-dataclasses==0.6.6 # via types-all
types-dateparser==1.1.4.10 # via types-all
@@ -176,44 +189,44 @@ types-futures==3.3.8 # via types-all
types-geoip2==3.0.0 # via types-all
types-ipaddress==1.0.8 # via types-all, types-maxminddb
types-itsdangerous==1.1.6 # via types-all
-types-jack-client==0.5.10.9 # via types-all
+types-jack-client==0.5.10.10 # via types-all
types-jinja2==2.11.9 # via types-all, types-flask
types-kazoo==0.1.3 # via types-all
-types-markdown==3.4.2.10 # via types-all
+types-markdown==3.5.0.3 # via types-all
types-markupsafe==1.1.10 # via types-all, types-jinja2
types-maxminddb==1.5.0 # via types-all, types-geoip2
-types-mock==5.1.0.2 # via types-all
+types-mock==5.1.0.3 # via types-all
types-mypy-extensions==1.0.0.5 # via types-all
types-nmap==0.1.6 # via types-all
types-openssl-python==0.1.3 # via types-all
types-orjson==3.6.2 # via types-all
-types-paramiko==3.3.0.0 # via types-all, types-pysftp
+types-paramiko==3.3.0.2 # via types-all, types-pysftp
types-pathlib2==2.3.0 # via types-all
-types-pillow==10.0.0.3 # via types-all
+types-pillow==10.1.0.2 # via types-all
types-pkg-resources==0.1.3 # via types-all
types-polib==1.2.0.1 # via types-all
-types-protobuf==4.24.0.1 # via types-all
+types-protobuf==4.24.0.4 # via types-all
types-pyaudio==0.2.16.7 # via types-all
types-pycurl==7.45.2.5 # via types-all
types-pyfarmhash==0.3.1.2 # via types-all
types-pyjwt==1.7.1 # via types-all
types-pymssql==2.1.0 # via types-all
types-pymysql==1.1.0.1 # via types-all
-types-pyopenssl==23.2.0.2 # via types-redis
+types-pyopenssl==23.3.0.0 # via types-redis
types-pyrfc3339==1.1.1.5 # via types-all
types-pysftp==0.2.17.6 # via types-all
types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange
types-python-gflags==3.1.7.3 # via types-all
types-python-slugify==8.0.0.3 # via types-all
-types-pytz==2023.3.1.0 # via types-all, types-tzlocal
+types-pytz==2023.3.1.1 # via types-all, types-tzlocal
types-pyvmomi==8.0.0.6 # via types-all
-types-pyyaml==6.0.12.11 # via types-all
-types-redis==4.6.0.6 # via types-all
-types-requests==2.31.0.2 # via types-all
+types-pyyaml==6.0.12.12 # via types-all
+types-redis==4.6.0.11 # via types-all
+types-requests==2.31.0.10 # via types-all
types-retry==0.9.9.4 # via types-all
types-routes==2.5.0 # via types-all
types-scribe==2.0.0 # via types-all
-types-setuptools==68.2.0.0 # via types-cffi
+types-setuptools==69.0.0.0 # via types-cffi
types-simplejson==3.19.0.2 # via types-all
types-singledispatch==4.1.0.0 # via types-all
types-six==1.16.21.9 # via types-all
@@ -222,21 +235,21 @@ types-termcolor==1.1.6.2 # via types-all
types-toml==0.10.8.7 # via types-all
types-tornado==5.1.1 # via types-all
types-typed-ast==1.5.8.7 # via types-all
-types-tzlocal==5.0.1.1 # via types-all
+types-tzlocal==5.1.0.1 # via types-all
types-ujson==5.8.0.1 # via types-all
-types-urllib3==1.26.25.14 # via types-requests
types-waitress==2.1.4.9 # via types-all
types-werkzeug==1.0.9 # via types-all, types-flask
types-xxhash==3.0.5.2 # via types-all
-typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy
-urllib3==2.0.4 # via requests
-virtualenv==20.24.5 # via pre-commit, tox
-websockets==11.0.3 # via dace
-werkzeug==2.3.7 # via flask
-wheel==0.41.2 # via astunparse, pip-tools
+typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pydantic, pytest-factoryboy, setuptools-scm
+urllib3==2.1.0 # via requests, types-requests
+virtualenv==20.25.0 # via pre-commit, tox
+websockets==12.0 # via dace
+werkzeug==3.0.1 # via flask
+wheel==0.42.0 # via astunparse, pip-tools, scikit-build
xxhash==3.0.0 # via gt4py (pyproject.toml)
-zipp==3.16.2 # via importlib-metadata, importlib-resources
+yarg==0.1.9 # via pipreqs
+zipp==3.17.0 # via importlib-metadata, importlib-resources
# The following packages are considered to be unsafe in a requirements file:
-pip==23.2.1 # via pip-tools
-setuptools==68.2.2 # via gt4py (pyproject.toml), nodeenv, pip-tools
+pip==23.3.1 # via pip-api, pip-tools, requirementslib
+setuptools==69.0.2 # via gt4py (pyproject.toml), nodeenv, pip-tools, requirementslib, scikit-build, setuptools-scm
diff --git a/src/gt4py/__init__.py b/src/gt4py/__init__.py
index 7d255de142..c28c5cf2d6 100644
--- a/src/gt4py/__init__.py
+++ b/src/gt4py/__init__.py
@@ -33,6 +33,6 @@
if _sys.version_info >= (3, 10):
- from . import next
+ from . import next # noqa: A004
__all__ += ["next"]
diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py
index b1e559a41e..5dae025acb 100644
--- a/src/gt4py/cartesian/backend/dace_backend.py
+++ b/src/gt4py/cartesian/backend/dace_backend.py
@@ -562,12 +562,6 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.S
omp_threads = ""
omp_header = ""
- # Backward compatible state struct name change in DaCe >=0.15.x
- try:
- dace_state_suffix = dace.Config.get("compiler.codegen_state_struct_suffix")
- except (KeyError, TypeError):
- dace_state_suffix = "_t" # old structure name
-
interface = cls.template.definition.render(
name=sdfg.name,
backend_specifics=omp_threads,
@@ -575,7 +569,7 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.S
functor_args=self.generate_functor_args(sdfg),
tmp_allocs=self.generate_tmp_allocs(sdfg),
allocator="gt::cuda_util::cuda_malloc" if is_gpu else "std::make_unique",
- state_suffix=dace_state_suffix,
+ state_suffix=dace.Config.get("compiler.codegen_state_struct_suffix"),
)
generated_code = textwrap.dedent(
f"""#include
diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py
index db276a48b9..48b129fa87 100644
--- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py
+++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py
@@ -30,6 +30,7 @@
compute_dcir_access_infos,
flatten_list,
get_tasklet_symbol,
+ make_dace_subset,
union_inout_memlets,
union_node_grid_subsets,
untile_memlets,
@@ -458,6 +459,40 @@ def visit_HorizontalExecution(
write_memlets=write_memlets,
)
+ for memlet in [*read_memlets, *write_memlets]:
+ """
+ This loop handles the special case of a tasklet performing array access.
+ The memlet should pass the full array shape (no tiling) and
+ the tasklet expression for array access should use all explicit indexes.
+ """
+ array_ndims = len(global_ctx.arrays[memlet.field].shape)
+ field_decl = global_ctx.library_node.field_decls[memlet.field]
+ # calculate array subset on original memlet
+ memlet_subset = make_dace_subset(
+ global_ctx.library_node.access_infos[memlet.field],
+ memlet.access_info,
+ field_decl.data_dims,
+ )
+ # select index values for single-point grid access
+ memlet_data_index = [
+ dcir.Literal(value=str(dim_range[0]), dtype=common.DataType.INT32)
+ for dim_range, dim_size in zip(memlet_subset, memlet_subset.size())
+ if dim_size == 1
+ ]
+ if len(memlet_data_index) < array_ndims:
+ reshape_memlet = False
+ for access_node in dcir_node.walk_values().if_isinstance(dcir.IndexAccess):
+ if access_node.data_index and access_node.name == memlet.connector:
+ access_node.data_index = memlet_data_index + access_node.data_index
+ assert len(access_node.data_index) == array_ndims
+ reshape_memlet = True
+ if reshape_memlet:
+ # ensure that memlet symbols used for array indexing are defined in context
+ for sym in memlet.access_info.grid_subset.free_symbols:
+ symbol_collector.add_symbol(sym)
+ # set full shape on memlet
+ memlet.access_info = global_ctx.library_node.access_infos[memlet.field]
+
for item in reversed(expansion_items):
iteration_ctx = iteration_ctx.pop()
dcir_node = self._process_iteration_item(
diff --git a/src/gt4py/cartesian/gtc/dace/nodes.py b/src/gt4py/cartesian/gtc/dace/nodes.py
index ddcb719b5f..bd8c08034c 100644
--- a/src/gt4py/cartesian/gtc/dace/nodes.py
+++ b/src/gt4py/cartesian/gtc/dace/nodes.py
@@ -121,7 +121,7 @@ def __init__(
*args,
**kwargs,
):
- super().__init__(name=name, *args, **kwargs)
+ super().__init__(*args, name=name, **kwargs)
from gt4py.cartesian.gtc.dace.utils import compute_dcir_access_infos
diff --git a/src/gt4py/cartesian/gtc/daceir.py b/src/gt4py/cartesian/gtc/daceir.py
index 28ebc8cd8e..0366317360 100644
--- a/src/gt4py/cartesian/gtc/daceir.py
+++ b/src/gt4py/cartesian/gtc/daceir.py
@@ -536,7 +536,7 @@ def union(self, other):
else:
assert (
isinstance(interval2, (TileInterval, DomainInterval))
- and isinstance(interval1, IndexWithExtent)
+ and isinstance(interval1, (IndexWithExtent, DomainInterval))
) or (
isinstance(interval1, (TileInterval, DomainInterval))
and isinstance(interval2, IndexWithExtent)
@@ -573,7 +573,7 @@ def overapproximated_shape(self):
def apply_iteration(self, grid_subset: GridSubset):
res_intervals = dict(self.grid_subset.intervals)
for axis, field_interval in self.grid_subset.intervals.items():
- if axis in grid_subset.intervals:
+ if axis in grid_subset.intervals and not isinstance(field_interval, DomainInterval):
grid_interval = grid_subset.intervals[axis]
assert isinstance(field_interval, IndexWithExtent)
extent = field_interval.extent
diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py
index fcd53d1312..bc744b3ccc 100644
--- a/src/gt4py/eve/datamodels/core.py
+++ b/src/gt4py/eve/datamodels/core.py
@@ -814,7 +814,7 @@ def concretize(
""" # noqa: RST301 # doctest conventions confuse RST validator
concrete_cls: Type[DataModelT] = _make_concrete_with_cache(
- datamodel_cls, *type_args, class_name=class_name, module=module
+ datamodel_cls, *type_args, class_name=class_name, module=module # type: ignore[arg-type]
)
assert isinstance(concrete_cls, type) and is_datamodel(concrete_cls)
@@ -883,17 +883,6 @@ def _substitute_typevars(
return type_params_map[type_hint], True
elif getattr(type_hint, "__parameters__", []):
return type_hint[tuple(type_params_map[tp] for tp in type_hint.__parameters__)], True
- # TODO(egparedes): WIP fix for partial specialization
- # # Type hint is a generic model: replace all the concretized type vars
- # noqa: e800 replaced = False
- # noqa: e800 new_args = []
- # noqa: e800 for tp in type_hint.__parameters__:
- # noqa: e800 if tp in type_params_map:
- # noqa: e800 new_args.append(type_params_map[tp])
- # noqa: e800 replaced = True
- # noqa: e800 else:
- # noqa: e800 new_args.append(type_params_map[tp])
- # noqa: e800 return type_hint[tuple(new_args)], replaced
else:
return type_hint, False
@@ -981,21 +970,14 @@ def __class_getitem__(
"""
type_args: Tuple[Type] = args if isinstance(args, tuple) else (args,)
concrete_cls: Type[DataModelT] = concretize(cls, *type_args)
- res = xtyping.StdGenericAliasType(concrete_cls, type_args)
- if sys.version_info < (3, 9):
- # in Python 3.8, xtyping.StdGenericAliasType (aka typing._GenericAlias)
- # does not copy all required `__dict__` entries, so do it manually
- for k, v in concrete_cls.__dict__.items():
- if k not in res.__dict__:
- res.__dict__[k] = v
- return res
+ return concrete_cls
return classmethod(__class_getitem__)
def _make_type_converter(type_annotation: TypeAnnotation, name: str) -> TypeConverter[_T]:
- # TODO(egparedes): if a "typing tree" structure is implemented, refactor this code as a tree traversal.
- #
+ # TODO(egparedes): if a "typing tree" structure is implemented, refactor this code
+ # as a tree traversal.
if xtyping.is_actual_type(type_annotation) and not isinstance(None, type_annotation):
assert not xtyping.get_args(type_annotation)
assert isinstance(type_annotation, type)
@@ -1316,11 +1298,7 @@ def _make_concrete_with_cache(
# Replace field definitions with the new actual types for generic fields
type_params_map = dict(zip(datamodel_cls.__parameters__, type_args))
model_fields = getattr(datamodel_cls, MODEL_FIELD_DEFINITIONS_ATTR)
- new_annotations = {
- # TODO(egparedes): ?
- # noqa: e800 "__args__": "ClassVar[Tuple[Union[Type, TypeVar], ...]]",
- # noqa: e800 "__parameters__": "ClassVar[Tuple[TypeVar, ...]]",
- }
+ new_annotations = {}
new_field_c_attrs = {}
for field_name, field_type in xtyping.get_type_hints(datamodel_cls).items():
@@ -1353,8 +1331,16 @@ def _make_concrete_with_cache(
"__module__": module if module else datamodel_cls.__module__,
**new_field_c_attrs,
}
-
concrete_cls = type(class_name, (datamodel_cls,), namespace)
+
+ # Update the tuple of generic parameters in the new class, in case
+ # this is a partial concretization
+ assert hasattr(concrete_cls, "__parameters__")
+ concrete_cls.__parameters__ = tuple(
+ type_params_map[tp_var]
+ for tp_var in datamodel_cls.__parameters__
+ if isinstance(type_params_map[tp_var], typing.TypeVar)
+ )
assert concrete_cls.__module__ == module or not module
if MODEL_FIELD_DEFINITIONS_ATTR not in concrete_cls.__dict__:
diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py
index 17462a37ff..3ee447ca6c 100644
--- a/src/gt4py/eve/extended_typing.py
+++ b/src/gt4py/eve/extended_typing.py
@@ -493,7 +493,7 @@ def _patched_proto_hook(other): # type: ignore[no-untyped-def]
if isinstance(_typing.Any, type): # Python >= 3.11
_ArtefactTypes = (*_ArtefactTypes, _typing.Any)
-# `Any` is a class since typing_extensions >= 4.4
+# `Any` is a class since typing_extensions >= 4.4 and Python 3.11
if (typing_exts_any := getattr(_typing_extensions, "Any", None)) is not _typing.Any and isinstance(
typing_exts_any, type
):
@@ -504,11 +504,13 @@ def is_actual_type(obj: Any) -> TypeGuard[Type]:
"""Check if an object has an actual type and instead of a typing artefact like ``GenericAlias`` or ``Any``.
This is needed because since Python 3.9:
- ``isinstance(types.GenericAlias(), type) is True``
+ ``isinstance(types.GenericAlias(), type) is True``
and since Python 3.11:
- ``isinstance(typing.Any, type) is True``
+ ``isinstance(typing.Any, type) is True``
"""
- return isinstance(obj, type) and type(obj) not in _ArtefactTypes
+ return (
+ isinstance(obj, type) and (obj not in _ArtefactTypes) and (type(obj) not in _ArtefactTypes)
+ )
if hasattr(_typing_extensions, "Any") and _typing.Any is not _typing_extensions.Any: # type: ignore[attr-defined] # _typing_extensions.Any only from >= 4.4
@@ -641,9 +643,12 @@ def get_partial_type_hints(
resolved_hints = get_type_hints( # type: ignore[call-arg] # Python 3.8 does not define `include-extras`
obj, globalns=globalns, localns=localns, include_extras=include_extras
)
- hints.update(resolved_hints)
+ hints[name] = resolved_hints[name]
except NameError as error:
if isinstance(hint, str):
+ # This conversion could be probably skipped in Python versions containing
+ # the fix applied in bpo-41370. Check:
+ # https://github.com/python/cpython/commit/b465b606049f6f7dd0711cb031fdaa251818741a#diff-ddb987fca5f5df0c9a2f5521ed687919d70bb3d64eaeb8021f98833a2a716887R344
hints[name] = ForwardRef(hint)
elif isinstance(hint, (ForwardRef, _typing.ForwardRef)):
hints[name] = hint
diff --git a/src/gt4py/eve/trees.py b/src/gt4py/eve/trees.py
index cd7e71588f..74c5bd41bb 100644
--- a/src/gt4py/eve/trees.py
+++ b/src/gt4py/eve/trees.py
@@ -133,7 +133,7 @@ def _pre_walk_items(
yield from _pre_walk_items(child, __key__=key)
-def _pre_walk_values(node: TreeLike) -> Iterable[Tuple[Any]]:
+def _pre_walk_values(node: TreeLike) -> Iterable:
"""Create a pre-order tree traversal iterator of values."""
yield node
for child in iter_children_values(node):
@@ -153,7 +153,7 @@ def _post_walk_items(
yield __key__, node
-def _post_walk_values(node: TreeLike) -> Iterable[Tuple[Any]]:
+def _post_walk_values(node: TreeLike) -> Iterable:
"""Create a post-order tree traversal iterator of values."""
if (iter_children_values := getattr(node, "iter_children_values", None)) is not None:
for child in iter_children_values():
diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py
index 7104f7658f..624407f319 100644
--- a/src/gt4py/eve/utils.py
+++ b/src/gt4py/eve/utils.py
@@ -1225,7 +1225,7 @@ def unzip(self) -> XIterable[Tuple[Any, ...]]:
[('a', 'b', 'c'), (1, 2, 3)]
"""
- return XIterable(zip(*self.iterator)) # type: ignore # mypy gets confused with *args
+ return XIterable(zip(*self.iterator))
@typing.overload
def islice(self, __stop: int) -> XIterable[T]:
@@ -1536,7 +1536,7 @@ def reduceby(
) -> Dict[K, S]:
...
- def reduceby( # type: ignore[misc] # signatures 2 and 4 are not satified due to inconsistencies with type variables
+ def reduceby(
self,
bin_op_func: Callable[[S, T], S],
key: Union[str, List[K], Callable[[T], K]],
diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py
index cbd5735949..1398af5f03 100644
--- a/src/gt4py/next/__init__.py
+++ b/src/gt4py/next/__init__.py
@@ -39,6 +39,11 @@
index_field,
np_as_located_field,
)
+from .program_processors.runners.gtfn import (
+ run_gtfn_cached as gtfn_cpu,
+ run_gtfn_gpu_cached as gtfn_gpu,
+)
+from .program_processors.runners.roundtrip import backend as itir_python
__all__ = [
@@ -74,5 +79,9 @@
"field_operator",
"program",
"scan_operator",
+ # from program_processor
+ "gtfn_cpu",
+ "gtfn_gpu",
+ "itir_python",
*fbuiltins.__all__,
]
diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py
index 29d606ccc0..949f4b461a 100644
--- a/src/gt4py/next/common.py
+++ b/src/gt4py/next/common.py
@@ -189,11 +189,12 @@ def __and__(self, other: UnitRange) -> UnitRange:
return UnitRange(max(self.start, other.start), min(self.stop, other.stop))
def __contains__(self, value: Any) -> bool:
- return (
- isinstance(value, core_defs.INTEGRAL_TYPES)
- and value >= self.start
- and value < self.stop
- )
+ # TODO(egparedes): use core_defs.IntegralScalar for `isinstance()` checks (see PEP 604)
+ # and remove int cast, once the related mypy bug (#16358) gets fixed
+ if isinstance(value, core_defs.INTEGRAL_TYPES):
+ return self.start <= cast(int, value) < self.stop
+ else:
+ return False
def __le__(self, other: UnitRange) -> bool:
return self.start >= other.start and self.stop <= other.stop
@@ -574,38 +575,39 @@ def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _
...
-# TODO(havogt): replace this protocol with the new `GTFieldInterface` protocol
-class NextGTDimsInterface(Protocol):
+# TODO(havogt): we need to describe when this interface should be used instead of the `Field` protocol.
+class GTFieldInterface(core_defs.GTDimsInterface, core_defs.GTOriginInterface, Protocol):
"""
- Protocol for objects providing the `__gt_dims__` property, naming :class:`Field` dimensions.
+ Protocol for object providing the `__gt_domain__` property, specifying the :class:`Domain` of a :class:`Field`.
- The dimension names are objects of type :class:`Dimension`, in contrast to
- :mod:`gt4py.cartesian`, where the labels are `str` s with implied semantics,
- see :class:`~gt4py._core.definitions.GTDimsInterface` .
+ Note:
+ - A default implementation of the `__gt_dims__` interface from `gt4py.cartesian` is provided.
+ - No implementation of `__gt_origin__` is provided because of infinite fields.
"""
@property
- def __gt_dims__(self) -> tuple[Dimension, ...]:
+ def __gt_domain__(self) -> Domain:
+ # TODO probably should be changed to `DomainLike` (with a new concept `DimensionLike`)
+ # to allow implementations without having to import gtx.Domain.
...
-
-# TODO(egparedes): add support for this new protocol in the cartesian module
-class GTFieldInterface(Protocol):
- """Protocol for object providing the `__gt_domain__` property, specifying the :class:`Domain` of a :class:`Field`."""
-
@property
- def __gt_domain__(self) -> Domain:
- ...
+ def __gt_dims__(self) -> tuple[str, ...]:
+ return tuple(d.value for d in self.__gt_domain__.dims)
@extended_runtime_checkable
-class Field(NextGTDimsInterface, core_defs.GTOriginInterface, Protocol[DimsT, core_defs.ScalarT]):
+class Field(GTFieldInterface, Protocol[DimsT, core_defs.ScalarT]):
__gt_builtin_func__: ClassVar[GTBuiltInFuncDispatcher]
@property
def domain(self) -> Domain:
...
+ @property
+ def __gt_domain__(self) -> Domain:
+ return self.domain
+
@property
def codomain(self) -> type[core_defs.ScalarT] | Dimension:
...
@@ -923,10 +925,6 @@ def asnumpy(self) -> Never:
def domain(self) -> Domain:
return Domain(dims=(self.dimension,), ranges=(UnitRange.infinite(),))
- @property
- def __gt_dims__(self) -> tuple[Dimension, ...]:
- return self.domain.dims
-
@property
def __gt_origin__(self) -> Never:
raise TypeError("'CartesianConnectivity' does not support this operation.")
diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py
index 8bd2673db9..9fc1b42038 100644
--- a/src/gt4py/next/embedded/nd_array_field.py
+++ b/src/gt4py/next/embedded/nd_array_field.py
@@ -107,10 +107,6 @@ def domain(self) -> common.Domain:
def shape(self) -> tuple[int, ...]:
return self._ndarray.shape
- @property
- def __gt_dims__(self) -> tuple[common.Dimension, ...]:
- return self._domain.dims
-
@property
def __gt_origin__(self) -> tuple[int, ...]:
assert common.Domain.is_finite(self._domain)
diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py
index 76a0ddcde0..9f8537f59b 100644
--- a/src/gt4py/next/ffront/decorator.py
+++ b/src/gt4py/next/ffront/decorator.py
@@ -344,7 +344,9 @@ def _validate_args(self, *args, **kwargs) -> None:
raise_exception=True,
)
except ValueError as err:
- raise TypeError(f"Invalid argument types in call to '{self.past_node.id}'.") from err
+ raise errors.DSLError(
+ None, f"Invalid argument types in call to '{self.past_node.id}'.\n{err}"
+ ) from err
def _process_args(self, args: tuple, kwargs: dict) -> tuple[tuple, tuple, dict[str, Any]]:
self._validate_args(*args, **kwargs)
@@ -453,27 +455,32 @@ def _process_args(self, args: tuple, kwargs: dict):
) from err
full_args = [*args]
+ full_kwargs = {**kwargs}
for index, param in enumerate(self.past_node.params):
if param.id in self.bound_args.keys():
- full_args.insert(index, self.bound_args[param.id])
+ if index < len(full_args):
+ full_args.insert(index, self.bound_args[param.id])
+ else:
+ full_kwargs[str(param.id)] = self.bound_args[param.id]
- return super()._process_args(tuple(full_args), kwargs)
+ return super()._process_args(tuple(full_args), full_kwargs)
@functools.cached_property
def itir(self):
new_itir = super().itir
for new_clos in new_itir.closures:
- for key in self.bound_args.keys():
+ new_args = [ref(inp.id) for inp in new_clos.inputs]
+ for key, value in self.bound_args.items():
index = next(
index
for index, closure_input in enumerate(new_clos.inputs)
if closure_input.id == key
)
+ new_args[new_args.index(new_clos.inputs[index])] = promote_to_const_iterator(
+ literal_from_value(value)
+ )
new_clos.inputs.pop(index)
- new_args = [ref(inp.id) for inp in new_clos.inputs]
params = [sym(inp.id) for inp in new_clos.inputs]
- for value in self.bound_args.values():
- new_args.append(promote_to_const_iterator(literal_from_value(value)))
expr = itir.FunCall(
fun=new_clos.stencil,
args=new_args,
@@ -847,6 +854,7 @@ def scan_operator_inner(definition: types.FunctionType) -> FieldOperator:
return FieldOperator.from_function(
definition,
DEFAULT_BACKEND if backend is eve.NOTHING else backend,
+ grid_type,
operator_node_cls=foast.ScanOperator,
operator_attributes={"axis": axis, "forward": forward, "init": init},
)
diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py
index 278dde9180..cd75538da7 100644
--- a/src/gt4py/next/ffront/fbuiltins.py
+++ b/src/gt4py/next/ffront/fbuiltins.py
@@ -15,7 +15,7 @@
import dataclasses
import functools
import inspect
-from builtins import bool, float, int, tuple
+from builtins import bool, float, int, tuple # noqa: A004
from typing import Any, Callable, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast
import numpy as np
diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py
index 639e5ff009..5e289af664 100644
--- a/src/gt4py/next/ffront/foast_passes/type_deduction.py
+++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py
@@ -694,7 +694,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call:
)
except ValueError as err:
raise errors.DSLError(
- node.location, f"Invalid argument types in call to '{new_func}'."
+ node.location, f"Invalid argument types in call to '{new_func}'.\n{err}"
) from err
return_type = type_info.return_type(func_type, with_args=arg_types, with_kwargs=kwarg_types)
diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py
index fc353d64e4..af8f5e8368 100644
--- a/src/gt4py/next/ffront/past_passes/type_deduction.py
+++ b/src/gt4py/next/ffront/past_passes/type_deduction.py
@@ -229,7 +229,7 @@ def visit_Call(self, node: past.Call, **kwargs):
)
except ValueError as ex:
- raise errors.DSLError(node.location, f"Invalid call to '{node.func.id}'.") from ex
+ raise errors.DSLError(node.location, f"Invalid call to '{node.func.id}'.\n{ex}") from ex
return past.Call(
func=new_func,
diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py
index ef70a2e645..390bec4312 100644
--- a/src/gt4py/next/iterator/embedded.py
+++ b/src/gt4py/next/iterator/embedded.py
@@ -172,7 +172,7 @@ class LocatedField(Protocol):
@property
@abc.abstractmethod
- def __gt_dims__(self) -> tuple[common.Dimension, ...]:
+ def __gt_domain__(self) -> common.Domain:
...
# TODO(havogt): define generic Protocol to provide a concrete return type
@@ -182,7 +182,7 @@ def field_getitem(self, indices: NamedFieldIndices) -> Any:
@property
def __gt_origin__(self) -> tuple[int, ...]:
- return tuple([0] * len(self.__gt_dims__))
+ return tuple([0] * len(self.__gt_domain__.dims))
@runtime_checkable
@@ -675,12 +675,18 @@ def _is_concrete_position(pos: Position) -> TypeGuard[ConcretePosition]:
def _get_axes(
field_or_tuple: LocatedField | tuple,
) -> Sequence[common.Dimension]: # arbitrary nesting of tuples of LocatedField
+ return _get_domain(field_or_tuple).dims
+
+
+def _get_domain(
+ field_or_tuple: LocatedField | tuple,
+) -> common.Domain: # arbitrary nesting of tuples of LocatedField
if isinstance(field_or_tuple, tuple):
- first = _get_axes(field_or_tuple[0])
- assert all(first == _get_axes(f) for f in field_or_tuple)
+ first = _get_domain(field_or_tuple[0])
+ assert all(first == _get_domain(f) for f in field_or_tuple)
return first
else:
- return field_or_tuple.__gt_dims__
+ return field_or_tuple.__gt_domain__
def _single_vertical_idx(
@@ -894,14 +900,14 @@ class NDArrayLocatedFieldWrapper(MutableLocatedField):
_ndarrayfield: common.Field
@property
- def __gt_dims__(self) -> tuple[common.Dimension, ...]:
- return self._ndarrayfield.__gt_dims__
+ def __gt_domain__(self) -> common.Domain:
+ return self._ndarrayfield.__gt_domain__
def _translate_named_indices(
self, _named_indices: NamedFieldIndices
) -> common.AbsoluteIndexSequence:
named_indices: Mapping[common.Dimension, FieldIndex | SparsePositionEntry] = {
- d: _named_indices[d.value] for d in self._ndarrayfield.__gt_dims__
+ d: _named_indices[d.value] for d in self._ndarrayfield.__gt_domain__.dims
}
domain_slice: list[common.NamedRange | common.NamedIndex] = []
for d, v in named_indices.items():
@@ -1046,8 +1052,8 @@ class IndexField(common.Field):
_dimension: common.Dimension
@property
- def __gt_dims__(self) -> tuple[common.Dimension, ...]:
- return (self._dimension,)
+ def __gt_domain__(self) -> common.Domain:
+ return self.domain
@property
def __gt_origin__(self) -> tuple[int, ...]:
@@ -1165,8 +1171,8 @@ class ConstantField(common.Field[Any, core_defs.ScalarT]):
_value: core_defs.ScalarT
@property
- def __gt_dims__(self) -> tuple[common.Dimension, ...]:
- return tuple()
+ def __gt_domain__(self) -> common.Domain:
+ return self.domain
@property
def __gt_origin__(self) -> tuple[int, ...]:
@@ -1452,7 +1458,7 @@ def _tuple_assign(field: tuple | MutableLocatedField, value: Any, named_indices:
class TupleOfFields(TupleField):
def __init__(self, data):
self.data = data
- self.__gt_dims__ = _get_axes(data)
+ self.__gt_domain__ = _get_domain(data)
def field_getitem(self, named_indices: NamedFieldIndices) -> Any:
return _build_tuple_result(self.data, named_indices)
diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py
index 30fec1f9fd..05ebd02352 100644
--- a/src/gt4py/next/iterator/tracing.py
+++ b/src/gt4py/next/iterator/tracing.py
@@ -254,7 +254,7 @@ def _contains_tuple_dtype_field(arg):
# other `np.int32`). We just ignore the error here and postpone fixing this to when
# the new storages land (The implementation here works for LocatedFieldImpl).
- return common.is_field(arg) and any(dim is None for dim in arg.__gt_dims__)
+ return common.is_field(arg) and any(dim is None for dim in arg.domain.dims)
def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]:
diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py
index 2609e35735..7ad55d0a87 100644
--- a/src/gt4py/next/iterator/transforms/global_tmps.py
+++ b/src/gt4py/next/iterator/transforms/global_tmps.py
@@ -23,6 +23,7 @@
from gt4py.eve.traits import SymbolTableTrait
from gt4py.eve.utils import UIDGenerator
from gt4py.eve.visitors import PreserveLocationVisitor
+from gt4py.next import common
from gt4py.next.iterator import ir, type_inference
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift
@@ -442,9 +443,12 @@ def _group_offsets(
return zip(tags, offsets, strict=True) # type: ignore[return-value] # mypy doesn't infer literal correctly
-def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, Any]):
+def update_domains(
+ node: FencilWithTemporaries,
+ offset_provider: Mapping[str, Any],
+ symbolic_sizes: Optional[dict[str, str]],
+):
horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider)
-
closures: list[ir.StencilClosure] = []
domains = dict[str, ir.FunCall]()
for closure in reversed(node.fencil.closures):
@@ -485,16 +489,29 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An
# cartesian shift
dim = offset_provider[offset_name].value
consumed_domain.ranges[dim] = consumed_domain.ranges[dim].translate(offset)
- elif isinstance(offset_provider[offset_name], gtx.NeighborTableOffsetProvider):
+ elif isinstance(offset_provider[offset_name], common.Connectivity):
# unstructured shift
nbt_provider = offset_provider[offset_name]
old_axis = nbt_provider.origin_axis.value
new_axis = nbt_provider.neighbor_axis.value
- consumed_domain.ranges.pop(old_axis)
- assert new_axis not in consumed_domain.ranges
- consumed_domain.ranges[new_axis] = SymbolicRange(
- im.literal("0", ir.INTEGER_INDEX_BUILTIN),
- im.literal(str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN),
+
+ assert new_axis not in consumed_domain.ranges or old_axis == new_axis
+
+ if symbolic_sizes is None:
+ new_range = SymbolicRange(
+ im.literal("0", ir.INTEGER_INDEX_BUILTIN),
+ im.literal(
+ str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN
+ ),
+ )
+ else:
+ new_range = SymbolicRange(
+ im.literal("0", ir.INTEGER_INDEX_BUILTIN),
+ im.ref(symbolic_sizes[new_axis]),
+ )
+ consumed_domain.ranges = dict(
+ (axis, range_) if axis != old_axis else (new_axis, new_range)
+ for axis, range_ in consumed_domain.ranges.items()
)
else:
raise NotImplementedError
@@ -577,7 +594,11 @@ class CreateGlobalTmps(PreserveLocationVisitor, NodeTranslator):
"""
def visit_FencilDefinition(
- self, node: ir.FencilDefinition, *, offset_provider: Mapping[str, Any]
+ self,
+ node: ir.FencilDefinition,
+ *,
+ offset_provider: Mapping[str, Any],
+ symbolic_sizes: Optional[dict[str, str]],
) -> FencilWithTemporaries:
# Split closures on lifted function calls and introduce temporaries
res = split_closures(node, offset_provider=offset_provider)
@@ -588,6 +609,6 @@ def visit_FencilDefinition(
# Perform an eta-reduction which should put all calls at the highest level of a closure
res = EtaReduction().visit(res)
# Perform a naive extent analysis to compute domain sizes of closures and temporaries
- res = update_domains(res, offset_provider)
+ res = update_domains(res, offset_provider, symbolic_sizes)
# Use type inference to determine the data type of the temporaries
return collect_tmps_info(res, offset_provider=offset_provider)
diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py
index 2e05391634..08897861c2 100644
--- a/src/gt4py/next/iterator/transforms/pass_manager.py
+++ b/src/gt4py/next/iterator/transforms/pass_manager.py
@@ -13,6 +13,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import enum
+from typing import Optional
from gt4py.next.iterator import ir
from gt4py.next.iterator.transforms import simple_inline_heuristic
@@ -81,6 +82,7 @@ def apply_common_transforms(
common_subexpression_elimination=True,
force_inline_lambda_args=False,
unconditionally_collapse_tuples=False,
+ symbolic_domain_sizes: Optional[dict[str, str]] = None,
):
if lift_mode is None:
lift_mode = LiftMode.FORCE_INLINE
@@ -147,7 +149,9 @@ def apply_common_transforms(
if lift_mode != LiftMode.FORCE_INLINE:
assert offset_provider is not None
- ir = CreateGlobalTmps().visit(ir, offset_provider=offset_provider)
+ ir = CreateGlobalTmps().visit(
+ ir, offset_provider=offset_provider, symbolic_sizes=symbolic_domain_sizes
+ )
ir = InlineLifts().visit(ir)
# If after creating temporaries, the scan is not at the top, we inline.
# The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it.
diff --git a/src/gt4py/next/iterator/transforms/power_unrolling.py b/src/gt4py/next/iterator/transforms/power_unrolling.py
new file mode 100644
index 0000000000..ac71f2747d
--- /dev/null
+++ b/src/gt4py/next/iterator/transforms/power_unrolling.py
@@ -0,0 +1,84 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+import dataclasses
+import math
+
+from gt4py.eve import NodeTranslator
+from gt4py.next.iterator import ir
+from gt4py.next.iterator.ir_utils import ir_makers as im
+from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas
+
+
+def _is_power_call(
+ node: ir.FunCall,
+) -> bool:
+ """Match expressions of the form `power(base, integral_literal)`."""
+ return (
+ isinstance(node.fun, ir.SymRef)
+ and node.fun.id == "power"
+ and isinstance(node.args[1], ir.Literal)
+ and float(node.args[1].value) == int(node.args[1].value)
+ and node.args[1].value >= im.literal_from_value(0).value
+ )
+
+
+def _compute_integer_power_of_two(exp: int) -> int:
+ return math.floor(math.log2(exp))
+
+
+@dataclasses.dataclass
+class PowerUnrolling(NodeTranslator):
+ max_unroll: int
+
+ @classmethod
+ def apply(cls, node: ir.Node, max_unroll: int = 5) -> ir.Node:
+ return cls(max_unroll=max_unroll).visit(node)
+
+ def visit_FunCall(self, node: ir.FunCall):
+ new_node = self.generic_visit(node)
+
+ if _is_power_call(new_node):
+ assert len(new_node.args) == 2
+ # Check if unroll should be performed or if exponent is too large
+ base, exponent = new_node.args[0], int(new_node.args[1].value)
+ if 1 <= exponent <= self.max_unroll:
+ # Calculate and store powers of two of the base as long as they are smaller than the exponent.
+ # Do the same (using the stored values) with the remainder and multiply computed values.
+ pow_cur = _compute_integer_power_of_two(exponent)
+ pow_max = pow_cur
+ remainder = exponent
+
+ # Build target expression
+ ret = im.ref(f"power_{2 ** pow_max}")
+ remainder -= 2**pow_cur
+ while remainder > 0:
+ pow_cur = _compute_integer_power_of_two(remainder)
+ remainder -= 2**pow_cur
+
+ ret = im.multiplies_(ret, f"power_{2 ** pow_cur}")
+
+ # Nest target expression to avoid multiple redundant evaluations
+ for i in range(pow_max, 0, -1):
+ ret = im.let(
+ f"power_{2 ** i}",
+ im.multiplies_(f"power_{2**(i-1)}", f"power_{2**(i-1)}"),
+ )(ret)
+ ret = im.let("power_1", base)(ret)
+
+ # Simplify expression in case of SymRef by resolving let statements
+ if isinstance(base, ir.SymRef):
+ return InlineLambdas.apply(ret, opcount_preserving=True)
+ else:
+ return ret
+ return new_node
diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py
index 68627cfd89..d65f67b266 100644
--- a/src/gt4py/next/iterator/type_inference.py
+++ b/src/gt4py/next/iterator/type_inference.py
@@ -567,9 +567,7 @@ def _infer_shift_location_types(shift_args, offset_provider, constraints):
axis = offset_provider[offset]
if isinstance(axis, gtx.Dimension):
continue # Cartesian shifts don’t change the location type
- elif isinstance(
- axis, (gtx.NeighborTableOffsetProvider, gtx.StridedNeighborOffsetProvider)
- ):
+ elif isinstance(axis, Connectivity):
assert (
axis.origin_axis.kind
== axis.neighbor_axis.kind
@@ -964,7 +962,7 @@ def visit_FencilDefinition(
def _save_types_to_annex(node: ir.Node, types: dict[int, Type]) -> None:
for child_node in node.pre_walk_values().if_isinstance(*TYPED_IR_NODES):
try:
- child_node.annex.type = types[id(child_node)] # type: ignore[attr-defined]
+ child_node.annex.type = types[id(child_node)]
except KeyError:
if not (
isinstance(child_node, ir.SymRef)
diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py
index ed8b768972..3a82f9c738 100644
--- a/src/gt4py/next/otf/workflow.py
+++ b/src/gt4py/next/otf/workflow.py
@@ -82,7 +82,7 @@ def replace(self, **kwargs: Any) -> Self:
if not dataclasses.is_dataclass(self):
raise TypeError(f"'{self.__class__}' is not a dataclass.")
assert not isinstance(self, type)
- return dataclasses.replace(self, **kwargs) # type: ignore[misc] # `self` is guaranteed to be a dataclass (is_dataclass) should be a `TypeGuard`?
+ return dataclasses.replace(self, **kwargs)
class ChainableWorkflowMixin(Workflow[StartT, EndT]):
diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py
deleted file mode 100644
index 4183f52550..0000000000
--- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py
+++ /dev/null
@@ -1,77 +0,0 @@
-# GT4Py - GridTools Framework
-#
-# Copyright (c) 2014-2023, ETH Zurich
-# All rights reserved.
-#
-# This file is part of the GT4Py project and the GridTools framework.
-# GT4Py is free software: you can redistribute it and/or modify it under
-# the terms of the GNU General Public License as published by the
-# Free Software Foundation, either version 3 of the License, or any later
-# version. See the LICENSE.txt file at the top-level directory of this
-# distribution for a copy of the license or check .
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-from typing import Any
-
-import gt4py.next.iterator.ir as itir
-from gt4py.eve import codegen
-from gt4py.eve.exceptions import EveValueError
-from gt4py.next.iterator.transforms.pass_manager import apply_common_transforms
-from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen
-from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_to_gtfn_im_ir import GTFN_IM_lowering
-from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import GTFN_lowering
-
-
-def _lower(
- program: itir.FencilDefinition, enable_itir_transforms: bool, do_unroll: bool, **kwargs: Any
-):
- offset_provider = kwargs.get("offset_provider")
- assert isinstance(offset_provider, dict)
- if enable_itir_transforms:
- program = apply_common_transforms(
- program,
- lift_mode=kwargs.get("lift_mode"),
- offset_provider=offset_provider,
- unroll_reduce=do_unroll,
- unconditionally_collapse_tuples=True, # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements
- )
- gtfn_ir = GTFN_lowering.apply(
- program,
- offset_provider=offset_provider,
- column_axis=kwargs.get("column_axis"),
- )
- return gtfn_ir
-
-
-def generate(
- program: itir.FencilDefinition, enable_itir_transforms: bool = True, **kwargs: Any
-) -> str:
- if kwargs.get("imperative", False):
- try:
- gtfn_ir = _lower(
- program=program,
- enable_itir_transforms=enable_itir_transforms,
- do_unroll=False,
- **kwargs,
- )
- except EveValueError:
- # if we don't unroll, there may be lifts left in the itir which can't be lowered to
- # gtfn. In this case, just retry with unrolled reductions.
- gtfn_ir = _lower(
- program=program,
- enable_itir_transforms=enable_itir_transforms,
- do_unroll=True,
- **kwargs,
- )
- gtfn_im_ir = GTFN_IM_lowering().visit(node=gtfn_ir, **kwargs)
- generated_code = GTFNIMCodegen.apply(gtfn_im_ir, **kwargs)
- else:
- gtfn_ir = _lower(
- program=program,
- enable_itir_transforms=enable_itir_transforms,
- do_unroll=True,
- **kwargs,
- )
- generated_code = GTFNCodegen.apply(gtfn_ir, **kwargs)
- return codegen.format_source("cpp", generated_code, style="LLVM")
diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
index 4abdaa6eea..718fef72af 100644
--- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
+++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
@@ -15,21 +15,24 @@
from __future__ import annotations
import dataclasses
+import functools
import warnings
from typing import Any, Final, Optional
import numpy as np
from gt4py._core import definitions as core_defs
-from gt4py.eve import trees, utils
+from gt4py.eve import codegen, trees, utils
from gt4py.next import common
from gt4py.next.common import Connectivity, Dimension
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import ir as itir
-from gt4py.next.iterator.transforms import LiftMode
+from gt4py.next.iterator.transforms import LiftMode, pass_manager
from gt4py.next.otf import languages, stages, step_types, workflow
from gt4py.next.otf.binding import cpp_interface, interface
-from gt4py.next.program_processors.codegens.gtfn import gtfn_backend
+from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen
+from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_to_gtfn_im_ir import GTFN_IM_lowering
+from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import GTFN_lowering
from gt4py.next.type_system import type_specifications as ts, type_translation
@@ -54,6 +57,7 @@ class GTFNTranslationStep(
use_imperative_backend: bool = False
lift_mode: Optional[LiftMode] = None
device_type: core_defs.DeviceType = core_defs.DeviceType.CPU
+ symbolic_domain_sizes: Optional[dict[str, str]] = None
def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings:
match self.device_type:
@@ -171,6 +175,70 @@ def _process_connectivity_args(
return parameters, arg_exprs
+ def _preprocess_program(
+ self,
+ program: itir.FencilDefinition,
+ offset_provider: dict[str, Connectivity | Dimension],
+ runtime_lift_mode: Optional[LiftMode] = None,
+ ) -> itir.FencilDefinition:
+ # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added
+ # to the interface of all (or at least all of concern) backends, but instead should be
+ # configured in the backend itself (like it is here), until then we respect the argument
+ # here and warn the user if it differs from the one configured.
+ lift_mode = runtime_lift_mode or self.lift_mode
+ if lift_mode != self.lift_mode:
+ warnings.warn(
+ f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but "
+ f"overriden to be {str(runtime_lift_mode)} at runtime."
+ )
+
+ if not self.enable_itir_transforms:
+ return program
+
+ apply_common_transforms = functools.partial(
+ pass_manager.apply_common_transforms,
+ lift_mode=lift_mode,
+ offset_provider=offset_provider,
+ # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements
+ unconditionally_collapse_tuples=True,
+ symbolic_domain_sizes=self.symbolic_domain_sizes,
+ )
+
+ new_program = apply_common_transforms(
+ program, unroll_reduce=not self.use_imperative_backend
+ )
+
+ if self.use_imperative_backend and any(
+ node.id == "neighbors"
+ for node in new_program.pre_walk_values().if_isinstance(itir.SymRef)
+ ):
+ # if we don't unroll, there may be lifts left in the itir which can't be lowered to
+ # gtfn. In this case, just retry with unrolled reductions.
+ new_program = apply_common_transforms(program, unroll_reduce=True)
+
+ return new_program
+
+ def generate_stencil_source(
+ self,
+ program: itir.FencilDefinition,
+ offset_provider: dict[str, Connectivity | Dimension],
+ column_axis: Optional[common.Dimension],
+ runtime_lift_mode: Optional[LiftMode] = None,
+ ) -> str:
+ new_program = self._preprocess_program(program, offset_provider, runtime_lift_mode)
+ gtfn_ir = GTFN_lowering.apply(
+ new_program,
+ offset_provider=offset_provider,
+ column_axis=column_axis,
+ )
+
+ if self.use_imperative_backend:
+ gtfn_im_ir = GTFN_IM_lowering().visit(node=gtfn_ir)
+ generated_code = GTFNIMCodegen.apply(gtfn_im_ir)
+ else:
+ generated_code = GTFNCodegen.apply(gtfn_ir)
+ return codegen.format_source("cpp", generated_code, style="LLVM")
+
def __call__(
self,
inp: stages.ProgramCall,
@@ -190,18 +258,6 @@ def __call__(
inp.kwargs["offset_provider"]
)
- # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added
- # to the interface of all (or at least all of concern) backends, but instead should be
- # configured in the backend itself (like it is here), until then we respect the argument
- # here and warn the user if it differs from the one configured.
- runtime_lift_mode = inp.kwargs.pop("lift_mode", None)
- lift_mode = runtime_lift_mode or self.lift_mode
- if runtime_lift_mode != self.lift_mode:
- warnings.warn(
- f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but "
- "overriden to be {str(runtime_lift_mode)} at runtime."
- )
-
# combine into a format that is aligned with what the backend expects
parameters: list[interface.Parameter] = regular_parameters + connectivity_parameters
backend_arg = self._backend_type()
@@ -213,12 +269,11 @@ def __call__(
f"{', '.join(connectivity_args_expr)})({', '.join(args_expr)});"
)
decl_src = cpp_interface.render_function_declaration(function, body=decl_body)
- stencil_src = gtfn_backend.generate(
+ stencil_src = self.generate_stencil_source(
program,
- enable_itir_transforms=self.enable_itir_transforms,
- lift_mode=lift_mode,
- imperative=self.use_imperative_backend,
- **inp.kwargs,
+ inp.kwargs["offset_provider"],
+ inp.kwargs.get("column_axis", None),
+ inp.kwargs.get("lift_mode", None),
)
source_code = interface.format_source(
self._language_settings(),
diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py
index f9fa154641..27dec77ed1 100644
--- a/src/gt4py/next/program_processors/formatters/gtfn.py
+++ b/src/gt4py/next/program_processors/formatters/gtfn.py
@@ -15,10 +15,19 @@
from typing import Any
from gt4py.next.iterator import ir as itir
-from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate
+from gt4py.next.program_processors.codegens.gtfn.gtfn_module import GTFNTranslationStep
from gt4py.next.program_processors.processor_interface import program_formatter
+from gt4py.next.program_processors.runners.gtfn import gtfn_executor
@program_formatter
def format_cpp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str:
- return generate(program, **kwargs)
+ # TODO(tehrengruber): This is a little ugly. Revisit.
+ gtfn_translation = gtfn_executor.otf_workflow.translation
+ assert isinstance(gtfn_translation, GTFNTranslationStep)
+ return gtfn_translation.generate_stencil_source(
+ program,
+ offset_provider=kwargs.get("offset_provider", None),
+ column_axis=kwargs.get("column_axis", None),
+ runtime_lift_mode=kwargs.get("lift_mode", None),
+ )
diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
index cd70f8b588..54ca08fe6e 100644
--- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
+++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
@@ -261,10 +261,11 @@ def build_sdfg_from_itir(
# visit ITIR and generate SDFG
program = preprocess_program(program, offset_provider, lift_mode)
- # TODO: According to Lex one should build the SDFG first in a general mannor.
- # Generalisation to a particular device should happen only at the end.
- sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, on_gpu)
+ sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis)
sdfg = sdfg_genenerator.visit(program)
+ if sdfg is None:
+ raise RuntimeError(f"Visit failed for program {program.id}.")
+
for nested_sdfg in sdfg.all_sdfgs_recursive():
if not nested_sdfg.debuginfo:
_, frameinfo = warnings.warn(
@@ -277,6 +278,8 @@ def build_sdfg_from_itir(
end_line=frameinfo.lineno,
filename=frameinfo.filename,
)
+
+ # run DaCe transformations to simplify the SDFG
sdfg.simplify()
# run DaCe auto-optimization heuristics
@@ -287,6 +290,9 @@ def build_sdfg_from_itir(
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU
sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu)
+ if on_gpu:
+ sdfg.apply_gpu_transformations()
+
return sdfg
@@ -296,7 +302,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):
compiler_args = kwargs.get("compiler_args", None) # `None` will take default.
build_type = kwargs.get("build_type", "RelWithDebInfo")
on_gpu = kwargs.get("on_gpu", False)
- auto_optimize = kwargs.get("auto_optimize", False)
+ auto_optimize = kwargs.get("auto_optimize", True)
lift_mode = kwargs.get("lift_mode", itir_transforms.LiftMode.FORCE_INLINE)
# ITIR parameters
column_axis = kwargs.get("column_axis", None)
diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py
index e03e3654a2..dc194c0436 100644
--- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py
+++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py
@@ -100,20 +100,17 @@ class ItirToSDFG(eve.NodeVisitor):
offset_provider: dict[str, Any]
node_types: dict[int, next_typing.Type]
unique_id: int
- use_gpu_storage: bool
def __init__(
self,
param_types: list[ts.TypeSpec],
offset_provider: dict[str, NeighborTableOffsetProvider],
column_axis: Optional[Dimension] = None,
- use_gpu_storage: bool = False,
):
self.param_types = param_types
self.column_axis = column_axis
self.offset_provider = offset_provider
self.storage_types = {}
- self.use_gpu_storage = use_gpu_storage
def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True):
if isinstance(type_, ts.FieldType):
@@ -124,14 +121,7 @@ def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset
else None
)
dtype = as_dace_type(type_.dtype)
- storage = (
- dace.dtypes.StorageType.GPU_Global
- if self.use_gpu_storage
- else dace.dtypes.StorageType.Default
- )
- sdfg.add_array(
- name, shape=shape, strides=strides, offset=offset, dtype=dtype, storage=storage
- )
+ sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype)
elif isinstance(type_, ts.ScalarType):
sdfg.add_symbol(name, as_dace_type(type_))
@@ -250,7 +240,6 @@ def visit_StencilClosure(
shape=array_table[name].shape,
strides=array_table[name].strides,
dtype=array_table[name].dtype,
- storage=array_table[name].storage,
transient=True,
)
closure_init_state.add_nedge(
@@ -265,7 +254,6 @@ def visit_StencilClosure(
shape=array_table[name].shape,
strides=array_table[name].strides,
dtype=array_table[name].dtype,
- storage=array_table[name].storage,
)
else:
assert isinstance(self.storage_types[name], ts.ScalarType)
diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py
index 88a8347fe4..12649bf620 100644
--- a/src/gt4py/next/type_system/type_translation.py
+++ b/src/gt4py/next/type_system/type_translation.py
@@ -184,7 +184,7 @@ def from_value(value: Any) -> ts.TypeSpec:
elif isinstance(value, common.Dimension):
symbol_type = ts.DimensionType(dim=value)
elif common.is_field(value):
- dims = list(value.__gt_dims__)
+ dims = list(value.domain.dims)
dtype = from_type_hint(value.dtype.scalar_type)
symbol_type = ts.FieldType(dims=dims, dtype=dtype)
elif isinstance(value, tuple):
diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py
index 0f7cf5d0ab..4e7ebb0c21 100644
--- a/src/gt4py/storage/cartesian/utils.py
+++ b/src/gt4py/storage/cartesian/utils.py
@@ -192,6 +192,10 @@ def cpu_copy(array: Union[np.ndarray, "cp.ndarray"]) -> np.ndarray:
def asarray(
array: FieldLike, *, device: Literal["cpu", "gpu", None] = None
) -> np.ndarray | cp.ndarray:
+ if hasattr(array, "ndarray"):
+ # extract the buffer from a gt4py.next.Field
+ # TODO(havogt): probably `Field` should provide the array interface methods when applicable
+ array = array.ndarray
if device == "gpu" or (not device and hasattr(array, "__cuda_array_interface__")):
return cp.asarray(array)
if device == "cpu" or (
diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py
index e580333bc8..8cfff12df4 100644
--- a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py
+++ b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py
@@ -312,7 +312,7 @@ def test_symbolref_validation_for_valid_tree():
SymbolTableRootNode(
nodes=[SymbolChildNode(name="foo"), SymbolRefChildNode(name="foo")],
)
- SymbolTableRootNode(
+ SymbolTableRootNode( # noqa: B018
nodes=[
SymbolChildNode(name="foo"),
SymbolRefChildNode(name="foo"),
diff --git a/tests/eve_tests/unit_tests/test_datamodels.py b/tests/eve_tests/unit_tests/test_datamodels.py
index 8fa9e02cb6..0abb893dd4 100644
--- a/tests/eve_tests/unit_tests/test_datamodels.py
+++ b/tests/eve_tests/unit_tests/test_datamodels.py
@@ -15,6 +15,7 @@
from __future__ import annotations
import enum
+import numbers
import types
import typing
from typing import Set # noqa: F401 # imported but unused (used in exec() context)
@@ -1150,66 +1151,80 @@ class PartialGenericModel(datamodels.GenericDataModel, Generic[T]):
with pytest.raises(TypeError, match="'PartialGenericModel__int.value'"):
PartialGenericModel__int(value=["1"])
- def test_partial_specialization(self):
- class PartialGenericModel(datamodels.GenericDataModel, Generic[T, U]):
+ def test_partial_concretization(self):
+ class BaseGenericModel(datamodels.GenericDataModel, Generic[T, U]):
value: List[Tuple[T, U]]
- PartialGenericModel(value=[])
- PartialGenericModel(value=[("value", 3)])
- PartialGenericModel(value=[(1, "value")])
- PartialGenericModel(value=[(-1.0, "value")])
- with pytest.raises(TypeError, match="'PartialGenericModel.value'"):
- PartialGenericModel(value=1)
- with pytest.raises(TypeError, match="'PartialGenericModel.value'"):
- PartialGenericModel(value=(1, 2))
- with pytest.raises(TypeError, match="'PartialGenericModel.value'"):
- PartialGenericModel(value=[()])
- with pytest.raises(TypeError, match="'PartialGenericModel.value'"):
- PartialGenericModel(value=[(1,)])
+ assert len(BaseGenericModel.__parameters__) == 2
+
+ BaseGenericModel(value=[])
+ BaseGenericModel(value=[("value", 3)])
+ BaseGenericModel(value=[(1, "value")])
+ BaseGenericModel(value=[(-1.0, "value")])
+ with pytest.raises(TypeError, match="'BaseGenericModel.value'"):
+ BaseGenericModel(value=1)
+ with pytest.raises(TypeError, match="'BaseGenericModel.value'"):
+ BaseGenericModel(value=(1, 2))
+ with pytest.raises(TypeError, match="'BaseGenericModel.value'"):
+ BaseGenericModel(value=[()])
+ with pytest.raises(TypeError, match="'BaseGenericModel.value'"):
+ BaseGenericModel(value=[(1,)])
+
+ PartiallyConcretizedGenericModel = BaseGenericModel[int, U]
+
+ assert len(PartiallyConcretizedGenericModel.__parameters__) == 1
+
+ PartiallyConcretizedGenericModel(value=[])
+ PartiallyConcretizedGenericModel(value=[(1, 2)])
+ PartiallyConcretizedGenericModel(value=[(1, "value")])
+ PartiallyConcretizedGenericModel(value=[(1, (11, 12))])
+ with pytest.raises(TypeError, match=".value'"):
+ PartiallyConcretizedGenericModel(value=1)
+ with pytest.raises(TypeError, match=".value'"):
+ PartiallyConcretizedGenericModel(value=(1, 2))
+ with pytest.raises(TypeError, match=".value'"):
+ PartiallyConcretizedGenericModel(value=[1.0])
+ with pytest.raises(TypeError, match=".value'"):
+ PartiallyConcretizedGenericModel(value=["1"])
- print(f"{PartialGenericModel.__parameters__=}")
- print(f"{hasattr(PartialGenericModel ,'__args__')=}")
+ FullyConcretizedGenericModel = PartiallyConcretizedGenericModel[str]
- PartiallySpecializedGenericModel = PartialGenericModel[int, U]
- print(f"{PartiallySpecializedGenericModel.__datamodel_fields__=}")
- print(f"{PartiallySpecializedGenericModel.__parameters__=}")
- print(f"{PartiallySpecializedGenericModel.__args__=}")
+ assert len(FullyConcretizedGenericModel.__parameters__) == 0
- PartiallySpecializedGenericModel(value=[])
- PartiallySpecializedGenericModel(value=[(1, 2)])
- PartiallySpecializedGenericModel(value=[(1, "value")])
- PartiallySpecializedGenericModel(value=[(1, (11, 12))])
+ FullyConcretizedGenericModel(value=[])
+ FullyConcretizedGenericModel(value=[(1, "value")])
+ with pytest.raises(TypeError, match=".value'"):
+ FullyConcretizedGenericModel(value=1)
+ with pytest.raises(TypeError, match=".value'"):
+ FullyConcretizedGenericModel(value=(1, 2))
with pytest.raises(TypeError, match=".value'"):
- PartiallySpecializedGenericModel(value=1)
+ FullyConcretizedGenericModel(value=[1.0])
with pytest.raises(TypeError, match=".value'"):
- PartiallySpecializedGenericModel(value=(1, 2))
+ FullyConcretizedGenericModel(value=["1"])
with pytest.raises(TypeError, match=".value'"):
- PartiallySpecializedGenericModel(value=[1.0])
+ FullyConcretizedGenericModel(value=1)
with pytest.raises(TypeError, match=".value'"):
- PartiallySpecializedGenericModel(value=["1"])
-
- # TODO(egparedes): after fixing partial nested datamodel specialization
- # noqa: e800 FullySpecializedGenericModel = PartiallySpecializedGenericModel[str]
- # noqa: e800 print(f"{FullySpecializedGenericModel.__datamodel_fields__=}")
- # noqa: e800 print(f"{FullySpecializedGenericModel.__parameters__=}")
- # noqa: e800 print(f"{FullySpecializedGenericModel.__args__=}")
-
- # noqa: e800 FullySpecializedGenericModel(value=[])
- # noqa: e800 FullySpecializedGenericModel(value=[(1, "value")])
- # noqa: e800 with pytest.raises(TypeError, match=".value'"):
- # noqa: e800 FullySpecializedGenericModel(value=1)
- # noqa: e800 with pytest.raises(TypeError, match=".value'"):
- # noqa: e800 FullySpecializedGenericModel(value=(1, 2))
- # noqa: e800 with pytest.raises(TypeError, match=".value'"):
- # noqa: e800 FullySpecializedGenericModel(value=[1.0])
- # noqa: e800 with pytest.raises(TypeError, match=".value'"):
- # noqa: e800 FullySpecializedGenericModel(value=["1"])
- # noqa: e800 with pytest.raises(TypeError, match=".value'"):
- # noqa: e800 FullySpecializedGenericModel(value=1)
- # noqa: e800 with pytest.raises(TypeError, match=".value'"):
- # noqa: e800 FullySpecializedGenericModel(value=[(1, 2)])
- # noqa: e800 with pytest.raises(TypeError, match=".value'"):
- # noqa: e800 FullySpecializedGenericModel(value=[(1, (11, 12))])
+ FullyConcretizedGenericModel(value=[(1, 2)])
+ with pytest.raises(TypeError, match=".value'"):
+ FullyConcretizedGenericModel(value=[(1, (11, 12))])
+
+ def test_partial_concretization_with_typevar(self):
+ class PartialGenericModel(datamodels.GenericDataModel, Generic[T]):
+ a: T
+ values: List[T]
+
+ B = TypeVar("B", bound=numbers.Number)
+ PartiallyConcretizedGenericModel = PartialGenericModel[B]
+
+ PartiallyConcretizedGenericModel(a=1, values=[2, 3])
+ PartiallyConcretizedGenericModel(a=-1.32, values=[2.2, 3j])
+
+ with pytest.raises(TypeError, match=".a'"):
+ PartiallyConcretizedGenericModel(a="1", values=[2, 3])
+ with pytest.raises(TypeError, match=".values'"):
+ PartiallyConcretizedGenericModel(a=1, values=[1, "2"])
+ with pytest.raises(TypeError, match=".values'"):
+ PartiallyConcretizedGenericModel(a=1, values=(1, 2))
# Reuse sample_type_data from test_field_type_hint
@pytest.mark.parametrize(["type_hint", "valid_values", "wrong_values"], SAMPLE_TYPE_DATA)
diff --git a/tests/eve_tests/unit_tests/test_type_validation.py b/tests/eve_tests/unit_tests/test_type_validation.py
index 70ef033ff0..d9977f0d3a 100644
--- a/tests/eve_tests/unit_tests/test_type_validation.py
+++ b/tests/eve_tests/unit_tests/test_type_validation.py
@@ -28,6 +28,7 @@
)
from gt4py.eve.extended_typing import (
Any,
+ Callable,
Dict,
Final,
ForwardRef,
@@ -41,8 +42,8 @@
)
-VALIDATORS: Final = [type_val.simple_type_validator]
-FACTORIES: Final = [type_val.simple_type_validator_factory]
+VALIDATORS: Final[list[Callable]] = [type_val.simple_type_validator]
+FACTORIES: Final[list[Callable]] = [type_val.simple_type_validator_factory]
class SampleEnum(enum.Enum):
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py
new file mode 100644
index 0000000000..0de953d85f
--- /dev/null
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py
@@ -0,0 +1,64 @@
+# -*- coding: utf-8 -*-
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import numpy as np
+
+import gt4py.next as gtx
+from gt4py.next import int32
+
+from next_tests.integration_tests import cases
+from next_tests.integration_tests.cases import cartesian_case
+from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
+ fieldview_backend,
+ reduction_setup,
+)
+
+
+def test_with_bound_args(cartesian_case):
+ @gtx.field_operator
+ def fieldop_bound_args(a: cases.IField, scalar: int32, condition: bool) -> cases.IField:
+ if not condition:
+ scalar = 0
+ return a + scalar
+
+ @gtx.program
+ def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cases.IField):
+ fieldop_bound_args(a, scalar, condition, out=out)
+
+ a = cases.allocate(cartesian_case, program_bound_args, "a")()
+ scalar = int32(1)
+ ref = a + scalar
+ out = cases.allocate(cartesian_case, program_bound_args, "out")()
+
+ prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True)
+ cases.verify(cartesian_case, prog_bounds, a, out, inout=out, ref=ref)
+
+
+def test_with_bound_args_order_args(cartesian_case):
+ @gtx.field_operator
+ def fieldop_args(a: cases.IField, condition: bool, scalar: int32) -> cases.IField:
+ scalar = 0 if not condition else scalar
+ return a + scalar
+
+ @gtx.program(backend=cartesian_case.backend)
+ def program_args(a: cases.IField, condition: bool, scalar: int32, out: cases.IField):
+ fieldop_args(a, condition, scalar, out=out)
+
+ a = cases.allocate(cartesian_case, program_args, "a")()
+ out = cases.allocate(cartesian_case, program_args, "out")()
+
+ prog_bounds = program_args.with_bound_args(condition=True)
+ prog_bounds(a=a, scalar=int32(1), out=out, offset_provider={})
+ np.allclose(out.asnumpy(), a.asnumpy() + int32(1))
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py
index a08931628b..70c79d7b6c 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py
@@ -898,26 +898,6 @@ def test_docstring(a: cases.IField):
cases.verify(cartesian_case, test_docstring, a, inout=a, ref=a)
-def test_with_bound_args(cartesian_case):
- @gtx.field_operator
- def fieldop_bound_args(a: cases.IField, scalar: int32, condition: bool) -> cases.IField:
- if not condition:
- scalar = 0
- return a + a + scalar
-
- @gtx.program
- def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cases.IField):
- fieldop_bound_args(a, scalar, condition, out=out)
-
- a = cases.allocate(cartesian_case, program_bound_args, "a")()
- scalar = int32(1)
- ref = a + a + 1
- out = cases.allocate(cartesian_case, program_bound_args, "out")()
-
- prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True)
- cases.verify(cartesian_case, prog_bounds, a, out, inout=out, ref=ref)
-
-
def test_domain(cartesian_case):
@gtx.field_operator
def fieldop_domain(a: cases.IField) -> cases.IField:
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py
index 698dce2b5c..d100cd380c 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py
@@ -30,16 +30,6 @@
def test_external_local_field(unstructured_case):
- # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15
- try:
- from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu
-
- if unstructured_case.backend == run_dace_gpu:
- # see https://github.com/spcl/dace/pull/1442
- pytest.xfail("requires fix in dace module for cuda codegen")
- except ImportError:
- pass
-
@gtx.field_operator
def testee(
inp: gtx.Field[[Vertex, V2EDim], int32], ones: gtx.Field[[Edge], int32]
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py
index e8d0c8b163..e2434d860a 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py
@@ -46,16 +46,6 @@
ids=["positive_values", "negative_values"],
)
def test_maxover_execution_(unstructured_case, strategy):
- # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15
- try:
- from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu
-
- if unstructured_case.backend == run_dace_gpu:
- # see https://github.com/spcl/dace/pull/1442
- pytest.xfail("requires fix in dace module for cuda codegen")
- except ImportError:
- pass
-
if unstructured_case.backend in [
gtfn.run_gtfn,
gtfn.run_gtfn_gpu,
@@ -79,16 +69,6 @@ def testee(edge_f: cases.EField) -> cases.VField:
@pytest.mark.uses_unstructured_shift
def test_minover_execution(unstructured_case):
- # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15
- try:
- from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu
-
- if unstructured_case.backend == run_dace_gpu:
- # see https://github.com/spcl/dace/pull/1442
- pytest.xfail("requires fix in dace module for cuda codegen")
- except ImportError:
- pass
-
@gtx.field_operator
def minover(edge_f: cases.EField) -> cases.VField:
out = min_over(edge_f(V2E), axis=V2EDim)
@@ -102,16 +82,6 @@ def minover(edge_f: cases.EField) -> cases.VField:
@pytest.mark.uses_unstructured_shift
def test_reduction_execution(unstructured_case):
- # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15
- try:
- from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu
-
- if unstructured_case.backend == run_dace_gpu:
- # see https://github.com/spcl/dace/pull/1442
- pytest.xfail("requires fix in dace module for cuda codegen")
- except ImportError:
- pass
-
@gtx.field_operator
def reduction(edge_f: cases.EField) -> cases.VField:
return neighbor_sum(edge_f(V2E), axis=V2EDim)
@@ -150,16 +120,6 @@ def fencil(edge_f: cases.EField, out: cases.VField):
@pytest.mark.uses_unstructured_shift
def test_reduction_with_common_expression(unstructured_case):
- # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15
- try:
- from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu
-
- if unstructured_case.backend == run_dace_gpu:
- # see https://github.com/spcl/dace/pull/1442
- pytest.xfail("requires fix in dace module for cuda codegen")
- except ImportError:
- pass
-
@gtx.field_operator
def testee(flux: cases.EField) -> cases.VField:
return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim)
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py
index c86881ab7c..938c69fb52 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py
@@ -20,6 +20,7 @@
import pytest
import gt4py.next as gtx
+from gt4py.next import errors
from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import IDim, Ioff, JDim, cartesian_case, fieldview_backend
@@ -222,7 +223,7 @@ def test_wrong_argument_type(cartesian_case, copy_program_def):
inp = cartesian_case.as_field([JDim], np.ones((cartesian_case.default_sizes[JDim],)))
out = cases.allocate(cartesian_case, copy_program, "out").strategy(cases.ConstInitializer(1))()
- with pytest.raises(TypeError) as exc_info:
+ with pytest.raises(errors.DSLError) as exc_info:
# program is defined on Field[[IDim], ...], but we call with
# Field[[JDim], ...]
copy_program(inp, out, offset_provider={})
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py
new file mode 100644
index 0000000000..788081b81e
--- /dev/null
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py
@@ -0,0 +1,119 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import pytest
+from numpy import int32, int64
+
+from gt4py import next as gtx
+from gt4py.next import common
+from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms
+from gt4py.next.program_processors import otf_compile_executor
+from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries
+
+from next_tests.integration_tests import cases
+from next_tests.integration_tests.cases import (
+ E2V,
+ Case,
+ KDim,
+ Vertex,
+ cartesian_case,
+ unstructured_case,
+)
+from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
+ reduction_setup,
+)
+from next_tests.toy_connectivity import Cell, Edge
+
+
+@pytest.fixture
+def run_gtfn_with_temporaries_and_symbolic_sizes():
+ return otf_compile_executor.OTFBackend(
+ executor=otf_compile_executor.OTFCompileExecutor(
+ name="run_gtfn_with_temporaries_and_sizes",
+ otf_workflow=run_gtfn_with_temporaries.executor.otf_workflow.replace(
+ translation=run_gtfn_with_temporaries.executor.otf_workflow.translation.replace(
+ symbolic_domain_sizes={
+ "Cell": "num_cells",
+ "Edge": "num_edges",
+ "Vertex": "num_vertices",
+ },
+ ),
+ ),
+ ),
+ allocator=run_gtfn_with_temporaries.allocator,
+ )
+
+
+@pytest.fixture
+def testee():
+ @gtx.field_operator
+ def testee_op(a: cases.VField) -> cases.EField:
+ amul = a * 2
+ return amul(E2V[0]) + amul(E2V[1])
+
+ @gtx.program
+ def prog(
+ a: cases.VField,
+ out: cases.EField,
+ num_vertices: int32,
+ num_edges: int64,
+ num_cells: int32,
+ ):
+ testee_op(a, out=out)
+
+ return prog
+
+
+def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, reduction_setup):
+ unstructured_case = Case(
+ run_gtfn_with_temporaries_and_symbolic_sizes,
+ offset_provider=reduction_setup.offset_provider,
+ default_sizes={
+ Vertex: reduction_setup.num_vertices,
+ Edge: reduction_setup.num_edges,
+ Cell: reduction_setup.num_cells,
+ KDim: reduction_setup.k_levels,
+ },
+ grid_type=common.GridType.UNSTRUCTURED,
+ )
+
+ a = cases.allocate(unstructured_case, testee, "a")()
+ out = cases.allocate(unstructured_case, testee, "out")()
+
+ first_nbs, second_nbs = (reduction_setup.offset_provider["E2V"].table[:, i] for i in [0, 1])
+ ref = (a.ndarray * 2)[first_nbs] + (a.ndarray * 2)[second_nbs]
+
+ cases.verify(
+ unstructured_case,
+ testee,
+ a,
+ out,
+ reduction_setup.num_vertices,
+ reduction_setup.num_edges,
+ reduction_setup.num_cells,
+ inout=out,
+ ref=ref,
+ )
+
+
+def test_temporary_symbols(testee, reduction_setup):
+ itir_with_tmp = apply_common_transforms(
+ testee.itir,
+ lift_mode=LiftMode.FORCE_TEMPORARIES,
+ offset_provider=reduction_setup.offset_provider,
+ )
+
+ params = ["num_vertices", "num_edges", "num_cells"]
+ for param in params:
+ assert any([param == str(p) for p in itir_with_tmp.fencil.params])
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py
index e851e7b130..5af4605988 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py
@@ -18,7 +18,7 @@
from gt4py.next.iterator.builtins import *
from gt4py.next.iterator.runtime import closure, fundef, offset
from gt4py.next.iterator.tracing import trace_fencil_definition
-from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate
+from gt4py.next.program_processors.runners.gtfn import run_gtfn
@fundef
@@ -69,7 +69,9 @@ def lap_fencil(i_size, j_size, k_size, i_off, j_off, k_off, out, inp):
output_file = sys.argv[1]
prog = trace_fencil_definition(lap_fencil, [None] * 8, use_arg_types=False)
- generated_code = generate(prog, offset_provider={"i": IDim, "j": JDim})
+ generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source(
+ prog, offset_provider={"i": IDim, "j": JDim}, column_axis=None
+ )
with open(output_file, "w+") as output:
output.write(generated_code)
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py
index 33c7d5baa7..3e8b88ac66 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py
@@ -18,7 +18,7 @@
from gt4py.next.iterator.builtins import *
from gt4py.next.iterator.runtime import closure, fundef
from gt4py.next.iterator.tracing import trace_fencil_definition
-from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate
+from gt4py.next.program_processors.runners.gtfn import run_gtfn
IDim = gtx.Dimension("IDim")
@@ -48,7 +48,9 @@ def copy_fencil(isize, jsize, ksize, inp, out):
output_file = sys.argv[1]
prog = trace_fencil_definition(copy_fencil, [None] * 5, use_arg_types=False)
- generated_code = generate(prog, offset_provider={})
+ generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source(
+ prog, offset_provider={}, column_axis=None
+ )
with open(output_file, "w+") as output:
output.write(generated_code)
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py
index f7472d4ac3..fdc57449ee 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py
@@ -18,7 +18,7 @@
import gt4py.next as gtx
from gt4py.next import Field, field_operator, program
-from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate
+from gt4py.next.program_processors.runners.gtfn import run_gtfn
IDim = gtx.Dimension("IDim")
@@ -47,7 +47,9 @@ def copy_program(
output_file = sys.argv[1]
prog = copy_program.itir
- generated_code = generate(prog, offset_provider={})
+ generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source(
+ prog, offset_provider={}, column_axis=None
+ )
with open(output_file, "w+") as output:
output.write(generated_code)
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py
index 1dfd74baca..abc3755dca 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py
@@ -19,7 +19,7 @@
from gt4py.next.iterator.builtins import *
from gt4py.next.iterator.runtime import closure, fundef, offset
from gt4py.next.iterator.tracing import trace_fencil_definition
-from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate
+from gt4py.next.program_processors.runners.gtfn import run_gtfn, run_gtfn_imperative
E2V = offset("E2V")
@@ -92,13 +92,20 @@ def mapped_index(_, __) -> int:
output_file = sys.argv[1]
imperative = sys.argv[2].lower() == "true"
+ if imperative:
+ backend = run_gtfn_imperative
+ else:
+ backend = run_gtfn
+
# prog = trace(zavgS_fencil, [None] * 4) # TODO allow generating of 2 fencils
prog = trace_fencil_definition(nabla_fencil, [None] * 7, use_arg_types=False)
offset_provider = {
"V2E": DummyConnectivity(max_neighbors=6, has_skip_values=True),
"E2V": DummyConnectivity(max_neighbors=2, has_skip_values=False),
}
- generated_code = generate(prog, offset_provider=offset_provider, imperative=imperative)
+ generated_code = backend.executor.otf_workflow.translation.generate_stencil_source(
+ prog, offset_provider=offset_provider, column_axis=None
+ )
with open(output_file, "w+") as output:
output.write(generated_code)
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py
index 578a19faab..9755774fd0 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py
@@ -19,7 +19,7 @@
from gt4py.next.iterator.runtime import closure, fundef
from gt4py.next.iterator.tracing import trace_fencil_definition
from gt4py.next.iterator.transforms import LiftMode
-from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate
+from gt4py.next.program_processors.runners.gtfn import run_gtfn
IDim = gtx.Dimension("IDim")
@@ -67,10 +67,10 @@ def tridiagonal_solve_fencil(isize, jsize, ksize, a, b, c, d, x):
prog = trace_fencil_definition(tridiagonal_solve_fencil, [None] * 8, use_arg_types=False)
offset_provider = {"I": gtx.Dimension("IDim"), "J": gtx.Dimension("JDim")}
- generated_code = generate(
+ generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source(
prog,
offset_provider=offset_provider,
- lift_mode=LiftMode.SIMPLE_HEURISTIC,
+ runtime_lift_mode=LiftMode.SIMPLE_HEURISTIC,
column_axis=KDim,
)
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py
index 130f6bd29c..5bd255f80f 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py
@@ -12,7 +12,7 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
-from dataclasses import dataclass
+import dataclasses
import numpy as np
import pytest
@@ -201,22 +201,26 @@ def test_setup(fieldview_backend):
grid_type=common.GridType.UNSTRUCTURED,
)
- @dataclass(frozen=True)
+ @dataclasses.dataclass(frozen=True)
class setup:
- case: cases.Case = test_case
- cell_size = case.default_sizes[Cell]
- k_size = case.default_sizes[KDim]
- z_alpha = case.as_field(
+ case: cases.Case = dataclasses.field(default_factory=lambda: test_case)
+ cell_size = test_case.default_sizes[Cell]
+ k_size = test_case.default_sizes[KDim]
+ z_alpha = test_case.as_field(
[Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size + 1))
)
- z_beta = case.as_field(
+ z_beta = test_case.as_field(
+ [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))
+ )
+ z_q = test_case.as_field(
+ [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))
+ )
+ w = test_case.as_field(
[Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))
)
- z_q = case.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)))
- w = case.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)))
z_q_ref, w_ref = reference(z_alpha.ndarray, z_beta.ndarray, z_q.ndarray, w.ndarray)
- dummy = case.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool))
- z_q_out = case.as_field([Cell, KDim], np.zeros((cell_size, k_size)))
+ dummy = test_case.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool))
+ z_q_out = test_case.as_field([Cell, KDim], np.zeros((cell_size, k_size)))
return setup()
diff --git a/tests/next_tests/unit_tests/ffront_tests/ast_passes_tests/test_single_static_assign.py b/tests/next_tests/unit_tests/ffront_tests/ast_passes_tests/test_single_static_assign.py
index 052f272d22..ea1cdb82a6 100644
--- a/tests/next_tests/unit_tests/ffront_tests/ast_passes_tests/test_single_static_assign.py
+++ b/tests/next_tests/unit_tests/ffront_tests/ast_passes_tests/test_single_static_assign.py
@@ -108,7 +108,10 @@ def test_unpacking_swap():
lines = ast.unparse(ssa_ast).split("\n")
assert lines[0] == f"a{SEP}0 = 5"
assert lines[1] == f"b{SEP}0 = 1"
- assert lines[2] == f"(b{SEP}1, a{SEP}1) = (a{SEP}0, b{SEP}0)"
+ assert lines[2] in [
+ f"(b{SEP}1, a{SEP}1) = (a{SEP}0, b{SEP}0)",
+ f"b{SEP}1, a{SEP}1 = (a{SEP}0, b{SEP}0)",
+ ] # unparse produces different parentheses in different Python versions
def test_annotated_assign():
diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py
index 86c3c98c62..5c2802f90c 100644
--- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py
+++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py
@@ -323,7 +323,7 @@ def test_update_cartesian_domains():
for a, s in (("JDim", "j"), ("KDim", "k"))
],
)
- actual = update_domains(testee, {"I": gtx.Dimension("IDim")})
+ actual = update_domains(testee, {"I": gtx.Dimension("IDim")}, symbolic_sizes=None)
assert actual == expected
diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_power_unrolling.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_power_unrolling.py
new file mode 100644
index 0000000000..ae23becb4c
--- /dev/null
+++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_power_unrolling.py
@@ -0,0 +1,161 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import pytest
+
+from gt4py.eve import SymbolRef
+from gt4py.next.iterator import ir
+from gt4py.next.iterator.ir_utils import ir_makers as im
+from gt4py.next.iterator.transforms.power_unrolling import PowerUnrolling
+
+
+def test_power_unrolling_zero():
+ pytest.xfail(
+ "Not implementeds we don't have an easy way to determine the type of the one literal (type inference is to expensive)."
+ )
+ testee = im.call("power")("x", 0)
+ expected = im.literal_from_value(1)
+
+ actual = PowerUnrolling.apply(testee)
+ assert actual == expected
+
+
+def test_power_unrolling_one():
+ testee = im.call("power")("x", 1)
+ expected = ir.SymRef(id=SymbolRef("x"))
+
+ actual = PowerUnrolling.apply(testee)
+ assert actual == expected
+
+
+def test_power_unrolling_two():
+ testee = im.call("power")("x", 2)
+ expected = im.multiplies_("x", "x")
+
+ actual = PowerUnrolling.apply(testee)
+ assert actual == expected
+
+
+def test_power_unrolling_two_x_plus_two():
+ testee = im.call("power")(im.plus("x", 2), 2)
+ expected = im.let("power_1", im.plus("x", 2))(
+ im.let("power_2", im.multiplies_("power_1", "power_1"))("power_2")
+ )
+
+ actual = PowerUnrolling.apply(testee)
+ assert actual == expected
+
+
+def test_power_unrolling_two_x_plus_one_times_three():
+ testee = im.call("power")(im.multiplies_(im.plus("x", 1), 3), 2)
+ expected = im.let("power_1", im.multiplies_(im.plus("x", 1), 3))(
+ im.let("power_2", im.multiplies_("power_1", "power_1"))("power_2")
+ )
+
+ actual = PowerUnrolling.apply(testee)
+ assert actual == expected
+
+
+def test_power_unrolling_three():
+ testee = im.call("power")("x", 3)
+ expected = im.multiplies_(im.multiplies_("x", "x"), "x")
+
+ actual = PowerUnrolling.apply(testee)
+ assert actual == expected
+
+
+def test_power_unrolling_four():
+ testee = im.call("power")("x", 4)
+ expected = im.let("power_2", im.multiplies_("x", "x"))(im.multiplies_("power_2", "power_2"))
+
+ actual = PowerUnrolling.apply(testee)
+ assert actual == expected
+
+
+def test_power_unrolling_five():
+ testee = im.call("power")("x", 5)
+ tmp2 = im.multiplies_("x", "x")
+ expected = im.multiplies_(im.multiplies_(tmp2, tmp2), "x")
+ expected = im.let("power_2", im.multiplies_("x", "x"))(
+ im.multiplies_(im.multiplies_("power_2", "power_2"), "x")
+ )
+
+ actual = PowerUnrolling.apply(testee)
+ assert actual == expected
+
+
+def test_power_unrolling_seven():
+ testee = im.call("power")("x", 7)
+ expected = im.call("power")("x", 7)
+
+ actual = PowerUnrolling.apply(testee, max_unroll=5)
+ assert actual == expected
+
+
+def test_power_unrolling_seven_unrolled():
+ testee = im.call("power")("x", 7)
+ expected = im.let("power_2", im.multiplies_("x", "x"))(
+ im.multiplies_(im.multiplies_(im.multiplies_("power_2", "power_2"), "power_2"), "x")
+ )
+
+ actual = PowerUnrolling.apply(testee, max_unroll=7)
+ assert actual == expected
+
+
+def test_power_unrolling_seven_x_plus_one_unrolled():
+ testee = im.call("power")(im.plus("x", 1), 7)
+ expected = im.let("power_1", im.plus("x", 1))(
+ im.let("power_2", im.multiplies_("power_1", "power_1"))(
+ im.let("power_4", im.multiplies_("power_2", "power_2"))(
+ im.multiplies_(im.multiplies_("power_4", "power_2"), "power_1")
+ )
+ )
+ )
+
+ actual = PowerUnrolling.apply(testee, max_unroll=7)
+ assert actual == expected
+
+
+def test_power_unrolling_eight():
+ testee = im.call("power")("x", 8)
+ expected = im.call("power")("x", 8)
+
+ actual = PowerUnrolling.apply(testee, max_unroll=5)
+ assert actual == expected
+
+
+def test_power_unrolling_eight_unrolled():
+ testee = im.call("power")("x", 8)
+ expected = im.let("power_2", im.multiplies_("x", "x"))(
+ im.let("power_4", im.multiplies_("power_2", "power_2"))(
+ im.multiplies_("power_4", "power_4")
+ )
+ )
+
+ actual = PowerUnrolling.apply(testee, max_unroll=8)
+ assert actual == expected
+
+
+def test_power_unrolling_eight_x_plus_one_unrolled():
+ testee = im.call("power")(im.plus("x", 1), 8)
+ expected = im.let("power_1", im.plus("x", 1))(
+ im.let("power_2", im.multiplies_("power_1", "power_1"))(
+ im.let("power_4", im.multiplies_("power_2", "power_2"))(
+ im.let("power_8", im.multiplies_("power_4", "power_4"))("power_8")
+ )
+ )
+ )
+
+ actual = PowerUnrolling.apply(testee, max_unroll=8)
+ assert actual == expected
diff --git a/tox.ini b/tox.ini
index 44dc912c8a..817f721f71 100644
--- a/tox.ini
+++ b/tox.ini
@@ -11,21 +11,24 @@ envlist =
# docs
labels =
test-cartesian-cpu = cartesian-py38-internal-cpu, cartesian-py39-internal-cpu, cartesian-py310-internal-cpu, \
- cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu
+ cartesian-py311-internal-cpu, cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu, \
+ cartesian-py311-dace-cpu
- test-eve-cpu = eve-py38, eve-py39, eve-py310
+ test-eve-cpu = eve-py38, eve-py39, eve-py310, eve-py311
- test-next-cpu = next-py310-nomesh, next-py310-atlas
+ test-next-cpu = next-py310-nomesh, next-py311-nomesh, next-py310-atlas, next-py311-atlas
test-storage-cpu = storage-py38-internal-cpu, storage-py39-internal-cpu, storage-py310-internal-cpu, \
- storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu
+ storage-py311-internal-cpu, storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu, \
+ storage-py311-dace-cpu
test-cpu = cartesian-py38-internal-cpu, cartesian-py39-internal-cpu, cartesian-py310-internal-cpu, \
- cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu, \
- eve-py38, eve-py39, eve-py310, \
- next-py310-nomesh, next-py310-atlas, \
- storage-py38-internal-cpu, storage-py39-internal-cpu, storage-py310-internal-cpu, \
- storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu
+ cartesian-py311-internal-cpu, cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu, \
+ cartesian-py311-dace-cpu, \
+ eve-py38, eve-py39, eve-py310, eve-py311, \
+ next-py310-nomesh, next-py311-nomesh, next-py310-atlas, next-py311-atlas, \
+ storage-py38-internal-cpu, storage-py39-internal-cpu, storage-py310-internal-cpu, storage-py311-internal-cpu, \
+ storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu, storage-py311-dace-cpu
[testenv]
deps = -r {tox_root}{/}{env:ENV_REQUIREMENTS_FILE:requirements-dev.txt}
@@ -44,7 +47,7 @@ pass_env = NUM_PROCESSES
set_env =
PYTHONWARNINGS = {env:PYTHONWARNINGS:ignore:Support for `[tool.setuptools]` in `pyproject.toml` is still *beta*:UserWarning}
-[testenv:cartesian-py{38,39,310}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}]
+[testenv:cartesian-py{38,39,310,311}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}]
description = Run 'gt4py.cartesian' tests
pass_env = {[testenv]pass_env}, BOOST_ROOT, BOOST_HOME, CUDA_HOME, CUDA_PATH, CXX, CC, OPENMP_CPPFLAGS, OPENMP_LDFLAGS, PIP_USER, PYTHONUSERBASE
allowlist_externals =
@@ -65,13 +68,13 @@ commands =
; coverage json --rcfile=setup.cfg
; coverage html --rcfile=setup.cfg --show-contexts
-[testenv:eve-py{38,39,310}]
+[testenv:eve-py{38,39,310,311}]
description = Run 'gt4py.eve' tests
commands =
python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} {posargs} tests{/}eve_tests
python -m pytest --doctest-modules src{/}gt4py{/}eve
-[testenv:next-py{310}-{nomesh,atlas}-{cpu,cuda,cuda11x,cuda12x}]
+[testenv:next-py{310,311}-{nomesh,atlas}-{cpu,cuda,cuda11x,cuda12x}]
description = Run 'gt4py.next' tests
pass_env = {[testenv]pass_env}, BOOST_ROOT, BOOST_HOME, CUDA_HOME, CUDA_PATH
deps =
@@ -87,14 +90,14 @@ commands =
# atlas-{cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and requires_gpu" {posargs} tests{/}next_tests # TODO(ricoh): activate when such tests exist
pytest --doctest-modules src{/}gt4py{/}next
-[testenv:storage-py{38,39,310}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}]
+[testenv:storage-py{38,39,310,311}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}]
description = Run 'gt4py.storage' tests
commands =
cpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_gpu" {posargs} tests{/}storage_tests
{cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_gpu" {posargs} tests{/}storage_tests
#pytest doctest-modules {posargs} src{/}gt4py{/}storage
-[testenv:linters-py{38,39,310}]
+[testenv:linters-py{38,39,310,311}]
description = Run linters
commands =
flake8 .{/}src
@@ -134,11 +137,13 @@ description =
py38: Update requirements for testing a specific python version
py39: Update requirements for testing a specific python version
py310: Update requirements for testing a specific python version
+ py311: Update requirements for testing a specific python version
base_python =
common: py38
py38: py38
py39: py39
py310: py310
+ py311: py311
deps =
cogapp>=3.3
pip-tools>=6.10
@@ -178,7 +183,7 @@ commands =
# Run cog to update .pre-commit-config.yaml with new versions
common: cog -r -P .pre-commit-config.yaml
-[testenv:dev-py{38,39,310}{-atlas,}]
+[testenv:dev-py{38,39,310,311}{-atlas,}]
description = Initialize development environment for gt4py
deps =
-r {tox_root}{/}requirements-dev.txt