diff --git a/.github/workflows/daily-ci.yml b/.github/workflows/daily-ci.yml index 77ba39a361..8631390dbb 100644 --- a/.github/workflows/daily-ci.yml +++ b/.github/workflows/daily-ci.yml @@ -14,7 +14,7 @@ jobs: daily-ci: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] tox-module-factor: ["cartesian", "eve", "next", "storage"] os: ["ubuntu-latest"] requirements-file: ["requirements-dev.txt", "min-requirements-test.txt", "min-extra-requirements-test.txt"] diff --git a/.github/workflows/test-cartesian-fallback.yml b/.github/workflows/test-cartesian-fallback.yml index b2eaead47a..7e9a948e9c 100644 --- a/.github/workflows/test-cartesian-fallback.yml +++ b/.github/workflows/test-cartesian-fallback.yml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] backends: [internal-cpu, dace-cpu] steps: diff --git a/.github/workflows/test-cartesian.yml b/.github/workflows/test-cartesian.yml index 2c2b97aaa6..ebdc4ce749 100644 --- a/.github/workflows/test-cartesian.yml +++ b/.github/workflows/test-cartesian.yml @@ -23,7 +23,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] backends: [internal-cpu, dace-cpu] steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/test-eve-fallback.yml b/.github/workflows/test-eve-fallback.yml index 93dc308a53..fd7ab5452c 100644 --- a/.github/workflows/test-eve-fallback.yml +++ b/.github/workflows/test-eve-fallback.yml @@ -17,7 +17,7 @@ jobs: test-eve: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] os: ["ubuntu-latest"] runs-on: ${{ matrix.os }} diff --git a/.github/workflows/test-eve.yml b/.github/workflows/test-eve.yml index 1322c573db..222b825f38 100644 --- a/.github/workflows/test-eve.yml +++ b/.github/workflows/test-eve.yml @@ -20,7 +20,7 @@ jobs: test-eve: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] os: ["ubuntu-latest"] fail-fast: false @@ -68,4 +68,3 @@ jobs: # with: # name: info-py${{ matrix.python-version }}-${{ matrix.os }} # path: info.txt - diff --git a/.github/workflows/test-next-fallback.yml b/.github/workflows/test-next-fallback.yml index 8490a3e393..bdcc061db0 100644 --- a/.github/workflows/test-next-fallback.yml +++ b/.github/workflows/test-next-fallback.yml @@ -15,7 +15,7 @@ jobs: test-next: strategy: matrix: - python-version: ["3.10"] + python-version: ["3.10", "3.11"] tox-env-factor: ["nomesh", "atlas"] os: ["ubuntu-latest"] diff --git a/.github/workflows/test-next.yml b/.github/workflows/test-next.yml index 52f8c25386..4282a22da6 100644 --- a/.github/workflows/test-next.yml +++ b/.github/workflows/test-next.yml @@ -18,7 +18,7 @@ jobs: test-next: strategy: matrix: - python-version: ["3.10"] + python-version: ["3.10", "3.11"] tox-env-factor: ["nomesh", "atlas"] os: ["ubuntu-latest"] fail-fast: false diff --git a/.github/workflows/test-storage-fallback.yml b/.github/workflows/test-storage-fallback.yml index 0cbc735564..99e4923de8 100644 --- a/.github/workflows/test-storage-fallback.yml +++ b/.github/workflows/test-storage-fallback.yml @@ -18,7 +18,7 @@ jobs: test-storage: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] backends: [internal-cpu, dace-cpu] os: ["ubuntu-latest"] diff --git a/.github/workflows/test-storage.yml b/.github/workflows/test-storage.yml index 1133353f30..34841ed71c 100644 --- a/.github/workflows/test-storage.yml +++ b/.github/workflows/test-storage.yml @@ -21,7 +21,7 @@ jobs: test-storage: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] backends: [internal-cpu, dace-cpu] os: ["ubuntu-latest"] fail-fast: false @@ -70,4 +70,3 @@ jobs: # with: # name: info-py${{ matrix.python-version }}-${{ matrix.os }} # path: info.txt - diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b1092fafd0..d9cfa0ff48 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -62,7 +62,7 @@ repos: ## version = re.search('black==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: '{version}' # version from constraints.txt") ##]]] - rev: '23.9.1' # version from constraints.txt + rev: '23.11.0' # version from constraints.txt ##[[[end]]] hooks: - id: black @@ -73,7 +73,7 @@ repos: ## version = re.search('isort==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: '{version}' # version from constraints.txt") ##]]] - rev: '5.12.0' # version from constraints.txt + rev: '5.13.0' # version from constraints.txt ##[[[end]]] hooks: - id: isort @@ -97,14 +97,14 @@ repos: ## print(f"- {pkg}==" + str(re.search(f'\n{pkg}==([0-9\.]*)', constraints)[1])) ##]]] - darglint==1.8.1 - - flake8-bugbear==23.9.16 - - flake8-builtins==2.1.0 + - flake8-bugbear==23.12.2 + - flake8-builtins==2.2.0 - flake8-debugger==4.1.2 - flake8-docstrings==1.7.0 - flake8-eradicate==1.5.0 - flake8-mutable==1.2.0 - flake8-pyproject==1.2.3 - - pygments==2.16.1 + - pygments==2.17.2 ##[[[end]]] # - flake8-rst-docstrings # Disabled for now due to random false positives exclude: | @@ -146,9 +146,9 @@ repos: ## version = re.search('mypy==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"#========= FROM constraints.txt: v{version} =========") ##]]] - #========= FROM constraints.txt: v1.5.1 ========= + #========= FROM constraints.txt: v1.7.1 ========= ##[[[end]]] - rev: v1.5.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) + rev: v1.7.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) hooks: - id: mypy additional_dependencies: # versions from constraints.txt @@ -162,26 +162,26 @@ repos: ##]]] - astunparse==1.6.3 - attrs==23.1.0 - - black==23.9.1 - - boltons==23.0.0 + - black==23.11.0 + - boltons==23.1.1 - cached-property==1.5.2 - click==8.1.7 - - cmake==3.27.5 + - cmake==3.27.9 - cytoolz==0.12.2 - - deepdiff==6.5.0 + - deepdiff==6.7.1 - devtools==0.12.2 - - frozendict==2.3.8 + - frozendict==2.3.10 - gridtools-cpp==2.3.1 - - importlib-resources==6.0.1 + - importlib-resources==6.1.1 - jinja2==3.1.2 - - lark==1.1.7 - - mako==1.2.4 - - nanobind==1.5.2 - - ninja==1.11.1 + - lark==1.1.8 + - mako==1.3.0 + - nanobind==1.8.0 + - ninja==1.11.1.1 - numpy==1.24.4 - - packaging==23.1 + - packaging==23.2 - pybind11==2.11.1 - - setuptools==68.2.2 + - setuptools==69.0.2 - tabulate==0.9.0 - typing-extensions==4.5.0 - xxhash==3.0.0 diff --git a/AUTHORS.md b/AUTHORS.md index 89aafb9971..6c76e5759e 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -9,6 +9,7 @@ - Deconinck, Florian. SSAI/NASA-GSFC - Ehrengruber, Till. ETH Zurich - CSCS - Elbert, Oliver D. NOAA-GFDL +- Faghih-Naini, Sara. ECMWF - Farabullini, Nicoletta. ETH Zurich - C2SM - George, Rhea. Allen Institute for AI - González Paredes, Enrique. ETH Zurich - CSCS diff --git a/constraints.txt b/constraints.txt index b334851af1..81abd64c6e 100644 --- a/constraints.txt +++ b/constraints.txt @@ -6,124 +6,136 @@ # aenum==3.1.15 # via dace alabaster==0.7.13 # via sphinx -asttokens==2.4.0 # via devtools +asttokens==2.4.1 # via devtools astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing -babel==2.12.1 # via sphinx -black==23.9.1 # via gt4py (pyproject.toml) -blinker==1.6.2 # via flask -boltons==23.0.0 # via gt4py (pyproject.toml) +babel==2.13.1 # via sphinx +black==23.11.0 # via gt4py (pyproject.toml) +blinker==1.7.0 # via flask +boltons==23.1.1 # via gt4py (pyproject.toml) build==1.0.3 # via pip-tools cached-property==1.5.2 # via gt4py (pyproject.toml) -cachetools==5.3.1 # via tox -certifi==2023.7.22 # via requests -cffi==1.15.1 # via cryptography +cachetools==5.3.2 # via tox +cerberus==1.3.5 # via plette +certifi==2023.11.17 # via requests +cffi==1.16.0 # via cryptography cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox -charset-normalizer==3.2.0 # via requests -clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) +charset-normalizer==3.3.2 # via requests +clang-format==17.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.27.5 # via gt4py (pyproject.toml) +cmake==3.27.9 # via dace, gt4py (pyproject.toml) cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox -coverage==7.3.1 # via -r requirements-dev.in, pytest-cov -cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis +coverage==7.3.2 # via -r requirements-dev.in, pytest-cov +cryptography==41.0.7 # via types-paramiko, types-pyopenssl, types-redis cytoolz==0.12.2 # via gt4py (pyproject.toml) -dace==0.14.4 # via gt4py (pyproject.toml) +dace==0.15.1 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -deepdiff==6.5.0 # via gt4py (pyproject.toml) +deepdiff==6.7.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.7 # via dace -distlib==0.3.7 # via virtualenv -docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme +distlib==0.3.7 # via requirementslib, virtualenv +distro==1.8.0 # via scikit-build +docopt==0.6.2 # via pipreqs +docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme eradicate==2.3.0 # via flake8-eradicate -exceptiongroup==1.1.3 # via hypothesis, pytest +exceptiongroup==1.2.0 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist -executing==1.2.0 # via devtools +executing==2.0.1 # via devtools factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==19.6.1 # via factory-boy -fastjsonschema==2.18.0 # via nbformat -filelock==3.12.4 # via tox, virtualenv +faker==20.1.0 # via factory-boy +fastjsonschema==2.19.0 # via nbformat +filelock==3.13.1 # via tox, virtualenv flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings -flake8-bugbear==23.9.16 # via -r requirements-dev.in -flake8-builtins==2.1.0 # via -r requirements-dev.in +flake8-bugbear==23.12.2 # via -r requirements-dev.in +flake8-builtins==2.2.0 # via -r requirements-dev.in flake8-debugger==4.1.2 # via -r requirements-dev.in flake8-docstrings==1.7.0 # via -r requirements-dev.in flake8-eradicate==1.5.0 # via -r requirements-dev.in flake8-mutable==1.2.0 # via -r requirements-dev.in flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in -flask==2.3.3 # via dace -frozendict==2.3.8 # via gt4py (pyproject.toml) +flask==3.0.0 # via dace +fparser==0.1.3 # via dace +frozendict==2.3.10 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.86.1 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.5.29 # via pre-commit -idna==3.4 # via requests +hypothesis==6.92.0 # via -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.5.33 # via pre-commit +idna==3.6 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.8.0 # via build, flask, sphinx -importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications +importlib-metadata==7.0.0 # via build, flask, fparser, sphinx +importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest -isort==5.12.0 # via -r requirements-dev.in +isort==5.13.0 # via -r requirements-dev.in itsdangerous==2.1.2 # via flask jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx -jsonschema==4.19.0 # via nbformat -jsonschema-specifications==2023.7.1 # via jsonschema -jupyter-core==5.3.1 # via nbformat -jupytext==1.15.2 # via -r requirements-dev.in -lark==1.1.7 # via gt4py (pyproject.toml) -mako==1.2.4 # via gt4py (pyproject.toml) +jsonschema==4.20.0 # via nbformat +jsonschema-specifications==2023.11.2 # via jsonschema +jupyter-core==5.5.0 # via nbformat +jupytext==1.16.0 # via -r requirements-dev.in +lark==1.1.8 # via gt4py (pyproject.toml) +mako==1.3.0 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins markupsafe==2.1.3 # via jinja2, mako, werkzeug mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.5.1 # via -r requirements-dev.in +mypy==1.7.1 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==1.5.2 # via gt4py (pyproject.toml) +nanobind==1.8.0 # via gt4py (pyproject.toml) nbformat==5.9.2 # via jupytext networkx==3.1 # via dace -ninja==1.11.1 # via gt4py (pyproject.toml) +ninja==1.11.1.1 # via gt4py (pyproject.toml) nodeenv==1.8.0 # via pre-commit numpy==1.24.4 # via dace, gt4py (pyproject.toml), types-jack-client ordered-set==4.1.0 # via deepdiff -packaging==23.1 # via black, build, gt4py (pyproject.toml), pyproject-api, pytest, sphinx, tox -pathspec==0.11.2 # via black +packaging==23.2 # via black, build, gt4py (pyproject.toml), jupytext, pyproject-api, pytest, scikit-build, setuptools-scm, sphinx, tox +pathspec==0.12.1 # via black +pep517==0.13.1 # via requirementslib +pip-api==0.0.30 # via isort pip-tools==7.3.0 # via -r requirements-dev.in -pipdeptree==2.13.0 # via -r requirements-dev.in +pipdeptree==2.13.1 # via -r requirements-dev.in +pipreqs==0.4.13 # via isort pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv +platformdirs==4.1.0 # via black, jupyter-core, requirementslib, tox, virtualenv +plette==0.4.4 # via requirementslib pluggy==1.3.0 # via pytest, tox ply==3.11 # via dace -pre-commit==3.4.0 # via -r requirements-dev.in -psutil==5.9.5 # via -r requirements-dev.in, pytest-xdist +pre-commit==3.5.0 # via -r requirements-dev.in +psutil==5.9.6 # via -r requirements-dev.in, pytest-xdist pybind11==2.11.1 # via gt4py (pyproject.toml) -pycodestyle==2.11.0 # via flake8, flake8-debugger +pycodestyle==2.11.1 # via flake8, flake8-debugger pycparser==2.21 # via cffi +pydantic==1.10.13 # via requirementslib pydocstyle==6.3.0 # via flake8-docstrings pyflakes==3.1.0 # via flake8 -pygments==2.16.1 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx +pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.2 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==7.4.3 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in -pytest-factoryboy==2.5.1 # via -r requirements-dev.in -pytest-xdist==3.3.1 # via -r requirements-dev.in +pytest-factoryboy==2.6.0 # via -r requirements-dev.in +pytest-xdist==3.5.0 # via -r requirements-dev.in python-dateutil==2.8.2 # via faker pytz==2023.3.post1 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit -referencing==0.30.2 # via jsonschema, jsonschema-specifications -requests==2.31.0 # via dace, sphinx +referencing==0.32.0 # via jsonschema, jsonschema-specifications +requests==2.31.0 # via dace, requirementslib, sphinx, yarg +requirementslib==3.0.0 # via isort restructuredtext-lint==1.4.0 # via flake8-rst-docstrings -rpds-py==0.10.3 # via jsonschema, referencing -ruff==0.0.290 # via -r requirements-dev.in +rpds-py==0.13.2 # via jsonschema, referencing +ruff==0.1.7 # via -r requirements-dev.in +scikit-build==0.17.6 # via dace +setuptools-scm==8.0.4 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx sortedcontainers==2.4.0 # via hypothesis sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==1.3.0 # via -r requirements-dev.in +sphinx-rtd-theme==2.0.0 # via -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx sphinxcontrib-htmlhelp==2.0.1 # via sphinx @@ -131,31 +143,32 @@ sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme sphinxcontrib-jsmath==1.0.1 # via sphinx sphinxcontrib-qthelp==1.0.3 # via sphinx sphinxcontrib-serializinghtml==1.1.5 # via sphinx -sympy==1.12 # via dace, gt4py (pyproject.toml) +sympy==1.9 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext -tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox +tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pep517, pip-tools, pyproject-api, pyproject-hooks, pytest, scikit-build, setuptools-scm, tox +tomlkit==0.12.3 # via plette, requirementslib toolz==0.12.0 # via cytoolz -tox==4.11.3 # via -r requirements-dev.in -traitlets==5.10.0 # via jupyter-core, nbformat +tox==4.11.4 # via -r requirements-dev.in +traitlets==5.14.0 # via jupyter-core, nbformat types-aiofiles==23.2.0.0 # via types-all types-all==1.0.0 # via -r requirements-dev.in types-annoy==1.17.8.4 # via types-all types-atomicwrites==1.4.5.1 # via types-all types-backports==0.1.3 # via types-all types-backports-abc==0.5.2 # via types-all -types-bleach==6.0.0.4 # via types-all +types-bleach==6.1.0.1 # via types-all types-boto==2.49.18.9 # via types-all -types-cachetools==5.3.0.6 # via types-all +types-cachetools==5.3.0.7 # via types-all types-certifi==2021.10.8.3 # via types-all -types-cffi==1.15.1.15 # via types-jack-client +types-cffi==1.16.0.0 # via types-jack-client types-characteristic==14.3.7 # via types-all types-chardet==5.0.4.6 # via types-all types-click==7.1.8 # via types-all, types-flask -types-click-spinner==0.1.13.5 # via types-all +types-click-spinner==0.1.13.6 # via types-all types-colorama==0.4.15.12 # via types-all types-contextvars==2.4.7.3 # via types-all -types-croniter==1.4.0.1 # via types-all +types-croniter==2.0.0.0 # via types-all types-cryptography==3.3.23.2 # via types-all, types-openssl-python, types-pyjwt types-dataclasses==0.6.6 # via types-all types-dateparser==1.1.4.10 # via types-all @@ -176,44 +189,44 @@ types-futures==3.3.8 # via types-all types-geoip2==3.0.0 # via types-all types-ipaddress==1.0.8 # via types-all, types-maxminddb types-itsdangerous==1.1.6 # via types-all -types-jack-client==0.5.10.9 # via types-all +types-jack-client==0.5.10.10 # via types-all types-jinja2==2.11.9 # via types-all, types-flask types-kazoo==0.1.3 # via types-all -types-markdown==3.4.2.10 # via types-all +types-markdown==3.5.0.3 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 -types-mock==5.1.0.2 # via types-all +types-mock==5.1.0.3 # via types-all types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all -types-paramiko==3.3.0.0 # via types-all, types-pysftp +types-paramiko==3.3.0.2 # via types-all, types-pysftp types-pathlib2==2.3.0 # via types-all -types-pillow==10.0.0.3 # via types-all +types-pillow==10.1.0.2 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.1 # via types-all -types-protobuf==4.24.0.1 # via types-all +types-protobuf==4.24.0.4 # via types-all types-pyaudio==0.2.16.7 # via types-all types-pycurl==7.45.2.5 # via types-all types-pyfarmhash==0.3.1.2 # via types-all types-pyjwt==1.7.1 # via types-all types-pymssql==2.1.0 # via types-all types-pymysql==1.1.0.1 # via types-all -types-pyopenssl==23.2.0.2 # via types-redis +types-pyopenssl==23.3.0.0 # via types-redis types-pyrfc3339==1.1.1.5 # via types-all types-pysftp==0.2.17.6 # via types-all types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.1.0 # via types-all, types-tzlocal +types-pytz==2023.3.1.1 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all -types-pyyaml==6.0.12.11 # via types-all -types-redis==4.6.0.6 # via types-all -types-requests==2.31.0.2 # via types-all +types-pyyaml==6.0.12.12 # via types-all +types-redis==4.6.0.11 # via types-all +types-requests==2.31.0.10 # via types-all types-retry==0.9.9.4 # via types-all types-routes==2.5.0 # via types-all types-scribe==2.0.0 # via types-all -types-setuptools==68.2.0.0 # via types-cffi +types-setuptools==69.0.0.0 # via types-cffi types-simplejson==3.19.0.2 # via types-all types-singledispatch==4.1.0.0 # via types-all types-six==1.16.21.9 # via types-all @@ -222,21 +235,21 @@ types-termcolor==1.1.6.2 # via types-all types-toml==0.10.8.7 # via types-all types-tornado==5.1.1 # via types-all types-typed-ast==1.5.8.7 # via types-all -types-tzlocal==5.0.1.1 # via types-all +types-tzlocal==5.1.0.1 # via types-all types-ujson==5.8.0.1 # via types-all -types-urllib3==1.26.25.14 # via types-requests types-waitress==2.1.4.9 # via types-all types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all -typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy -urllib3==2.0.4 # via requests -virtualenv==20.24.5 # via pre-commit, tox -websockets==11.0.3 # via dace -werkzeug==2.3.7 # via flask -wheel==0.41.2 # via astunparse, pip-tools +typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pydantic, pytest-factoryboy, setuptools-scm +urllib3==2.1.0 # via requests, types-requests +virtualenv==20.25.0 # via pre-commit, tox +websockets==12.0 # via dace +werkzeug==3.0.1 # via flask +wheel==0.42.0 # via astunparse, pip-tools, scikit-build xxhash==3.0.0 # via gt4py (pyproject.toml) -zipp==3.16.2 # via importlib-metadata, importlib-resources +yarg==0.1.9 # via pipreqs +zipp==3.17.0 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: -pip==23.2.1 # via pip-tools -setuptools==68.2.2 # via gt4py (pyproject.toml), nodeenv, pip-tools +pip==23.3.1 # via pip-api, pip-tools, requirementslib +setuptools==69.0.2 # via gt4py (pyproject.toml), nodeenv, pip-tools, requirementslib, scikit-build, setuptools-scm diff --git a/examples/lap_cartesian_vs_next.ipynb b/examples/lap_cartesian_vs_next.ipynb new file mode 100644 index 0000000000..cb80122570 --- /dev/null +++ b/examples/lap_cartesian_vs_next.ipynb @@ -0,0 +1,189 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "GT4Py - GridTools for Python\n", + "\n", + "Copyright (c) 2014-2023, ETH Zurich\n", + "All rights reserved.\n", + "\n", + "This file is part the GT4Py project and the GridTools framework.\n", + "GT4Py is free software: you can redistribute it and/or modify it under\n", + "the terms of the GNU General Public License as published by the\n", + "Free Software Foundation, either version 3 of the License, or any later\n", + "version. See the LICENSE.txt file at the top-level directory of this\n", + "distribution for a copy of the license or check .\n", + "\n", + "SPDX-License-Identifier: GPL-3.0-or-later" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Demonstrates gt4py.cartesian with gt4py.next compatibility" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "nx = 32\n", + "ny = 32\n", + "nz = 1\n", + "dtype = np.float64" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Storages\n", + "--\n", + "\n", + "We create fields using the gt4py.next constructors. These fields are compatible with gt4py.cartesian when we use \"I\", \"J\", \"K\" as the dimension names." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import gt4py.next as gtx\n", + "\n", + "allocator = gtx.itir_embedded # should match the executor\n", + "# allocator = gtx.gtfn_cpu\n", + "# allocator = gtx.gtfn_gpu\n", + "\n", + "# Note: for gt4py.next, names don't matter, for gt4py.cartesian they have to be \"I\", \"J\", \"K\"\n", + "I = gtx.Dimension(\"I\")\n", + "J = gtx.Dimension(\"J\")\n", + "K = gtx.Dimension(\"K\", kind=gtx.DimensionKind.VERTICAL)\n", + "\n", + "domain = gtx.domain({I: nx, J: ny, K: nz})\n", + "\n", + "inp = gtx.as_field(domain, np.fromfunction(lambda x, y, z: x**2+y**2, shape=(nx, ny, nz)), dtype, allocator=allocator)\n", + "out_cartesian = gtx.zeros(domain, dtype, allocator=allocator)\n", + "out_next = gtx.zeros(domain, dtype, allocator=allocator)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "gt4py.cartesian\n", + "--" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import gt4py.cartesian.gtscript as gtscript\n", + "\n", + "cartesian_backend = \"numpy\"\n", + "# cartesian_backend = \"gt:cpu_ifirst\"\n", + "# cartesian_backend = \"gt:gpu\"\n", + "\n", + "@gtscript.stencil(backend=cartesian_backend)\n", + "def lap_cartesian(\n", + " inp: gtscript.Field[dtype],\n", + " out: gtscript.Field[dtype],\n", + "):\n", + " with computation(PARALLEL), interval(...):\n", + " out = -4.0 * inp[0, 0, 0] + inp[-1, 0, 0] + inp[1, 0, 0] + inp[0, -1, 0] + inp[0, 1, 0]\n", + "\n", + "lap_cartesian(inp=inp, out=out_cartesian, origin=(1, 1, 0), domain=(nx-2, ny-2, nz))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from gt4py.next import Field\n", + "\n", + "next_backend = gtx.itir_embedded\n", + "# next_backend = gtx.gtfn_cpu\n", + "# next_backend = gtx.gtfn_gpu\n", + "\n", + "Ioff = gtx.FieldOffset(\"I\", source=I, target=(I,))\n", + "Joff = gtx.FieldOffset(\"J\", source=J, target=(J,))\n", + "\n", + "@gtx.field_operator\n", + "def lap_next(inp: Field[[I, J, K], dtype]) -> Field[[I, J, K], dtype]:\n", + " return -4.0 * inp + inp(Ioff[-1]) + inp(Ioff[1]) + inp(Joff[-1]) + inp(Joff[1])\n", + "\n", + "@gtx.program(backend=next_backend)\n", + "def lap_next_program(inp: Field[[I, J, K], dtype], out: Field[[I, J, K], dtype]):\n", + " lap_next(inp, out=out[1:-1, 1:-1, :])\n", + "\n", + "lap_next_program(inp, out_next, offset_provider={\"Ioff\": I, \"Joff\": J})" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "assert np.allclose(out_cartesian.asnumpy(), out_next.asnumpy())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 17709206a0..fd7724bac9 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -25,7 +25,7 @@ cmake==3.22 cogapp==3.3 coverage[toml]==5.0 cytoolz==0.12.0 -dace==0.14.2 +dace==0.15.1 darglint==1.6 deepdiff==5.6.0 devtools==0.6 @@ -70,7 +70,7 @@ scipy==1.7.2 setuptools==65.5.0 sphinx==4.4 sphinx_rtd_theme==1.0 -sympy==1.7 +sympy==1.9 tabulate==0.8.10 tomli==2.0.1 tox==3.2.0 diff --git a/pyproject.toml b/pyproject.toml index 5d7a2f2cb6..675bdae9d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,15 +69,15 @@ requires-python = '>=3.8' cuda = ['cupy>=12.0'] cuda11x = ['cupy-cuda11x>=12.0'] cuda12x = ['cupy-cuda12x>=12.0'] -dace = ['dace>=0.14.2,<0.15', 'sympy>=1.7'] +dace = ['dace>=0.15.1,<0.16', 'sympy>=1.9'] formatting = ['clang-format>=9.0'] # Always add all extra packages to 'full' for a simple full gt4py installation full = [ 'clang-format>=9.0', - 'dace>=0.14.2,<0.15', + 'dace>=0.15.1,<0.16', 'hypothesis>=6.0.0', 'pytest>=7.0', - 'sympy>=1.7', + 'sympy>=1.9', 'scipy>=1.7.2', 'jax[cpu]>=0.4.13' ] diff --git a/requirements-dev.txt b/requirements-dev.txt index d6dcc12d21..0fa523866f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,124 +6,136 @@ # aenum==3.1.15 # via dace alabaster==0.7.13 # via sphinx -asttokens==2.4.0 # via devtools +asttokens==2.4.1 # via devtools astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing -babel==2.12.1 # via sphinx -black==23.9.1 # via gt4py (pyproject.toml) -blinker==1.6.2 # via flask -boltons==23.0.0 # via gt4py (pyproject.toml) +babel==2.13.1 # via sphinx +black==23.11.0 # via gt4py (pyproject.toml) +blinker==1.7.0 # via flask +boltons==23.1.1 # via gt4py (pyproject.toml) build==1.0.3 # via pip-tools cached-property==1.5.2 # via gt4py (pyproject.toml) -cachetools==5.3.1 # via tox -certifi==2023.7.22 # via requests -cffi==1.15.1 # via cryptography +cachetools==5.3.2 # via tox +cerberus==1.3.5 # via plette +certifi==2023.11.17 # via requests +cffi==1.16.0 # via cryptography cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox -charset-normalizer==3.2.0 # via requests -clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) +charset-normalizer==3.3.2 # via requests +clang-format==17.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.27.5 # via gt4py (pyproject.toml) +cmake==3.27.9 # via dace, gt4py (pyproject.toml) cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox -coverage[toml]==7.3.1 # via -r requirements-dev.in, pytest-cov -cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis +coverage[toml]==7.3.2 # via -r requirements-dev.in, pytest-cov +cryptography==41.0.7 # via types-paramiko, types-pyopenssl, types-redis cytoolz==0.12.2 # via gt4py (pyproject.toml) -dace==0.14.4 # via gt4py (pyproject.toml) +dace==0.15.1 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -deepdiff==6.5.0 # via gt4py (pyproject.toml) +deepdiff==6.7.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.7 # via dace -distlib==0.3.7 # via virtualenv -docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme +distlib==0.3.7 # via requirementslib, virtualenv +distro==1.8.0 # via scikit-build +docopt==0.6.2 # via pipreqs +docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme eradicate==2.3.0 # via flake8-eradicate -exceptiongroup==1.1.3 # via hypothesis, pytest +exceptiongroup==1.2.0 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist -executing==1.2.0 # via devtools +executing==2.0.1 # via devtools factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==19.6.1 # via factory-boy -fastjsonschema==2.18.0 # via nbformat -filelock==3.12.4 # via tox, virtualenv +faker==20.1.0 # via factory-boy +fastjsonschema==2.19.0 # via nbformat +filelock==3.13.1 # via tox, virtualenv flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings -flake8-bugbear==23.9.16 # via -r requirements-dev.in -flake8-builtins==2.1.0 # via -r requirements-dev.in +flake8-bugbear==23.12.2 # via -r requirements-dev.in +flake8-builtins==2.2.0 # via -r requirements-dev.in flake8-debugger==4.1.2 # via -r requirements-dev.in flake8-docstrings==1.7.0 # via -r requirements-dev.in flake8-eradicate==1.5.0 # via -r requirements-dev.in flake8-mutable==1.2.0 # via -r requirements-dev.in flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in -flask==2.3.3 # via dace -frozendict==2.3.8 # via gt4py (pyproject.toml) +flask==3.0.0 # via dace +fparser==0.1.3 # via dace +frozendict==2.3.10 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.86.1 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.5.29 # via pre-commit -idna==3.4 # via requests +hypothesis==6.92.0 # via -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.5.33 # via pre-commit +idna==3.6 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.8.0 # via build, flask, sphinx -importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications +importlib-metadata==7.0.0 # via build, flask, fparser, sphinx +importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest -isort==5.12.0 # via -r requirements-dev.in +isort==5.13.0 # via -r requirements-dev.in itsdangerous==2.1.2 # via flask jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx -jsonschema==4.19.0 # via nbformat -jsonschema-specifications==2023.7.1 # via jsonschema -jupyter-core==5.3.1 # via nbformat -jupytext==1.15.2 # via -r requirements-dev.in -lark==1.1.7 # via gt4py (pyproject.toml) -mako==1.2.4 # via gt4py (pyproject.toml) +jsonschema==4.20.0 # via nbformat +jsonschema-specifications==2023.11.2 # via jsonschema +jupyter-core==5.5.0 # via nbformat +jupytext==1.16.0 # via -r requirements-dev.in +lark==1.1.8 # via gt4py (pyproject.toml) +mako==1.3.0 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins markupsafe==2.1.3 # via jinja2, mako, werkzeug mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.5.1 # via -r requirements-dev.in +mypy==1.7.1 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==1.5.2 # via gt4py (pyproject.toml) +nanobind==1.8.0 # via gt4py (pyproject.toml) nbformat==5.9.2 # via jupytext networkx==3.1 # via dace -ninja==1.11.1 # via gt4py (pyproject.toml) +ninja==1.11.1.1 # via gt4py (pyproject.toml) nodeenv==1.8.0 # via pre-commit numpy==1.24.4 # via dace, gt4py (pyproject.toml), types-jack-client ordered-set==4.1.0 # via deepdiff -packaging==23.1 # via black, build, gt4py (pyproject.toml), pyproject-api, pytest, sphinx, tox -pathspec==0.11.2 # via black +packaging==23.2 # via black, build, gt4py (pyproject.toml), jupytext, pyproject-api, pytest, scikit-build, setuptools-scm, sphinx, tox +pathspec==0.12.1 # via black +pep517==0.13.1 # via requirementslib +pip-api==0.0.30 # via isort pip-tools==7.3.0 # via -r requirements-dev.in -pipdeptree==2.13.0 # via -r requirements-dev.in +pipdeptree==2.13.1 # via -r requirements-dev.in +pipreqs==0.4.13 # via isort pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv +platformdirs==4.1.0 # via black, jupyter-core, requirementslib, tox, virtualenv +plette[validation]==0.4.4 # via requirementslib pluggy==1.3.0 # via pytest, tox ply==3.11 # via dace -pre-commit==3.4.0 # via -r requirements-dev.in -psutil==5.9.5 # via -r requirements-dev.in, pytest-xdist +pre-commit==3.5.0 # via -r requirements-dev.in +psutil==5.9.6 # via -r requirements-dev.in, pytest-xdist pybind11==2.11.1 # via gt4py (pyproject.toml) -pycodestyle==2.11.0 # via flake8, flake8-debugger +pycodestyle==2.11.1 # via flake8, flake8-debugger pycparser==2.21 # via cffi +pydantic==1.10.13 # via requirementslib pydocstyle==6.3.0 # via flake8-docstrings pyflakes==3.1.0 # via flake8 -pygments==2.16.1 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx +pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.2 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==7.4.3 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in -pytest-factoryboy==2.5.1 # via -r requirements-dev.in -pytest-xdist[psutil]==3.3.1 # via -r requirements-dev.in +pytest-factoryboy==2.6.0 # via -r requirements-dev.in +pytest-xdist[psutil]==3.5.0 # via -r requirements-dev.in python-dateutil==2.8.2 # via faker pytz==2023.3.post1 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit -referencing==0.30.2 # via jsonschema, jsonschema-specifications -requests==2.31.0 # via dace, sphinx +referencing==0.32.0 # via jsonschema, jsonschema-specifications +requests==2.31.0 # via dace, requirementslib, sphinx, yarg +requirementslib==3.0.0 # via isort restructuredtext-lint==1.4.0 # via flake8-rst-docstrings -rpds-py==0.10.3 # via jsonschema, referencing -ruff==0.0.290 # via -r requirements-dev.in +rpds-py==0.13.2 # via jsonschema, referencing +ruff==0.1.7 # via -r requirements-dev.in +scikit-build==0.17.6 # via dace +setuptools-scm==8.0.4 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx sortedcontainers==2.4.0 # via hypothesis sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==1.3.0 # via -r requirements-dev.in +sphinx-rtd-theme==2.0.0 # via -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx sphinxcontrib-htmlhelp==2.0.1 # via sphinx @@ -131,31 +143,32 @@ sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme sphinxcontrib-jsmath==1.0.1 # via sphinx sphinxcontrib-qthelp==1.0.3 # via sphinx sphinxcontrib-serializinghtml==1.1.5 # via sphinx -sympy==1.12 # via dace, gt4py (pyproject.toml) +sympy==1.9 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext -tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox +tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pep517, pip-tools, pyproject-api, pyproject-hooks, pytest, scikit-build, setuptools-scm, tox +tomlkit==0.12.3 # via plette, requirementslib toolz==0.12.0 # via cytoolz -tox==4.11.3 # via -r requirements-dev.in -traitlets==5.10.0 # via jupyter-core, nbformat +tox==4.11.4 # via -r requirements-dev.in +traitlets==5.14.0 # via jupyter-core, nbformat types-aiofiles==23.2.0.0 # via types-all types-all==1.0.0 # via -r requirements-dev.in types-annoy==1.17.8.4 # via types-all types-atomicwrites==1.4.5.1 # via types-all types-backports==0.1.3 # via types-all types-backports-abc==0.5.2 # via types-all -types-bleach==6.0.0.4 # via types-all +types-bleach==6.1.0.1 # via types-all types-boto==2.49.18.9 # via types-all -types-cachetools==5.3.0.6 # via types-all +types-cachetools==5.3.0.7 # via types-all types-certifi==2021.10.8.3 # via types-all -types-cffi==1.15.1.15 # via types-jack-client +types-cffi==1.16.0.0 # via types-jack-client types-characteristic==14.3.7 # via types-all types-chardet==5.0.4.6 # via types-all types-click==7.1.8 # via types-all, types-flask -types-click-spinner==0.1.13.5 # via types-all +types-click-spinner==0.1.13.6 # via types-all types-colorama==0.4.15.12 # via types-all types-contextvars==2.4.7.3 # via types-all -types-croniter==1.4.0.1 # via types-all +types-croniter==2.0.0.0 # via types-all types-cryptography==3.3.23.2 # via types-all, types-openssl-python, types-pyjwt types-dataclasses==0.6.6 # via types-all types-dateparser==1.1.4.10 # via types-all @@ -176,44 +189,44 @@ types-futures==3.3.8 # via types-all types-geoip2==3.0.0 # via types-all types-ipaddress==1.0.8 # via types-all, types-maxminddb types-itsdangerous==1.1.6 # via types-all -types-jack-client==0.5.10.9 # via types-all +types-jack-client==0.5.10.10 # via types-all types-jinja2==2.11.9 # via types-all, types-flask types-kazoo==0.1.3 # via types-all -types-markdown==3.4.2.10 # via types-all +types-markdown==3.5.0.3 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 -types-mock==5.1.0.2 # via types-all +types-mock==5.1.0.3 # via types-all types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all -types-paramiko==3.3.0.0 # via types-all, types-pysftp +types-paramiko==3.3.0.2 # via types-all, types-pysftp types-pathlib2==2.3.0 # via types-all -types-pillow==10.0.0.3 # via types-all +types-pillow==10.1.0.2 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.1 # via types-all -types-protobuf==4.24.0.1 # via types-all +types-protobuf==4.24.0.4 # via types-all types-pyaudio==0.2.16.7 # via types-all types-pycurl==7.45.2.5 # via types-all types-pyfarmhash==0.3.1.2 # via types-all types-pyjwt==1.7.1 # via types-all types-pymssql==2.1.0 # via types-all types-pymysql==1.1.0.1 # via types-all -types-pyopenssl==23.2.0.2 # via types-redis +types-pyopenssl==23.3.0.0 # via types-redis types-pyrfc3339==1.1.1.5 # via types-all types-pysftp==0.2.17.6 # via types-all types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.1.0 # via types-all, types-tzlocal +types-pytz==2023.3.1.1 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all -types-pyyaml==6.0.12.11 # via types-all -types-redis==4.6.0.6 # via types-all -types-requests==2.31.0.2 # via types-all +types-pyyaml==6.0.12.12 # via types-all +types-redis==4.6.0.11 # via types-all +types-requests==2.31.0.10 # via types-all types-retry==0.9.9.4 # via types-all types-routes==2.5.0 # via types-all types-scribe==2.0.0 # via types-all -types-setuptools==68.2.0.0 # via types-cffi +types-setuptools==69.0.0.0 # via types-cffi types-simplejson==3.19.0.2 # via types-all types-singledispatch==4.1.0.0 # via types-all types-six==1.16.21.9 # via types-all @@ -222,21 +235,21 @@ types-termcolor==1.1.6.2 # via types-all types-toml==0.10.8.7 # via types-all types-tornado==5.1.1 # via types-all types-typed-ast==1.5.8.7 # via types-all -types-tzlocal==5.0.1.1 # via types-all +types-tzlocal==5.1.0.1 # via types-all types-ujson==5.8.0.1 # via types-all -types-urllib3==1.26.25.14 # via types-requests types-waitress==2.1.4.9 # via types-all types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all -typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy -urllib3==2.0.4 # via requests -virtualenv==20.24.5 # via pre-commit, tox -websockets==11.0.3 # via dace -werkzeug==2.3.7 # via flask -wheel==0.41.2 # via astunparse, pip-tools +typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pydantic, pytest-factoryboy, setuptools-scm +urllib3==2.1.0 # via requests, types-requests +virtualenv==20.25.0 # via pre-commit, tox +websockets==12.0 # via dace +werkzeug==3.0.1 # via flask +wheel==0.42.0 # via astunparse, pip-tools, scikit-build xxhash==3.0.0 # via gt4py (pyproject.toml) -zipp==3.16.2 # via importlib-metadata, importlib-resources +yarg==0.1.9 # via pipreqs +zipp==3.17.0 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: -pip==23.2.1 # via pip-tools -setuptools==68.2.2 # via gt4py (pyproject.toml), nodeenv, pip-tools +pip==23.3.1 # via pip-api, pip-tools, requirementslib +setuptools==69.0.2 # via gt4py (pyproject.toml), nodeenv, pip-tools, requirementslib, scikit-build, setuptools-scm diff --git a/src/gt4py/__init__.py b/src/gt4py/__init__.py index 7d255de142..c28c5cf2d6 100644 --- a/src/gt4py/__init__.py +++ b/src/gt4py/__init__.py @@ -33,6 +33,6 @@ if _sys.version_info >= (3, 10): - from . import next + from . import next # noqa: A004 __all__ += ["next"] diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index b1e559a41e..5dae025acb 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -562,12 +562,6 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.S omp_threads = "" omp_header = "" - # Backward compatible state struct name change in DaCe >=0.15.x - try: - dace_state_suffix = dace.Config.get("compiler.codegen_state_struct_suffix") - except (KeyError, TypeError): - dace_state_suffix = "_t" # old structure name - interface = cls.template.definition.render( name=sdfg.name, backend_specifics=omp_threads, @@ -575,7 +569,7 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.S functor_args=self.generate_functor_args(sdfg), tmp_allocs=self.generate_tmp_allocs(sdfg), allocator="gt::cuda_util::cuda_malloc" if is_gpu else "std::make_unique", - state_suffix=dace_state_suffix, + state_suffix=dace.Config.get("compiler.codegen_state_struct_suffix"), ) generated_code = textwrap.dedent( f"""#include diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index db276a48b9..48b129fa87 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -30,6 +30,7 @@ compute_dcir_access_infos, flatten_list, get_tasklet_symbol, + make_dace_subset, union_inout_memlets, union_node_grid_subsets, untile_memlets, @@ -458,6 +459,40 @@ def visit_HorizontalExecution( write_memlets=write_memlets, ) + for memlet in [*read_memlets, *write_memlets]: + """ + This loop handles the special case of a tasklet performing array access. + The memlet should pass the full array shape (no tiling) and + the tasklet expression for array access should use all explicit indexes. + """ + array_ndims = len(global_ctx.arrays[memlet.field].shape) + field_decl = global_ctx.library_node.field_decls[memlet.field] + # calculate array subset on original memlet + memlet_subset = make_dace_subset( + global_ctx.library_node.access_infos[memlet.field], + memlet.access_info, + field_decl.data_dims, + ) + # select index values for single-point grid access + memlet_data_index = [ + dcir.Literal(value=str(dim_range[0]), dtype=common.DataType.INT32) + for dim_range, dim_size in zip(memlet_subset, memlet_subset.size()) + if dim_size == 1 + ] + if len(memlet_data_index) < array_ndims: + reshape_memlet = False + for access_node in dcir_node.walk_values().if_isinstance(dcir.IndexAccess): + if access_node.data_index and access_node.name == memlet.connector: + access_node.data_index = memlet_data_index + access_node.data_index + assert len(access_node.data_index) == array_ndims + reshape_memlet = True + if reshape_memlet: + # ensure that memlet symbols used for array indexing are defined in context + for sym in memlet.access_info.grid_subset.free_symbols: + symbol_collector.add_symbol(sym) + # set full shape on memlet + memlet.access_info = global_ctx.library_node.access_infos[memlet.field] + for item in reversed(expansion_items): iteration_ctx = iteration_ctx.pop() dcir_node = self._process_iteration_item( diff --git a/src/gt4py/cartesian/gtc/dace/nodes.py b/src/gt4py/cartesian/gtc/dace/nodes.py index ddcb719b5f..bd8c08034c 100644 --- a/src/gt4py/cartesian/gtc/dace/nodes.py +++ b/src/gt4py/cartesian/gtc/dace/nodes.py @@ -121,7 +121,7 @@ def __init__( *args, **kwargs, ): - super().__init__(name=name, *args, **kwargs) + super().__init__(*args, name=name, **kwargs) from gt4py.cartesian.gtc.dace.utils import compute_dcir_access_infos diff --git a/src/gt4py/cartesian/gtc/daceir.py b/src/gt4py/cartesian/gtc/daceir.py index 28ebc8cd8e..0366317360 100644 --- a/src/gt4py/cartesian/gtc/daceir.py +++ b/src/gt4py/cartesian/gtc/daceir.py @@ -536,7 +536,7 @@ def union(self, other): else: assert ( isinstance(interval2, (TileInterval, DomainInterval)) - and isinstance(interval1, IndexWithExtent) + and isinstance(interval1, (IndexWithExtent, DomainInterval)) ) or ( isinstance(interval1, (TileInterval, DomainInterval)) and isinstance(interval2, IndexWithExtent) @@ -573,7 +573,7 @@ def overapproximated_shape(self): def apply_iteration(self, grid_subset: GridSubset): res_intervals = dict(self.grid_subset.intervals) for axis, field_interval in self.grid_subset.intervals.items(): - if axis in grid_subset.intervals: + if axis in grid_subset.intervals and not isinstance(field_interval, DomainInterval): grid_interval = grid_subset.intervals[axis] assert isinstance(field_interval, IndexWithExtent) extent = field_interval.extent diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index fcd53d1312..bc744b3ccc 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -814,7 +814,7 @@ def concretize( """ # noqa: RST301 # doctest conventions confuse RST validator concrete_cls: Type[DataModelT] = _make_concrete_with_cache( - datamodel_cls, *type_args, class_name=class_name, module=module + datamodel_cls, *type_args, class_name=class_name, module=module # type: ignore[arg-type] ) assert isinstance(concrete_cls, type) and is_datamodel(concrete_cls) @@ -883,17 +883,6 @@ def _substitute_typevars( return type_params_map[type_hint], True elif getattr(type_hint, "__parameters__", []): return type_hint[tuple(type_params_map[tp] for tp in type_hint.__parameters__)], True - # TODO(egparedes): WIP fix for partial specialization - # # Type hint is a generic model: replace all the concretized type vars - # noqa: e800 replaced = False - # noqa: e800 new_args = [] - # noqa: e800 for tp in type_hint.__parameters__: - # noqa: e800 if tp in type_params_map: - # noqa: e800 new_args.append(type_params_map[tp]) - # noqa: e800 replaced = True - # noqa: e800 else: - # noqa: e800 new_args.append(type_params_map[tp]) - # noqa: e800 return type_hint[tuple(new_args)], replaced else: return type_hint, False @@ -981,21 +970,14 @@ def __class_getitem__( """ type_args: Tuple[Type] = args if isinstance(args, tuple) else (args,) concrete_cls: Type[DataModelT] = concretize(cls, *type_args) - res = xtyping.StdGenericAliasType(concrete_cls, type_args) - if sys.version_info < (3, 9): - # in Python 3.8, xtyping.StdGenericAliasType (aka typing._GenericAlias) - # does not copy all required `__dict__` entries, so do it manually - for k, v in concrete_cls.__dict__.items(): - if k not in res.__dict__: - res.__dict__[k] = v - return res + return concrete_cls return classmethod(__class_getitem__) def _make_type_converter(type_annotation: TypeAnnotation, name: str) -> TypeConverter[_T]: - # TODO(egparedes): if a "typing tree" structure is implemented, refactor this code as a tree traversal. - # + # TODO(egparedes): if a "typing tree" structure is implemented, refactor this code + # as a tree traversal. if xtyping.is_actual_type(type_annotation) and not isinstance(None, type_annotation): assert not xtyping.get_args(type_annotation) assert isinstance(type_annotation, type) @@ -1316,11 +1298,7 @@ def _make_concrete_with_cache( # Replace field definitions with the new actual types for generic fields type_params_map = dict(zip(datamodel_cls.__parameters__, type_args)) model_fields = getattr(datamodel_cls, MODEL_FIELD_DEFINITIONS_ATTR) - new_annotations = { - # TODO(egparedes): ? - # noqa: e800 "__args__": "ClassVar[Tuple[Union[Type, TypeVar], ...]]", - # noqa: e800 "__parameters__": "ClassVar[Tuple[TypeVar, ...]]", - } + new_annotations = {} new_field_c_attrs = {} for field_name, field_type in xtyping.get_type_hints(datamodel_cls).items(): @@ -1353,8 +1331,16 @@ def _make_concrete_with_cache( "__module__": module if module else datamodel_cls.__module__, **new_field_c_attrs, } - concrete_cls = type(class_name, (datamodel_cls,), namespace) + + # Update the tuple of generic parameters in the new class, in case + # this is a partial concretization + assert hasattr(concrete_cls, "__parameters__") + concrete_cls.__parameters__ = tuple( + type_params_map[tp_var] + for tp_var in datamodel_cls.__parameters__ + if isinstance(type_params_map[tp_var], typing.TypeVar) + ) assert concrete_cls.__module__ == module or not module if MODEL_FIELD_DEFINITIONS_ATTR not in concrete_cls.__dict__: diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index 17462a37ff..3ee447ca6c 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -493,7 +493,7 @@ def _patched_proto_hook(other): # type: ignore[no-untyped-def] if isinstance(_typing.Any, type): # Python >= 3.11 _ArtefactTypes = (*_ArtefactTypes, _typing.Any) -# `Any` is a class since typing_extensions >= 4.4 +# `Any` is a class since typing_extensions >= 4.4 and Python 3.11 if (typing_exts_any := getattr(_typing_extensions, "Any", None)) is not _typing.Any and isinstance( typing_exts_any, type ): @@ -504,11 +504,13 @@ def is_actual_type(obj: Any) -> TypeGuard[Type]: """Check if an object has an actual type and instead of a typing artefact like ``GenericAlias`` or ``Any``. This is needed because since Python 3.9: - ``isinstance(types.GenericAlias(), type) is True`` + ``isinstance(types.GenericAlias(), type) is True`` and since Python 3.11: - ``isinstance(typing.Any, type) is True`` + ``isinstance(typing.Any, type) is True`` """ - return isinstance(obj, type) and type(obj) not in _ArtefactTypes + return ( + isinstance(obj, type) and (obj not in _ArtefactTypes) and (type(obj) not in _ArtefactTypes) + ) if hasattr(_typing_extensions, "Any") and _typing.Any is not _typing_extensions.Any: # type: ignore[attr-defined] # _typing_extensions.Any only from >= 4.4 @@ -641,9 +643,12 @@ def get_partial_type_hints( resolved_hints = get_type_hints( # type: ignore[call-arg] # Python 3.8 does not define `include-extras` obj, globalns=globalns, localns=localns, include_extras=include_extras ) - hints.update(resolved_hints) + hints[name] = resolved_hints[name] except NameError as error: if isinstance(hint, str): + # This conversion could be probably skipped in Python versions containing + # the fix applied in bpo-41370. Check: + # https://github.com/python/cpython/commit/b465b606049f6f7dd0711cb031fdaa251818741a#diff-ddb987fca5f5df0c9a2f5521ed687919d70bb3d64eaeb8021f98833a2a716887R344 hints[name] = ForwardRef(hint) elif isinstance(hint, (ForwardRef, _typing.ForwardRef)): hints[name] = hint diff --git a/src/gt4py/eve/trees.py b/src/gt4py/eve/trees.py index cd7e71588f..74c5bd41bb 100644 --- a/src/gt4py/eve/trees.py +++ b/src/gt4py/eve/trees.py @@ -133,7 +133,7 @@ def _pre_walk_items( yield from _pre_walk_items(child, __key__=key) -def _pre_walk_values(node: TreeLike) -> Iterable[Tuple[Any]]: +def _pre_walk_values(node: TreeLike) -> Iterable: """Create a pre-order tree traversal iterator of values.""" yield node for child in iter_children_values(node): @@ -153,7 +153,7 @@ def _post_walk_items( yield __key__, node -def _post_walk_values(node: TreeLike) -> Iterable[Tuple[Any]]: +def _post_walk_values(node: TreeLike) -> Iterable: """Create a post-order tree traversal iterator of values.""" if (iter_children_values := getattr(node, "iter_children_values", None)) is not None: for child in iter_children_values(): diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 7104f7658f..624407f319 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -1225,7 +1225,7 @@ def unzip(self) -> XIterable[Tuple[Any, ...]]: [('a', 'b', 'c'), (1, 2, 3)] """ - return XIterable(zip(*self.iterator)) # type: ignore # mypy gets confused with *args + return XIterable(zip(*self.iterator)) @typing.overload def islice(self, __stop: int) -> XIterable[T]: @@ -1536,7 +1536,7 @@ def reduceby( ) -> Dict[K, S]: ... - def reduceby( # type: ignore[misc] # signatures 2 and 4 are not satified due to inconsistencies with type variables + def reduceby( self, bin_op_func: Callable[[S, T], S], key: Union[str, List[K], Callable[[T], K]], diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index cbd5735949..1398af5f03 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -39,6 +39,11 @@ index_field, np_as_located_field, ) +from .program_processors.runners.gtfn import ( + run_gtfn_cached as gtfn_cpu, + run_gtfn_gpu_cached as gtfn_gpu, +) +from .program_processors.runners.roundtrip import backend as itir_python __all__ = [ @@ -74,5 +79,9 @@ "field_operator", "program", "scan_operator", + # from program_processor + "gtfn_cpu", + "gtfn_gpu", + "itir_python", *fbuiltins.__all__, ] diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 29d606ccc0..949f4b461a 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -189,11 +189,12 @@ def __and__(self, other: UnitRange) -> UnitRange: return UnitRange(max(self.start, other.start), min(self.stop, other.stop)) def __contains__(self, value: Any) -> bool: - return ( - isinstance(value, core_defs.INTEGRAL_TYPES) - and value >= self.start - and value < self.stop - ) + # TODO(egparedes): use core_defs.IntegralScalar for `isinstance()` checks (see PEP 604) + # and remove int cast, once the related mypy bug (#16358) gets fixed + if isinstance(value, core_defs.INTEGRAL_TYPES): + return self.start <= cast(int, value) < self.stop + else: + return False def __le__(self, other: UnitRange) -> bool: return self.start >= other.start and self.stop <= other.stop @@ -574,38 +575,39 @@ def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _ ... -# TODO(havogt): replace this protocol with the new `GTFieldInterface` protocol -class NextGTDimsInterface(Protocol): +# TODO(havogt): we need to describe when this interface should be used instead of the `Field` protocol. +class GTFieldInterface(core_defs.GTDimsInterface, core_defs.GTOriginInterface, Protocol): """ - Protocol for objects providing the `__gt_dims__` property, naming :class:`Field` dimensions. + Protocol for object providing the `__gt_domain__` property, specifying the :class:`Domain` of a :class:`Field`. - The dimension names are objects of type :class:`Dimension`, in contrast to - :mod:`gt4py.cartesian`, where the labels are `str` s with implied semantics, - see :class:`~gt4py._core.definitions.GTDimsInterface` . + Note: + - A default implementation of the `__gt_dims__` interface from `gt4py.cartesian` is provided. + - No implementation of `__gt_origin__` is provided because of infinite fields. """ @property - def __gt_dims__(self) -> tuple[Dimension, ...]: + def __gt_domain__(self) -> Domain: + # TODO probably should be changed to `DomainLike` (with a new concept `DimensionLike`) + # to allow implementations without having to import gtx.Domain. ... - -# TODO(egparedes): add support for this new protocol in the cartesian module -class GTFieldInterface(Protocol): - """Protocol for object providing the `__gt_domain__` property, specifying the :class:`Domain` of a :class:`Field`.""" - @property - def __gt_domain__(self) -> Domain: - ... + def __gt_dims__(self) -> tuple[str, ...]: + return tuple(d.value for d in self.__gt_domain__.dims) @extended_runtime_checkable -class Field(NextGTDimsInterface, core_defs.GTOriginInterface, Protocol[DimsT, core_defs.ScalarT]): +class Field(GTFieldInterface, Protocol[DimsT, core_defs.ScalarT]): __gt_builtin_func__: ClassVar[GTBuiltInFuncDispatcher] @property def domain(self) -> Domain: ... + @property + def __gt_domain__(self) -> Domain: + return self.domain + @property def codomain(self) -> type[core_defs.ScalarT] | Dimension: ... @@ -923,10 +925,6 @@ def asnumpy(self) -> Never: def domain(self) -> Domain: return Domain(dims=(self.dimension,), ranges=(UnitRange.infinite(),)) - @property - def __gt_dims__(self) -> tuple[Dimension, ...]: - return self.domain.dims - @property def __gt_origin__(self) -> Never: raise TypeError("'CartesianConnectivity' does not support this operation.") diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 8bd2673db9..9fc1b42038 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -107,10 +107,6 @@ def domain(self) -> common.Domain: def shape(self) -> tuple[int, ...]: return self._ndarray.shape - @property - def __gt_dims__(self) -> tuple[common.Dimension, ...]: - return self._domain.dims - @property def __gt_origin__(self) -> tuple[int, ...]: assert common.Domain.is_finite(self._domain) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 76a0ddcde0..9f8537f59b 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -344,7 +344,9 @@ def _validate_args(self, *args, **kwargs) -> None: raise_exception=True, ) except ValueError as err: - raise TypeError(f"Invalid argument types in call to '{self.past_node.id}'.") from err + raise errors.DSLError( + None, f"Invalid argument types in call to '{self.past_node.id}'.\n{err}" + ) from err def _process_args(self, args: tuple, kwargs: dict) -> tuple[tuple, tuple, dict[str, Any]]: self._validate_args(*args, **kwargs) @@ -453,27 +455,32 @@ def _process_args(self, args: tuple, kwargs: dict): ) from err full_args = [*args] + full_kwargs = {**kwargs} for index, param in enumerate(self.past_node.params): if param.id in self.bound_args.keys(): - full_args.insert(index, self.bound_args[param.id]) + if index < len(full_args): + full_args.insert(index, self.bound_args[param.id]) + else: + full_kwargs[str(param.id)] = self.bound_args[param.id] - return super()._process_args(tuple(full_args), kwargs) + return super()._process_args(tuple(full_args), full_kwargs) @functools.cached_property def itir(self): new_itir = super().itir for new_clos in new_itir.closures: - for key in self.bound_args.keys(): + new_args = [ref(inp.id) for inp in new_clos.inputs] + for key, value in self.bound_args.items(): index = next( index for index, closure_input in enumerate(new_clos.inputs) if closure_input.id == key ) + new_args[new_args.index(new_clos.inputs[index])] = promote_to_const_iterator( + literal_from_value(value) + ) new_clos.inputs.pop(index) - new_args = [ref(inp.id) for inp in new_clos.inputs] params = [sym(inp.id) for inp in new_clos.inputs] - for value in self.bound_args.values(): - new_args.append(promote_to_const_iterator(literal_from_value(value))) expr = itir.FunCall( fun=new_clos.stencil, args=new_args, @@ -847,6 +854,7 @@ def scan_operator_inner(definition: types.FunctionType) -> FieldOperator: return FieldOperator.from_function( definition, DEFAULT_BACKEND if backend is eve.NOTHING else backend, + grid_type, operator_node_cls=foast.ScanOperator, operator_attributes={"axis": axis, "forward": forward, "init": init}, ) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 278dde9180..cd75538da7 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -15,7 +15,7 @@ import dataclasses import functools import inspect -from builtins import bool, float, int, tuple +from builtins import bool, float, int, tuple # noqa: A004 from typing import Any, Callable, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast import numpy as np diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 639e5ff009..5e289af664 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -694,7 +694,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: ) except ValueError as err: raise errors.DSLError( - node.location, f"Invalid argument types in call to '{new_func}'." + node.location, f"Invalid argument types in call to '{new_func}'.\n{err}" ) from err return_type = type_info.return_type(func_type, with_args=arg_types, with_kwargs=kwarg_types) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index fc353d64e4..af8f5e8368 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -229,7 +229,7 @@ def visit_Call(self, node: past.Call, **kwargs): ) except ValueError as ex: - raise errors.DSLError(node.location, f"Invalid call to '{node.func.id}'.") from ex + raise errors.DSLError(node.location, f"Invalid call to '{node.func.id}'.\n{ex}") from ex return past.Call( func=new_func, diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index ef70a2e645..390bec4312 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -172,7 +172,7 @@ class LocatedField(Protocol): @property @abc.abstractmethod - def __gt_dims__(self) -> tuple[common.Dimension, ...]: + def __gt_domain__(self) -> common.Domain: ... # TODO(havogt): define generic Protocol to provide a concrete return type @@ -182,7 +182,7 @@ def field_getitem(self, indices: NamedFieldIndices) -> Any: @property def __gt_origin__(self) -> tuple[int, ...]: - return tuple([0] * len(self.__gt_dims__)) + return tuple([0] * len(self.__gt_domain__.dims)) @runtime_checkable @@ -675,12 +675,18 @@ def _is_concrete_position(pos: Position) -> TypeGuard[ConcretePosition]: def _get_axes( field_or_tuple: LocatedField | tuple, ) -> Sequence[common.Dimension]: # arbitrary nesting of tuples of LocatedField + return _get_domain(field_or_tuple).dims + + +def _get_domain( + field_or_tuple: LocatedField | tuple, +) -> common.Domain: # arbitrary nesting of tuples of LocatedField if isinstance(field_or_tuple, tuple): - first = _get_axes(field_or_tuple[0]) - assert all(first == _get_axes(f) for f in field_or_tuple) + first = _get_domain(field_or_tuple[0]) + assert all(first == _get_domain(f) for f in field_or_tuple) return first else: - return field_or_tuple.__gt_dims__ + return field_or_tuple.__gt_domain__ def _single_vertical_idx( @@ -894,14 +900,14 @@ class NDArrayLocatedFieldWrapper(MutableLocatedField): _ndarrayfield: common.Field @property - def __gt_dims__(self) -> tuple[common.Dimension, ...]: - return self._ndarrayfield.__gt_dims__ + def __gt_domain__(self) -> common.Domain: + return self._ndarrayfield.__gt_domain__ def _translate_named_indices( self, _named_indices: NamedFieldIndices ) -> common.AbsoluteIndexSequence: named_indices: Mapping[common.Dimension, FieldIndex | SparsePositionEntry] = { - d: _named_indices[d.value] for d in self._ndarrayfield.__gt_dims__ + d: _named_indices[d.value] for d in self._ndarrayfield.__gt_domain__.dims } domain_slice: list[common.NamedRange | common.NamedIndex] = [] for d, v in named_indices.items(): @@ -1046,8 +1052,8 @@ class IndexField(common.Field): _dimension: common.Dimension @property - def __gt_dims__(self) -> tuple[common.Dimension, ...]: - return (self._dimension,) + def __gt_domain__(self) -> common.Domain: + return self.domain @property def __gt_origin__(self) -> tuple[int, ...]: @@ -1165,8 +1171,8 @@ class ConstantField(common.Field[Any, core_defs.ScalarT]): _value: core_defs.ScalarT @property - def __gt_dims__(self) -> tuple[common.Dimension, ...]: - return tuple() + def __gt_domain__(self) -> common.Domain: + return self.domain @property def __gt_origin__(self) -> tuple[int, ...]: @@ -1452,7 +1458,7 @@ def _tuple_assign(field: tuple | MutableLocatedField, value: Any, named_indices: class TupleOfFields(TupleField): def __init__(self, data): self.data = data - self.__gt_dims__ = _get_axes(data) + self.__gt_domain__ = _get_domain(data) def field_getitem(self, named_indices: NamedFieldIndices) -> Any: return _build_tuple_result(self.data, named_indices) diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index 30fec1f9fd..05ebd02352 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -254,7 +254,7 @@ def _contains_tuple_dtype_field(arg): # other `np.int32`). We just ignore the error here and postpone fixing this to when # the new storages land (The implementation here works for LocatedFieldImpl). - return common.is_field(arg) and any(dim is None for dim in arg.__gt_dims__) + return common.is_field(arg) and any(dim is None for dim in arg.domain.dims) def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]: diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 2609e35735..7ad55d0a87 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -23,6 +23,7 @@ from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.utils import UIDGenerator from gt4py.eve.visitors import PreserveLocationVisitor +from gt4py.next import common from gt4py.next.iterator import ir, type_inference from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift @@ -442,9 +443,12 @@ def _group_offsets( return zip(tags, offsets, strict=True) # type: ignore[return-value] # mypy doesn't infer literal correctly -def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, Any]): +def update_domains( + node: FencilWithTemporaries, + offset_provider: Mapping[str, Any], + symbolic_sizes: Optional[dict[str, str]], +): horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider) - closures: list[ir.StencilClosure] = [] domains = dict[str, ir.FunCall]() for closure in reversed(node.fencil.closures): @@ -485,16 +489,29 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An # cartesian shift dim = offset_provider[offset_name].value consumed_domain.ranges[dim] = consumed_domain.ranges[dim].translate(offset) - elif isinstance(offset_provider[offset_name], gtx.NeighborTableOffsetProvider): + elif isinstance(offset_provider[offset_name], common.Connectivity): # unstructured shift nbt_provider = offset_provider[offset_name] old_axis = nbt_provider.origin_axis.value new_axis = nbt_provider.neighbor_axis.value - consumed_domain.ranges.pop(old_axis) - assert new_axis not in consumed_domain.ranges - consumed_domain.ranges[new_axis] = SymbolicRange( - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.literal(str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN), + + assert new_axis not in consumed_domain.ranges or old_axis == new_axis + + if symbolic_sizes is None: + new_range = SymbolicRange( + im.literal("0", ir.INTEGER_INDEX_BUILTIN), + im.literal( + str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN + ), + ) + else: + new_range = SymbolicRange( + im.literal("0", ir.INTEGER_INDEX_BUILTIN), + im.ref(symbolic_sizes[new_axis]), + ) + consumed_domain.ranges = dict( + (axis, range_) if axis != old_axis else (new_axis, new_range) + for axis, range_ in consumed_domain.ranges.items() ) else: raise NotImplementedError @@ -577,7 +594,11 @@ class CreateGlobalTmps(PreserveLocationVisitor, NodeTranslator): """ def visit_FencilDefinition( - self, node: ir.FencilDefinition, *, offset_provider: Mapping[str, Any] + self, + node: ir.FencilDefinition, + *, + offset_provider: Mapping[str, Any], + symbolic_sizes: Optional[dict[str, str]], ) -> FencilWithTemporaries: # Split closures on lifted function calls and introduce temporaries res = split_closures(node, offset_provider=offset_provider) @@ -588,6 +609,6 @@ def visit_FencilDefinition( # Perform an eta-reduction which should put all calls at the highest level of a closure res = EtaReduction().visit(res) # Perform a naive extent analysis to compute domain sizes of closures and temporaries - res = update_domains(res, offset_provider) + res = update_domains(res, offset_provider, symbolic_sizes) # Use type inference to determine the data type of the temporaries return collect_tmps_info(res, offset_provider=offset_provider) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 2e05391634..08897861c2 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import enum +from typing import Optional from gt4py.next.iterator import ir from gt4py.next.iterator.transforms import simple_inline_heuristic @@ -81,6 +82,7 @@ def apply_common_transforms( common_subexpression_elimination=True, force_inline_lambda_args=False, unconditionally_collapse_tuples=False, + symbolic_domain_sizes: Optional[dict[str, str]] = None, ): if lift_mode is None: lift_mode = LiftMode.FORCE_INLINE @@ -147,7 +149,9 @@ def apply_common_transforms( if lift_mode != LiftMode.FORCE_INLINE: assert offset_provider is not None - ir = CreateGlobalTmps().visit(ir, offset_provider=offset_provider) + ir = CreateGlobalTmps().visit( + ir, offset_provider=offset_provider, symbolic_sizes=symbolic_domain_sizes + ) ir = InlineLifts().visit(ir) # If after creating temporaries, the scan is not at the top, we inline. # The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. diff --git a/src/gt4py/next/iterator/transforms/power_unrolling.py b/src/gt4py/next/iterator/transforms/power_unrolling.py new file mode 100644 index 0000000000..ac71f2747d --- /dev/null +++ b/src/gt4py/next/iterator/transforms/power_unrolling.py @@ -0,0 +1,84 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +import dataclasses +import math + +from gt4py.eve import NodeTranslator +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas + + +def _is_power_call( + node: ir.FunCall, +) -> bool: + """Match expressions of the form `power(base, integral_literal)`.""" + return ( + isinstance(node.fun, ir.SymRef) + and node.fun.id == "power" + and isinstance(node.args[1], ir.Literal) + and float(node.args[1].value) == int(node.args[1].value) + and node.args[1].value >= im.literal_from_value(0).value + ) + + +def _compute_integer_power_of_two(exp: int) -> int: + return math.floor(math.log2(exp)) + + +@dataclasses.dataclass +class PowerUnrolling(NodeTranslator): + max_unroll: int + + @classmethod + def apply(cls, node: ir.Node, max_unroll: int = 5) -> ir.Node: + return cls(max_unroll=max_unroll).visit(node) + + def visit_FunCall(self, node: ir.FunCall): + new_node = self.generic_visit(node) + + if _is_power_call(new_node): + assert len(new_node.args) == 2 + # Check if unroll should be performed or if exponent is too large + base, exponent = new_node.args[0], int(new_node.args[1].value) + if 1 <= exponent <= self.max_unroll: + # Calculate and store powers of two of the base as long as they are smaller than the exponent. + # Do the same (using the stored values) with the remainder and multiply computed values. + pow_cur = _compute_integer_power_of_two(exponent) + pow_max = pow_cur + remainder = exponent + + # Build target expression + ret = im.ref(f"power_{2 ** pow_max}") + remainder -= 2**pow_cur + while remainder > 0: + pow_cur = _compute_integer_power_of_two(remainder) + remainder -= 2**pow_cur + + ret = im.multiplies_(ret, f"power_{2 ** pow_cur}") + + # Nest target expression to avoid multiple redundant evaluations + for i in range(pow_max, 0, -1): + ret = im.let( + f"power_{2 ** i}", + im.multiplies_(f"power_{2**(i-1)}", f"power_{2**(i-1)}"), + )(ret) + ret = im.let("power_1", base)(ret) + + # Simplify expression in case of SymRef by resolving let statements + if isinstance(base, ir.SymRef): + return InlineLambdas.apply(ret, opcount_preserving=True) + else: + return ret + return new_node diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py index 68627cfd89..d65f67b266 100644 --- a/src/gt4py/next/iterator/type_inference.py +++ b/src/gt4py/next/iterator/type_inference.py @@ -567,9 +567,7 @@ def _infer_shift_location_types(shift_args, offset_provider, constraints): axis = offset_provider[offset] if isinstance(axis, gtx.Dimension): continue # Cartesian shifts don’t change the location type - elif isinstance( - axis, (gtx.NeighborTableOffsetProvider, gtx.StridedNeighborOffsetProvider) - ): + elif isinstance(axis, Connectivity): assert ( axis.origin_axis.kind == axis.neighbor_axis.kind @@ -964,7 +962,7 @@ def visit_FencilDefinition( def _save_types_to_annex(node: ir.Node, types: dict[int, Type]) -> None: for child_node in node.pre_walk_values().if_isinstance(*TYPED_IR_NODES): try: - child_node.annex.type = types[id(child_node)] # type: ignore[attr-defined] + child_node.annex.type = types[id(child_node)] except KeyError: if not ( isinstance(child_node, ir.SymRef) diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index ed8b768972..3a82f9c738 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -82,7 +82,7 @@ def replace(self, **kwargs: Any) -> Self: if not dataclasses.is_dataclass(self): raise TypeError(f"'{self.__class__}' is not a dataclass.") assert not isinstance(self, type) - return dataclasses.replace(self, **kwargs) # type: ignore[misc] # `self` is guaranteed to be a dataclass (is_dataclass) should be a `TypeGuard`? + return dataclasses.replace(self, **kwargs) class ChainableWorkflowMixin(Workflow[StartT, EndT]): diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py deleted file mode 100644 index 4183f52550..0000000000 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py +++ /dev/null @@ -1,77 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from typing import Any - -import gt4py.next.iterator.ir as itir -from gt4py.eve import codegen -from gt4py.eve.exceptions import EveValueError -from gt4py.next.iterator.transforms.pass_manager import apply_common_transforms -from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen -from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_to_gtfn_im_ir import GTFN_IM_lowering -from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import GTFN_lowering - - -def _lower( - program: itir.FencilDefinition, enable_itir_transforms: bool, do_unroll: bool, **kwargs: Any -): - offset_provider = kwargs.get("offset_provider") - assert isinstance(offset_provider, dict) - if enable_itir_transforms: - program = apply_common_transforms( - program, - lift_mode=kwargs.get("lift_mode"), - offset_provider=offset_provider, - unroll_reduce=do_unroll, - unconditionally_collapse_tuples=True, # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements - ) - gtfn_ir = GTFN_lowering.apply( - program, - offset_provider=offset_provider, - column_axis=kwargs.get("column_axis"), - ) - return gtfn_ir - - -def generate( - program: itir.FencilDefinition, enable_itir_transforms: bool = True, **kwargs: Any -) -> str: - if kwargs.get("imperative", False): - try: - gtfn_ir = _lower( - program=program, - enable_itir_transforms=enable_itir_transforms, - do_unroll=False, - **kwargs, - ) - except EveValueError: - # if we don't unroll, there may be lifts left in the itir which can't be lowered to - # gtfn. In this case, just retry with unrolled reductions. - gtfn_ir = _lower( - program=program, - enable_itir_transforms=enable_itir_transforms, - do_unroll=True, - **kwargs, - ) - gtfn_im_ir = GTFN_IM_lowering().visit(node=gtfn_ir, **kwargs) - generated_code = GTFNIMCodegen.apply(gtfn_im_ir, **kwargs) - else: - gtfn_ir = _lower( - program=program, - enable_itir_transforms=enable_itir_transforms, - do_unroll=True, - **kwargs, - ) - generated_code = GTFNCodegen.apply(gtfn_ir, **kwargs) - return codegen.format_source("cpp", generated_code, style="LLVM") diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 4abdaa6eea..718fef72af 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -15,21 +15,24 @@ from __future__ import annotations import dataclasses +import functools import warnings from typing import Any, Final, Optional import numpy as np from gt4py._core import definitions as core_defs -from gt4py.eve import trees, utils +from gt4py.eve import codegen, trees, utils from gt4py.next import common from gt4py.next.common import Connectivity, Dimension from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import LiftMode +from gt4py.next.iterator.transforms import LiftMode, pass_manager from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import cpp_interface, interface -from gt4py.next.program_processors.codegens.gtfn import gtfn_backend +from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen +from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_to_gtfn_im_ir import GTFN_IM_lowering +from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import GTFN_lowering from gt4py.next.type_system import type_specifications as ts, type_translation @@ -54,6 +57,7 @@ class GTFNTranslationStep( use_imperative_backend: bool = False lift_mode: Optional[LiftMode] = None device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + symbolic_domain_sizes: Optional[dict[str, str]] = None def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings: match self.device_type: @@ -171,6 +175,70 @@ def _process_connectivity_args( return parameters, arg_exprs + def _preprocess_program( + self, + program: itir.FencilDefinition, + offset_provider: dict[str, Connectivity | Dimension], + runtime_lift_mode: Optional[LiftMode] = None, + ) -> itir.FencilDefinition: + # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added + # to the interface of all (or at least all of concern) backends, but instead should be + # configured in the backend itself (like it is here), until then we respect the argument + # here and warn the user if it differs from the one configured. + lift_mode = runtime_lift_mode or self.lift_mode + if lift_mode != self.lift_mode: + warnings.warn( + f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but " + f"overriden to be {str(runtime_lift_mode)} at runtime." + ) + + if not self.enable_itir_transforms: + return program + + apply_common_transforms = functools.partial( + pass_manager.apply_common_transforms, + lift_mode=lift_mode, + offset_provider=offset_provider, + # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements + unconditionally_collapse_tuples=True, + symbolic_domain_sizes=self.symbolic_domain_sizes, + ) + + new_program = apply_common_transforms( + program, unroll_reduce=not self.use_imperative_backend + ) + + if self.use_imperative_backend and any( + node.id == "neighbors" + for node in new_program.pre_walk_values().if_isinstance(itir.SymRef) + ): + # if we don't unroll, there may be lifts left in the itir which can't be lowered to + # gtfn. In this case, just retry with unrolled reductions. + new_program = apply_common_transforms(program, unroll_reduce=True) + + return new_program + + def generate_stencil_source( + self, + program: itir.FencilDefinition, + offset_provider: dict[str, Connectivity | Dimension], + column_axis: Optional[common.Dimension], + runtime_lift_mode: Optional[LiftMode] = None, + ) -> str: + new_program = self._preprocess_program(program, offset_provider, runtime_lift_mode) + gtfn_ir = GTFN_lowering.apply( + new_program, + offset_provider=offset_provider, + column_axis=column_axis, + ) + + if self.use_imperative_backend: + gtfn_im_ir = GTFN_IM_lowering().visit(node=gtfn_ir) + generated_code = GTFNIMCodegen.apply(gtfn_im_ir) + else: + generated_code = GTFNCodegen.apply(gtfn_ir) + return codegen.format_source("cpp", generated_code, style="LLVM") + def __call__( self, inp: stages.ProgramCall, @@ -190,18 +258,6 @@ def __call__( inp.kwargs["offset_provider"] ) - # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added - # to the interface of all (or at least all of concern) backends, but instead should be - # configured in the backend itself (like it is here), until then we respect the argument - # here and warn the user if it differs from the one configured. - runtime_lift_mode = inp.kwargs.pop("lift_mode", None) - lift_mode = runtime_lift_mode or self.lift_mode - if runtime_lift_mode != self.lift_mode: - warnings.warn( - f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but " - "overriden to be {str(runtime_lift_mode)} at runtime." - ) - # combine into a format that is aligned with what the backend expects parameters: list[interface.Parameter] = regular_parameters + connectivity_parameters backend_arg = self._backend_type() @@ -213,12 +269,11 @@ def __call__( f"{', '.join(connectivity_args_expr)})({', '.join(args_expr)});" ) decl_src = cpp_interface.render_function_declaration(function, body=decl_body) - stencil_src = gtfn_backend.generate( + stencil_src = self.generate_stencil_source( program, - enable_itir_transforms=self.enable_itir_transforms, - lift_mode=lift_mode, - imperative=self.use_imperative_backend, - **inp.kwargs, + inp.kwargs["offset_provider"], + inp.kwargs.get("column_axis", None), + inp.kwargs.get("lift_mode", None), ) source_code = interface.format_source( self._language_settings(), diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py index f9fa154641..27dec77ed1 100644 --- a/src/gt4py/next/program_processors/formatters/gtfn.py +++ b/src/gt4py/next/program_processors/formatters/gtfn.py @@ -15,10 +15,19 @@ from typing import Any from gt4py.next.iterator import ir as itir -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.codegens.gtfn.gtfn_module import GTFNTranslationStep from gt4py.next.program_processors.processor_interface import program_formatter +from gt4py.next.program_processors.runners.gtfn import gtfn_executor @program_formatter def format_cpp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: - return generate(program, **kwargs) + # TODO(tehrengruber): This is a little ugly. Revisit. + gtfn_translation = gtfn_executor.otf_workflow.translation + assert isinstance(gtfn_translation, GTFNTranslationStep) + return gtfn_translation.generate_stencil_source( + program, + offset_provider=kwargs.get("offset_provider", None), + column_axis=kwargs.get("column_axis", None), + runtime_lift_mode=kwargs.get("lift_mode", None), + ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index cd70f8b588..54ca08fe6e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -261,10 +261,11 @@ def build_sdfg_from_itir( # visit ITIR and generate SDFG program = preprocess_program(program, offset_provider, lift_mode) - # TODO: According to Lex one should build the SDFG first in a general mannor. - # Generalisation to a particular device should happen only at the end. - sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, on_gpu) + sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) sdfg = sdfg_genenerator.visit(program) + if sdfg is None: + raise RuntimeError(f"Visit failed for program {program.id}.") + for nested_sdfg in sdfg.all_sdfgs_recursive(): if not nested_sdfg.debuginfo: _, frameinfo = warnings.warn( @@ -277,6 +278,8 @@ def build_sdfg_from_itir( end_line=frameinfo.lineno, filename=frameinfo.filename, ) + + # run DaCe transformations to simplify the SDFG sdfg.simplify() # run DaCe auto-optimization heuristics @@ -287,6 +290,9 @@ def build_sdfg_from_itir( device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu) + if on_gpu: + sdfg.apply_gpu_transformations() + return sdfg @@ -296,7 +302,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): compiler_args = kwargs.get("compiler_args", None) # `None` will take default. build_type = kwargs.get("build_type", "RelWithDebInfo") on_gpu = kwargs.get("on_gpu", False) - auto_optimize = kwargs.get("auto_optimize", False) + auto_optimize = kwargs.get("auto_optimize", True) lift_mode = kwargs.get("lift_mode", itir_transforms.LiftMode.FORCE_INLINE) # ITIR parameters column_axis = kwargs.get("column_axis", None) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index e03e3654a2..dc194c0436 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -100,20 +100,17 @@ class ItirToSDFG(eve.NodeVisitor): offset_provider: dict[str, Any] node_types: dict[int, next_typing.Type] unique_id: int - use_gpu_storage: bool def __init__( self, param_types: list[ts.TypeSpec], offset_provider: dict[str, NeighborTableOffsetProvider], column_axis: Optional[Dimension] = None, - use_gpu_storage: bool = False, ): self.param_types = param_types self.column_axis = column_axis self.offset_provider = offset_provider self.storage_types = {} - self.use_gpu_storage = use_gpu_storage def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True): if isinstance(type_, ts.FieldType): @@ -124,14 +121,7 @@ def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset else None ) dtype = as_dace_type(type_.dtype) - storage = ( - dace.dtypes.StorageType.GPU_Global - if self.use_gpu_storage - else dace.dtypes.StorageType.Default - ) - sdfg.add_array( - name, shape=shape, strides=strides, offset=offset, dtype=dtype, storage=storage - ) + sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype) elif isinstance(type_, ts.ScalarType): sdfg.add_symbol(name, as_dace_type(type_)) @@ -250,7 +240,6 @@ def visit_StencilClosure( shape=array_table[name].shape, strides=array_table[name].strides, dtype=array_table[name].dtype, - storage=array_table[name].storage, transient=True, ) closure_init_state.add_nedge( @@ -265,7 +254,6 @@ def visit_StencilClosure( shape=array_table[name].shape, strides=array_table[name].strides, dtype=array_table[name].dtype, - storage=array_table[name].storage, ) else: assert isinstance(self.storage_types[name], ts.ScalarType) diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 88a8347fe4..12649bf620 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -184,7 +184,7 @@ def from_value(value: Any) -> ts.TypeSpec: elif isinstance(value, common.Dimension): symbol_type = ts.DimensionType(dim=value) elif common.is_field(value): - dims = list(value.__gt_dims__) + dims = list(value.domain.dims) dtype = from_type_hint(value.dtype.scalar_type) symbol_type = ts.FieldType(dims=dims, dtype=dtype) elif isinstance(value, tuple): diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py index 0f7cf5d0ab..4e7ebb0c21 100644 --- a/src/gt4py/storage/cartesian/utils.py +++ b/src/gt4py/storage/cartesian/utils.py @@ -192,6 +192,10 @@ def cpu_copy(array: Union[np.ndarray, "cp.ndarray"]) -> np.ndarray: def asarray( array: FieldLike, *, device: Literal["cpu", "gpu", None] = None ) -> np.ndarray | cp.ndarray: + if hasattr(array, "ndarray"): + # extract the buffer from a gt4py.next.Field + # TODO(havogt): probably `Field` should provide the array interface methods when applicable + array = array.ndarray if device == "gpu" or (not device and hasattr(array, "__cuda_array_interface__")): return cp.asarray(array) if device == "cpu" or ( diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py index e580333bc8..8cfff12df4 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py @@ -312,7 +312,7 @@ def test_symbolref_validation_for_valid_tree(): SymbolTableRootNode( nodes=[SymbolChildNode(name="foo"), SymbolRefChildNode(name="foo")], ) - SymbolTableRootNode( + SymbolTableRootNode( # noqa: B018 nodes=[ SymbolChildNode(name="foo"), SymbolRefChildNode(name="foo"), diff --git a/tests/eve_tests/unit_tests/test_datamodels.py b/tests/eve_tests/unit_tests/test_datamodels.py index 8fa9e02cb6..0abb893dd4 100644 --- a/tests/eve_tests/unit_tests/test_datamodels.py +++ b/tests/eve_tests/unit_tests/test_datamodels.py @@ -15,6 +15,7 @@ from __future__ import annotations import enum +import numbers import types import typing from typing import Set # noqa: F401 # imported but unused (used in exec() context) @@ -1150,66 +1151,80 @@ class PartialGenericModel(datamodels.GenericDataModel, Generic[T]): with pytest.raises(TypeError, match="'PartialGenericModel__int.value'"): PartialGenericModel__int(value=["1"]) - def test_partial_specialization(self): - class PartialGenericModel(datamodels.GenericDataModel, Generic[T, U]): + def test_partial_concretization(self): + class BaseGenericModel(datamodels.GenericDataModel, Generic[T, U]): value: List[Tuple[T, U]] - PartialGenericModel(value=[]) - PartialGenericModel(value=[("value", 3)]) - PartialGenericModel(value=[(1, "value")]) - PartialGenericModel(value=[(-1.0, "value")]) - with pytest.raises(TypeError, match="'PartialGenericModel.value'"): - PartialGenericModel(value=1) - with pytest.raises(TypeError, match="'PartialGenericModel.value'"): - PartialGenericModel(value=(1, 2)) - with pytest.raises(TypeError, match="'PartialGenericModel.value'"): - PartialGenericModel(value=[()]) - with pytest.raises(TypeError, match="'PartialGenericModel.value'"): - PartialGenericModel(value=[(1,)]) + assert len(BaseGenericModel.__parameters__) == 2 + + BaseGenericModel(value=[]) + BaseGenericModel(value=[("value", 3)]) + BaseGenericModel(value=[(1, "value")]) + BaseGenericModel(value=[(-1.0, "value")]) + with pytest.raises(TypeError, match="'BaseGenericModel.value'"): + BaseGenericModel(value=1) + with pytest.raises(TypeError, match="'BaseGenericModel.value'"): + BaseGenericModel(value=(1, 2)) + with pytest.raises(TypeError, match="'BaseGenericModel.value'"): + BaseGenericModel(value=[()]) + with pytest.raises(TypeError, match="'BaseGenericModel.value'"): + BaseGenericModel(value=[(1,)]) + + PartiallyConcretizedGenericModel = BaseGenericModel[int, U] + + assert len(PartiallyConcretizedGenericModel.__parameters__) == 1 + + PartiallyConcretizedGenericModel(value=[]) + PartiallyConcretizedGenericModel(value=[(1, 2)]) + PartiallyConcretizedGenericModel(value=[(1, "value")]) + PartiallyConcretizedGenericModel(value=[(1, (11, 12))]) + with pytest.raises(TypeError, match=".value'"): + PartiallyConcretizedGenericModel(value=1) + with pytest.raises(TypeError, match=".value'"): + PartiallyConcretizedGenericModel(value=(1, 2)) + with pytest.raises(TypeError, match=".value'"): + PartiallyConcretizedGenericModel(value=[1.0]) + with pytest.raises(TypeError, match=".value'"): + PartiallyConcretizedGenericModel(value=["1"]) - print(f"{PartialGenericModel.__parameters__=}") - print(f"{hasattr(PartialGenericModel ,'__args__')=}") + FullyConcretizedGenericModel = PartiallyConcretizedGenericModel[str] - PartiallySpecializedGenericModel = PartialGenericModel[int, U] - print(f"{PartiallySpecializedGenericModel.__datamodel_fields__=}") - print(f"{PartiallySpecializedGenericModel.__parameters__=}") - print(f"{PartiallySpecializedGenericModel.__args__=}") + assert len(FullyConcretizedGenericModel.__parameters__) == 0 - PartiallySpecializedGenericModel(value=[]) - PartiallySpecializedGenericModel(value=[(1, 2)]) - PartiallySpecializedGenericModel(value=[(1, "value")]) - PartiallySpecializedGenericModel(value=[(1, (11, 12))]) + FullyConcretizedGenericModel(value=[]) + FullyConcretizedGenericModel(value=[(1, "value")]) + with pytest.raises(TypeError, match=".value'"): + FullyConcretizedGenericModel(value=1) + with pytest.raises(TypeError, match=".value'"): + FullyConcretizedGenericModel(value=(1, 2)) with pytest.raises(TypeError, match=".value'"): - PartiallySpecializedGenericModel(value=1) + FullyConcretizedGenericModel(value=[1.0]) with pytest.raises(TypeError, match=".value'"): - PartiallySpecializedGenericModel(value=(1, 2)) + FullyConcretizedGenericModel(value=["1"]) with pytest.raises(TypeError, match=".value'"): - PartiallySpecializedGenericModel(value=[1.0]) + FullyConcretizedGenericModel(value=1) with pytest.raises(TypeError, match=".value'"): - PartiallySpecializedGenericModel(value=["1"]) - - # TODO(egparedes): after fixing partial nested datamodel specialization - # noqa: e800 FullySpecializedGenericModel = PartiallySpecializedGenericModel[str] - # noqa: e800 print(f"{FullySpecializedGenericModel.__datamodel_fields__=}") - # noqa: e800 print(f"{FullySpecializedGenericModel.__parameters__=}") - # noqa: e800 print(f"{FullySpecializedGenericModel.__args__=}") - - # noqa: e800 FullySpecializedGenericModel(value=[]) - # noqa: e800 FullySpecializedGenericModel(value=[(1, "value")]) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=1) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=(1, 2)) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=[1.0]) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=["1"]) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=1) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=[(1, 2)]) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=[(1, (11, 12))]) + FullyConcretizedGenericModel(value=[(1, 2)]) + with pytest.raises(TypeError, match=".value'"): + FullyConcretizedGenericModel(value=[(1, (11, 12))]) + + def test_partial_concretization_with_typevar(self): + class PartialGenericModel(datamodels.GenericDataModel, Generic[T]): + a: T + values: List[T] + + B = TypeVar("B", bound=numbers.Number) + PartiallyConcretizedGenericModel = PartialGenericModel[B] + + PartiallyConcretizedGenericModel(a=1, values=[2, 3]) + PartiallyConcretizedGenericModel(a=-1.32, values=[2.2, 3j]) + + with pytest.raises(TypeError, match=".a'"): + PartiallyConcretizedGenericModel(a="1", values=[2, 3]) + with pytest.raises(TypeError, match=".values'"): + PartiallyConcretizedGenericModel(a=1, values=[1, "2"]) + with pytest.raises(TypeError, match=".values'"): + PartiallyConcretizedGenericModel(a=1, values=(1, 2)) # Reuse sample_type_data from test_field_type_hint @pytest.mark.parametrize(["type_hint", "valid_values", "wrong_values"], SAMPLE_TYPE_DATA) diff --git a/tests/eve_tests/unit_tests/test_type_validation.py b/tests/eve_tests/unit_tests/test_type_validation.py index 70ef033ff0..d9977f0d3a 100644 --- a/tests/eve_tests/unit_tests/test_type_validation.py +++ b/tests/eve_tests/unit_tests/test_type_validation.py @@ -28,6 +28,7 @@ ) from gt4py.eve.extended_typing import ( Any, + Callable, Dict, Final, ForwardRef, @@ -41,8 +42,8 @@ ) -VALIDATORS: Final = [type_val.simple_type_validator] -FACTORIES: Final = [type_val.simple_type_validator_factory] +VALIDATORS: Final[list[Callable]] = [type_val.simple_type_validator] +FACTORIES: Final[list[Callable]] = [type_val.simple_type_validator_factory] class SampleEnum(enum.Enum): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py new file mode 100644 index 0000000000..0de953d85f --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np + +import gt4py.next as gtx +from gt4py.next import int32 + +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import cartesian_case +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + fieldview_backend, + reduction_setup, +) + + +def test_with_bound_args(cartesian_case): + @gtx.field_operator + def fieldop_bound_args(a: cases.IField, scalar: int32, condition: bool) -> cases.IField: + if not condition: + scalar = 0 + return a + scalar + + @gtx.program + def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cases.IField): + fieldop_bound_args(a, scalar, condition, out=out) + + a = cases.allocate(cartesian_case, program_bound_args, "a")() + scalar = int32(1) + ref = a + scalar + out = cases.allocate(cartesian_case, program_bound_args, "out")() + + prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True) + cases.verify(cartesian_case, prog_bounds, a, out, inout=out, ref=ref) + + +def test_with_bound_args_order_args(cartesian_case): + @gtx.field_operator + def fieldop_args(a: cases.IField, condition: bool, scalar: int32) -> cases.IField: + scalar = 0 if not condition else scalar + return a + scalar + + @gtx.program(backend=cartesian_case.backend) + def program_args(a: cases.IField, condition: bool, scalar: int32, out: cases.IField): + fieldop_args(a, condition, scalar, out=out) + + a = cases.allocate(cartesian_case, program_args, "a")() + out = cases.allocate(cartesian_case, program_args, "out")() + + prog_bounds = program_args.with_bound_args(condition=True) + prog_bounds(a=a, scalar=int32(1), out=out, offset_provider={}) + np.allclose(out.asnumpy(), a.asnumpy() + int32(1)) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index a08931628b..70c79d7b6c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -898,26 +898,6 @@ def test_docstring(a: cases.IField): cases.verify(cartesian_case, test_docstring, a, inout=a, ref=a) -def test_with_bound_args(cartesian_case): - @gtx.field_operator - def fieldop_bound_args(a: cases.IField, scalar: int32, condition: bool) -> cases.IField: - if not condition: - scalar = 0 - return a + a + scalar - - @gtx.program - def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cases.IField): - fieldop_bound_args(a, scalar, condition, out=out) - - a = cases.allocate(cartesian_case, program_bound_args, "a")() - scalar = int32(1) - ref = a + a + 1 - out = cases.allocate(cartesian_case, program_bound_args, "out")() - - prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True) - cases.verify(cartesian_case, prog_bounds, a, out, inout=out, ref=ref) - - def test_domain(cartesian_case): @gtx.field_operator def fieldop_domain(a: cases.IField) -> cases.IField: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 698dce2b5c..d100cd380c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -30,16 +30,6 @@ def test_external_local_field(unstructured_case): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - @gtx.field_operator def testee( inp: gtx.Field[[Vertex, V2EDim], int32], ones: gtx.Field[[Edge], int32] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index e8d0c8b163..e2434d860a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -46,16 +46,6 @@ ids=["positive_values", "negative_values"], ) def test_maxover_execution_(unstructured_case, strategy): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - if unstructured_case.backend in [ gtfn.run_gtfn, gtfn.run_gtfn_gpu, @@ -79,16 +69,6 @@ def testee(edge_f: cases.EField) -> cases.VField: @pytest.mark.uses_unstructured_shift def test_minover_execution(unstructured_case): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - @gtx.field_operator def minover(edge_f: cases.EField) -> cases.VField: out = min_over(edge_f(V2E), axis=V2EDim) @@ -102,16 +82,6 @@ def minover(edge_f: cases.EField) -> cases.VField: @pytest.mark.uses_unstructured_shift def test_reduction_execution(unstructured_case): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - @gtx.field_operator def reduction(edge_f: cases.EField) -> cases.VField: return neighbor_sum(edge_f(V2E), axis=V2EDim) @@ -150,16 +120,6 @@ def fencil(edge_f: cases.EField, out: cases.VField): @pytest.mark.uses_unstructured_shift def test_reduction_with_common_expression(unstructured_case): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - @gtx.field_operator def testee(flux: cases.EField) -> cases.VField: return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index c86881ab7c..938c69fb52 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -20,6 +20,7 @@ import pytest import gt4py.next as gtx +from gt4py.next import errors from next_tests.integration_tests import cases from next_tests.integration_tests.cases import IDim, Ioff, JDim, cartesian_case, fieldview_backend @@ -222,7 +223,7 @@ def test_wrong_argument_type(cartesian_case, copy_program_def): inp = cartesian_case.as_field([JDim], np.ones((cartesian_case.default_sizes[JDim],))) out = cases.allocate(cartesian_case, copy_program, "out").strategy(cases.ConstInitializer(1))() - with pytest.raises(TypeError) as exc_info: + with pytest.raises(errors.DSLError) as exc_info: # program is defined on Field[[IDim], ...], but we call with # Field[[JDim], ...] copy_program(inp, out, offset_provider={}) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py new file mode 100644 index 0000000000..788081b81e --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -0,0 +1,119 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pytest +from numpy import int32, int64 + +from gt4py import next as gtx +from gt4py.next import common +from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms +from gt4py.next.program_processors import otf_compile_executor +from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries + +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import ( + E2V, + Case, + KDim, + Vertex, + cartesian_case, + unstructured_case, +) +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + reduction_setup, +) +from next_tests.toy_connectivity import Cell, Edge + + +@pytest.fixture +def run_gtfn_with_temporaries_and_symbolic_sizes(): + return otf_compile_executor.OTFBackend( + executor=otf_compile_executor.OTFCompileExecutor( + name="run_gtfn_with_temporaries_and_sizes", + otf_workflow=run_gtfn_with_temporaries.executor.otf_workflow.replace( + translation=run_gtfn_with_temporaries.executor.otf_workflow.translation.replace( + symbolic_domain_sizes={ + "Cell": "num_cells", + "Edge": "num_edges", + "Vertex": "num_vertices", + }, + ), + ), + ), + allocator=run_gtfn_with_temporaries.allocator, + ) + + +@pytest.fixture +def testee(): + @gtx.field_operator + def testee_op(a: cases.VField) -> cases.EField: + amul = a * 2 + return amul(E2V[0]) + amul(E2V[1]) + + @gtx.program + def prog( + a: cases.VField, + out: cases.EField, + num_vertices: int32, + num_edges: int64, + num_cells: int32, + ): + testee_op(a, out=out) + + return prog + + +def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, reduction_setup): + unstructured_case = Case( + run_gtfn_with_temporaries_and_symbolic_sizes, + offset_provider=reduction_setup.offset_provider, + default_sizes={ + Vertex: reduction_setup.num_vertices, + Edge: reduction_setup.num_edges, + Cell: reduction_setup.num_cells, + KDim: reduction_setup.k_levels, + }, + grid_type=common.GridType.UNSTRUCTURED, + ) + + a = cases.allocate(unstructured_case, testee, "a")() + out = cases.allocate(unstructured_case, testee, "out")() + + first_nbs, second_nbs = (reduction_setup.offset_provider["E2V"].table[:, i] for i in [0, 1]) + ref = (a.ndarray * 2)[first_nbs] + (a.ndarray * 2)[second_nbs] + + cases.verify( + unstructured_case, + testee, + a, + out, + reduction_setup.num_vertices, + reduction_setup.num_edges, + reduction_setup.num_cells, + inout=out, + ref=ref, + ) + + +def test_temporary_symbols(testee, reduction_setup): + itir_with_tmp = apply_common_transforms( + testee.itir, + lift_mode=LiftMode.FORCE_TEMPORARIES, + offset_provider=reduction_setup.offset_provider, + ) + + params = ["num_vertices", "num_edges", "num_cells"] + for param in params: + assert any([param == str(p) for p in itir_with_tmp.fencil.params]) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py index e851e7b130..5af4605988 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py @@ -18,7 +18,7 @@ from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fundef, offset from gt4py.next.iterator.tracing import trace_fencil_definition -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn @fundef @@ -69,7 +69,9 @@ def lap_fencil(i_size, j_size, k_size, i_off, j_off, k_off, out, inp): output_file = sys.argv[1] prog = trace_fencil_definition(lap_fencil, [None] * 8, use_arg_types=False) - generated_code = generate(prog, offset_provider={"i": IDim, "j": JDim}) + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider={"i": IDim, "j": JDim}, column_axis=None + ) with open(output_file, "w+") as output: output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py index 33c7d5baa7..3e8b88ac66 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py @@ -18,7 +18,7 @@ from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fundef from gt4py.next.iterator.tracing import trace_fencil_definition -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn IDim = gtx.Dimension("IDim") @@ -48,7 +48,9 @@ def copy_fencil(isize, jsize, ksize, inp, out): output_file = sys.argv[1] prog = trace_fencil_definition(copy_fencil, [None] * 5, use_arg_types=False) - generated_code = generate(prog, offset_provider={}) + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider={}, column_axis=None + ) with open(output_file, "w+") as output: output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py index f7472d4ac3..fdc57449ee 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py @@ -18,7 +18,7 @@ import gt4py.next as gtx from gt4py.next import Field, field_operator, program -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn IDim = gtx.Dimension("IDim") @@ -47,7 +47,9 @@ def copy_program( output_file = sys.argv[1] prog = copy_program.itir - generated_code = generate(prog, offset_provider={}) + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider={}, column_axis=None + ) with open(output_file, "w+") as output: output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py index 1dfd74baca..abc3755dca 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py @@ -19,7 +19,7 @@ from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fundef, offset from gt4py.next.iterator.tracing import trace_fencil_definition -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn, run_gtfn_imperative E2V = offset("E2V") @@ -92,13 +92,20 @@ def mapped_index(_, __) -> int: output_file = sys.argv[1] imperative = sys.argv[2].lower() == "true" + if imperative: + backend = run_gtfn_imperative + else: + backend = run_gtfn + # prog = trace(zavgS_fencil, [None] * 4) # TODO allow generating of 2 fencils prog = trace_fencil_definition(nabla_fencil, [None] * 7, use_arg_types=False) offset_provider = { "V2E": DummyConnectivity(max_neighbors=6, has_skip_values=True), "E2V": DummyConnectivity(max_neighbors=2, has_skip_values=False), } - generated_code = generate(prog, offset_provider=offset_provider, imperative=imperative) + generated_code = backend.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider=offset_provider, column_axis=None + ) with open(output_file, "w+") as output: output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py index 578a19faab..9755774fd0 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py @@ -19,7 +19,7 @@ from gt4py.next.iterator.runtime import closure, fundef from gt4py.next.iterator.tracing import trace_fencil_definition from gt4py.next.iterator.transforms import LiftMode -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn IDim = gtx.Dimension("IDim") @@ -67,10 +67,10 @@ def tridiagonal_solve_fencil(isize, jsize, ksize, a, b, c, d, x): prog = trace_fencil_definition(tridiagonal_solve_fencil, [None] * 8, use_arg_types=False) offset_provider = {"I": gtx.Dimension("IDim"), "J": gtx.Dimension("JDim")} - generated_code = generate( + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( prog, offset_provider=offset_provider, - lift_mode=LiftMode.SIMPLE_HEURISTIC, + runtime_lift_mode=LiftMode.SIMPLE_HEURISTIC, column_axis=KDim, ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 130f6bd29c..5bd255f80f 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from dataclasses import dataclass +import dataclasses import numpy as np import pytest @@ -201,22 +201,26 @@ def test_setup(fieldview_backend): grid_type=common.GridType.UNSTRUCTURED, ) - @dataclass(frozen=True) + @dataclasses.dataclass(frozen=True) class setup: - case: cases.Case = test_case - cell_size = case.default_sizes[Cell] - k_size = case.default_sizes[KDim] - z_alpha = case.as_field( + case: cases.Case = dataclasses.field(default_factory=lambda: test_case) + cell_size = test_case.default_sizes[Cell] + k_size = test_case.default_sizes[KDim] + z_alpha = test_case.as_field( [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size + 1)) ) - z_beta = case.as_field( + z_beta = test_case.as_field( + [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) + ) + z_q = test_case.as_field( + [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) + ) + w = test_case.as_field( [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) ) - z_q = case.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) - w = case.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) z_q_ref, w_ref = reference(z_alpha.ndarray, z_beta.ndarray, z_q.ndarray, w.ndarray) - dummy = case.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool)) - z_q_out = case.as_field([Cell, KDim], np.zeros((cell_size, k_size))) + dummy = test_case.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool)) + z_q_out = test_case.as_field([Cell, KDim], np.zeros((cell_size, k_size))) return setup() diff --git a/tests/next_tests/unit_tests/ffront_tests/ast_passes_tests/test_single_static_assign.py b/tests/next_tests/unit_tests/ffront_tests/ast_passes_tests/test_single_static_assign.py index 052f272d22..ea1cdb82a6 100644 --- a/tests/next_tests/unit_tests/ffront_tests/ast_passes_tests/test_single_static_assign.py +++ b/tests/next_tests/unit_tests/ffront_tests/ast_passes_tests/test_single_static_assign.py @@ -108,7 +108,10 @@ def test_unpacking_swap(): lines = ast.unparse(ssa_ast).split("\n") assert lines[0] == f"a{SEP}0 = 5" assert lines[1] == f"b{SEP}0 = 1" - assert lines[2] == f"(b{SEP}1, a{SEP}1) = (a{SEP}0, b{SEP}0)" + assert lines[2] in [ + f"(b{SEP}1, a{SEP}1) = (a{SEP}0, b{SEP}0)", + f"b{SEP}1, a{SEP}1 = (a{SEP}0, b{SEP}0)", + ] # unparse produces different parentheses in different Python versions def test_annotated_assign(): diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 86c3c98c62..5c2802f90c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -323,7 +323,7 @@ def test_update_cartesian_domains(): for a, s in (("JDim", "j"), ("KDim", "k")) ], ) - actual = update_domains(testee, {"I": gtx.Dimension("IDim")}) + actual = update_domains(testee, {"I": gtx.Dimension("IDim")}, symbolic_sizes=None) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_power_unrolling.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_power_unrolling.py new file mode 100644 index 0000000000..ae23becb4c --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_power_unrolling.py @@ -0,0 +1,161 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pytest + +from gt4py.eve import SymbolRef +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.power_unrolling import PowerUnrolling + + +def test_power_unrolling_zero(): + pytest.xfail( + "Not implementeds we don't have an easy way to determine the type of the one literal (type inference is to expensive)." + ) + testee = im.call("power")("x", 0) + expected = im.literal_from_value(1) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_one(): + testee = im.call("power")("x", 1) + expected = ir.SymRef(id=SymbolRef("x")) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_two(): + testee = im.call("power")("x", 2) + expected = im.multiplies_("x", "x") + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_two_x_plus_two(): + testee = im.call("power")(im.plus("x", 2), 2) + expected = im.let("power_1", im.plus("x", 2))( + im.let("power_2", im.multiplies_("power_1", "power_1"))("power_2") + ) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_two_x_plus_one_times_three(): + testee = im.call("power")(im.multiplies_(im.plus("x", 1), 3), 2) + expected = im.let("power_1", im.multiplies_(im.plus("x", 1), 3))( + im.let("power_2", im.multiplies_("power_1", "power_1"))("power_2") + ) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_three(): + testee = im.call("power")("x", 3) + expected = im.multiplies_(im.multiplies_("x", "x"), "x") + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_four(): + testee = im.call("power")("x", 4) + expected = im.let("power_2", im.multiplies_("x", "x"))(im.multiplies_("power_2", "power_2")) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_five(): + testee = im.call("power")("x", 5) + tmp2 = im.multiplies_("x", "x") + expected = im.multiplies_(im.multiplies_(tmp2, tmp2), "x") + expected = im.let("power_2", im.multiplies_("x", "x"))( + im.multiplies_(im.multiplies_("power_2", "power_2"), "x") + ) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_seven(): + testee = im.call("power")("x", 7) + expected = im.call("power")("x", 7) + + actual = PowerUnrolling.apply(testee, max_unroll=5) + assert actual == expected + + +def test_power_unrolling_seven_unrolled(): + testee = im.call("power")("x", 7) + expected = im.let("power_2", im.multiplies_("x", "x"))( + im.multiplies_(im.multiplies_(im.multiplies_("power_2", "power_2"), "power_2"), "x") + ) + + actual = PowerUnrolling.apply(testee, max_unroll=7) + assert actual == expected + + +def test_power_unrolling_seven_x_plus_one_unrolled(): + testee = im.call("power")(im.plus("x", 1), 7) + expected = im.let("power_1", im.plus("x", 1))( + im.let("power_2", im.multiplies_("power_1", "power_1"))( + im.let("power_4", im.multiplies_("power_2", "power_2"))( + im.multiplies_(im.multiplies_("power_4", "power_2"), "power_1") + ) + ) + ) + + actual = PowerUnrolling.apply(testee, max_unroll=7) + assert actual == expected + + +def test_power_unrolling_eight(): + testee = im.call("power")("x", 8) + expected = im.call("power")("x", 8) + + actual = PowerUnrolling.apply(testee, max_unroll=5) + assert actual == expected + + +def test_power_unrolling_eight_unrolled(): + testee = im.call("power")("x", 8) + expected = im.let("power_2", im.multiplies_("x", "x"))( + im.let("power_4", im.multiplies_("power_2", "power_2"))( + im.multiplies_("power_4", "power_4") + ) + ) + + actual = PowerUnrolling.apply(testee, max_unroll=8) + assert actual == expected + + +def test_power_unrolling_eight_x_plus_one_unrolled(): + testee = im.call("power")(im.plus("x", 1), 8) + expected = im.let("power_1", im.plus("x", 1))( + im.let("power_2", im.multiplies_("power_1", "power_1"))( + im.let("power_4", im.multiplies_("power_2", "power_2"))( + im.let("power_8", im.multiplies_("power_4", "power_4"))("power_8") + ) + ) + ) + + actual = PowerUnrolling.apply(testee, max_unroll=8) + assert actual == expected diff --git a/tox.ini b/tox.ini index 44dc912c8a..817f721f71 100644 --- a/tox.ini +++ b/tox.ini @@ -11,21 +11,24 @@ envlist = # docs labels = test-cartesian-cpu = cartesian-py38-internal-cpu, cartesian-py39-internal-cpu, cartesian-py310-internal-cpu, \ - cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu + cartesian-py311-internal-cpu, cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu, \ + cartesian-py311-dace-cpu - test-eve-cpu = eve-py38, eve-py39, eve-py310 + test-eve-cpu = eve-py38, eve-py39, eve-py310, eve-py311 - test-next-cpu = next-py310-nomesh, next-py310-atlas + test-next-cpu = next-py310-nomesh, next-py311-nomesh, next-py310-atlas, next-py311-atlas test-storage-cpu = storage-py38-internal-cpu, storage-py39-internal-cpu, storage-py310-internal-cpu, \ - storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu + storage-py311-internal-cpu, storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu, \ + storage-py311-dace-cpu test-cpu = cartesian-py38-internal-cpu, cartesian-py39-internal-cpu, cartesian-py310-internal-cpu, \ - cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu, \ - eve-py38, eve-py39, eve-py310, \ - next-py310-nomesh, next-py310-atlas, \ - storage-py38-internal-cpu, storage-py39-internal-cpu, storage-py310-internal-cpu, \ - storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu + cartesian-py311-internal-cpu, cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu, \ + cartesian-py311-dace-cpu, \ + eve-py38, eve-py39, eve-py310, eve-py311, \ + next-py310-nomesh, next-py311-nomesh, next-py310-atlas, next-py311-atlas, \ + storage-py38-internal-cpu, storage-py39-internal-cpu, storage-py310-internal-cpu, storage-py311-internal-cpu, \ + storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu, storage-py311-dace-cpu [testenv] deps = -r {tox_root}{/}{env:ENV_REQUIREMENTS_FILE:requirements-dev.txt} @@ -44,7 +47,7 @@ pass_env = NUM_PROCESSES set_env = PYTHONWARNINGS = {env:PYTHONWARNINGS:ignore:Support for `[tool.setuptools]` in `pyproject.toml` is still *beta*:UserWarning} -[testenv:cartesian-py{38,39,310}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] +[testenv:cartesian-py{38,39,310,311}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] description = Run 'gt4py.cartesian' tests pass_env = {[testenv]pass_env}, BOOST_ROOT, BOOST_HOME, CUDA_HOME, CUDA_PATH, CXX, CC, OPENMP_CPPFLAGS, OPENMP_LDFLAGS, PIP_USER, PYTHONUSERBASE allowlist_externals = @@ -65,13 +68,13 @@ commands = ; coverage json --rcfile=setup.cfg ; coverage html --rcfile=setup.cfg --show-contexts -[testenv:eve-py{38,39,310}] +[testenv:eve-py{38,39,310,311}] description = Run 'gt4py.eve' tests commands = python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} {posargs} tests{/}eve_tests python -m pytest --doctest-modules src{/}gt4py{/}eve -[testenv:next-py{310}-{nomesh,atlas}-{cpu,cuda,cuda11x,cuda12x}] +[testenv:next-py{310,311}-{nomesh,atlas}-{cpu,cuda,cuda11x,cuda12x}] description = Run 'gt4py.next' tests pass_env = {[testenv]pass_env}, BOOST_ROOT, BOOST_HOME, CUDA_HOME, CUDA_PATH deps = @@ -87,14 +90,14 @@ commands = # atlas-{cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and requires_gpu" {posargs} tests{/}next_tests # TODO(ricoh): activate when such tests exist pytest --doctest-modules src{/}gt4py{/}next -[testenv:storage-py{38,39,310}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] +[testenv:storage-py{38,39,310,311}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] description = Run 'gt4py.storage' tests commands = cpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_gpu" {posargs} tests{/}storage_tests {cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_gpu" {posargs} tests{/}storage_tests #pytest doctest-modules {posargs} src{/}gt4py{/}storage -[testenv:linters-py{38,39,310}] +[testenv:linters-py{38,39,310,311}] description = Run linters commands = flake8 .{/}src @@ -134,11 +137,13 @@ description = py38: Update requirements for testing a specific python version py39: Update requirements for testing a specific python version py310: Update requirements for testing a specific python version + py311: Update requirements for testing a specific python version base_python = common: py38 py38: py38 py39: py39 py310: py310 + py311: py311 deps = cogapp>=3.3 pip-tools>=6.10 @@ -178,7 +183,7 @@ commands = # Run cog to update .pre-commit-config.yaml with new versions common: cog -r -P .pre-commit-config.yaml -[testenv:dev-py{38,39,310}{-atlas,}] +[testenv:dev-py{38,39,310,311}{-atlas,}] description = Initialize development environment for gt4py deps = -r {tox_root}{/}requirements-dev.txt