From b900b474566f21339d5c99aa2365f9bed86bf1ec Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 19 Jan 2024 11:46:04 +0100 Subject: [PATCH] build[cartesian][next]: Bump dace version from 0.14.4 to 0.15.1 (#1391) Bumping dace version to 0.15.1 affects both cartesian and next gt4py: * cartesian: removed try/except for dace backward compatibility * next: re-enabled some tests that were broken on dace 0.14.4 * all: fixed and/or suppressed flake8 and mypy errors --- .pre-commit-config.yaml | 38 ++-- constraints.txt | 191 ++++++++++-------- min-extra-requirements-test.txt | 4 +- pyproject.toml | 6 +- requirements-dev.txt | 191 ++++++++++-------- src/gt4py/__init__.py | 2 +- src/gt4py/cartesian/backend/dace_backend.py | 8 +- src/gt4py/cartesian/gtc/dace/nodes.py | 2 +- src/gt4py/eve/datamodels/core.py | 2 +- src/gt4py/eve/utils.py | 4 +- src/gt4py/next/common.py | 11 +- src/gt4py/next/ffront/fbuiltins.py | 2 +- src/gt4py/next/otf/workflow.py | 2 +- .../runners/dace_iterator/__init__.py | 13 +- .../runners/dace_iterator/itir_to_sdfg.py | 14 +- .../unit_tests/test_gtc/test_common.py | 2 +- .../ffront_tests/test_external_local_field.py | 10 - .../ffront_tests/test_gt4py_builtins.py | 40 ---- .../test_temporaries_with_sizes.py | 12 +- 19 files changed, 262 insertions(+), 292 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b1092fafd0..d9cfa0ff48 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -62,7 +62,7 @@ repos: ## version = re.search('black==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: '{version}' # version from constraints.txt") ##]]] - rev: '23.9.1' # version from constraints.txt + rev: '23.11.0' # version from constraints.txt ##[[[end]]] hooks: - id: black @@ -73,7 +73,7 @@ repos: ## version = re.search('isort==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: '{version}' # version from constraints.txt") ##]]] - rev: '5.12.0' # version from constraints.txt + rev: '5.13.0' # version from constraints.txt ##[[[end]]] hooks: - id: isort @@ -97,14 +97,14 @@ repos: ## print(f"- {pkg}==" + str(re.search(f'\n{pkg}==([0-9\.]*)', constraints)[1])) ##]]] - darglint==1.8.1 - - flake8-bugbear==23.9.16 - - flake8-builtins==2.1.0 + - flake8-bugbear==23.12.2 + - flake8-builtins==2.2.0 - flake8-debugger==4.1.2 - flake8-docstrings==1.7.0 - flake8-eradicate==1.5.0 - flake8-mutable==1.2.0 - flake8-pyproject==1.2.3 - - pygments==2.16.1 + - pygments==2.17.2 ##[[[end]]] # - flake8-rst-docstrings # Disabled for now due to random false positives exclude: | @@ -146,9 +146,9 @@ repos: ## version = re.search('mypy==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"#========= FROM constraints.txt: v{version} =========") ##]]] - #========= FROM constraints.txt: v1.5.1 ========= + #========= FROM constraints.txt: v1.7.1 ========= ##[[[end]]] - rev: v1.5.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) + rev: v1.7.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) hooks: - id: mypy additional_dependencies: # versions from constraints.txt @@ -162,26 +162,26 @@ repos: ##]]] - astunparse==1.6.3 - attrs==23.1.0 - - black==23.9.1 - - boltons==23.0.0 + - black==23.11.0 + - boltons==23.1.1 - cached-property==1.5.2 - click==8.1.7 - - cmake==3.27.5 + - cmake==3.27.9 - cytoolz==0.12.2 - - deepdiff==6.5.0 + - deepdiff==6.7.1 - devtools==0.12.2 - - frozendict==2.3.8 + - frozendict==2.3.10 - gridtools-cpp==2.3.1 - - importlib-resources==6.0.1 + - importlib-resources==6.1.1 - jinja2==3.1.2 - - lark==1.1.7 - - mako==1.2.4 - - nanobind==1.5.2 - - ninja==1.11.1 + - lark==1.1.8 + - mako==1.3.0 + - nanobind==1.8.0 + - ninja==1.11.1.1 - numpy==1.24.4 - - packaging==23.1 + - packaging==23.2 - pybind11==2.11.1 - - setuptools==68.2.2 + - setuptools==69.0.2 - tabulate==0.9.0 - typing-extensions==4.5.0 - xxhash==3.0.0 diff --git a/constraints.txt b/constraints.txt index b334851af1..81abd64c6e 100644 --- a/constraints.txt +++ b/constraints.txt @@ -6,124 +6,136 @@ # aenum==3.1.15 # via dace alabaster==0.7.13 # via sphinx -asttokens==2.4.0 # via devtools +asttokens==2.4.1 # via devtools astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing -babel==2.12.1 # via sphinx -black==23.9.1 # via gt4py (pyproject.toml) -blinker==1.6.2 # via flask -boltons==23.0.0 # via gt4py (pyproject.toml) +babel==2.13.1 # via sphinx +black==23.11.0 # via gt4py (pyproject.toml) +blinker==1.7.0 # via flask +boltons==23.1.1 # via gt4py (pyproject.toml) build==1.0.3 # via pip-tools cached-property==1.5.2 # via gt4py (pyproject.toml) -cachetools==5.3.1 # via tox -certifi==2023.7.22 # via requests -cffi==1.15.1 # via cryptography +cachetools==5.3.2 # via tox +cerberus==1.3.5 # via plette +certifi==2023.11.17 # via requests +cffi==1.16.0 # via cryptography cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox -charset-normalizer==3.2.0 # via requests -clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) +charset-normalizer==3.3.2 # via requests +clang-format==17.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.27.5 # via gt4py (pyproject.toml) +cmake==3.27.9 # via dace, gt4py (pyproject.toml) cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox -coverage==7.3.1 # via -r requirements-dev.in, pytest-cov -cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis +coverage==7.3.2 # via -r requirements-dev.in, pytest-cov +cryptography==41.0.7 # via types-paramiko, types-pyopenssl, types-redis cytoolz==0.12.2 # via gt4py (pyproject.toml) -dace==0.14.4 # via gt4py (pyproject.toml) +dace==0.15.1 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -deepdiff==6.5.0 # via gt4py (pyproject.toml) +deepdiff==6.7.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.7 # via dace -distlib==0.3.7 # via virtualenv -docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme +distlib==0.3.7 # via requirementslib, virtualenv +distro==1.8.0 # via scikit-build +docopt==0.6.2 # via pipreqs +docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme eradicate==2.3.0 # via flake8-eradicate -exceptiongroup==1.1.3 # via hypothesis, pytest +exceptiongroup==1.2.0 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist -executing==1.2.0 # via devtools +executing==2.0.1 # via devtools factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==19.6.1 # via factory-boy -fastjsonschema==2.18.0 # via nbformat -filelock==3.12.4 # via tox, virtualenv +faker==20.1.0 # via factory-boy +fastjsonschema==2.19.0 # via nbformat +filelock==3.13.1 # via tox, virtualenv flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings -flake8-bugbear==23.9.16 # via -r requirements-dev.in -flake8-builtins==2.1.0 # via -r requirements-dev.in +flake8-bugbear==23.12.2 # via -r requirements-dev.in +flake8-builtins==2.2.0 # via -r requirements-dev.in flake8-debugger==4.1.2 # via -r requirements-dev.in flake8-docstrings==1.7.0 # via -r requirements-dev.in flake8-eradicate==1.5.0 # via -r requirements-dev.in flake8-mutable==1.2.0 # via -r requirements-dev.in flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in -flask==2.3.3 # via dace -frozendict==2.3.8 # via gt4py (pyproject.toml) +flask==3.0.0 # via dace +fparser==0.1.3 # via dace +frozendict==2.3.10 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.86.1 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.5.29 # via pre-commit -idna==3.4 # via requests +hypothesis==6.92.0 # via -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.5.33 # via pre-commit +idna==3.6 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.8.0 # via build, flask, sphinx -importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications +importlib-metadata==7.0.0 # via build, flask, fparser, sphinx +importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest -isort==5.12.0 # via -r requirements-dev.in +isort==5.13.0 # via -r requirements-dev.in itsdangerous==2.1.2 # via flask jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx -jsonschema==4.19.0 # via nbformat -jsonschema-specifications==2023.7.1 # via jsonschema -jupyter-core==5.3.1 # via nbformat -jupytext==1.15.2 # via -r requirements-dev.in -lark==1.1.7 # via gt4py (pyproject.toml) -mako==1.2.4 # via gt4py (pyproject.toml) +jsonschema==4.20.0 # via nbformat +jsonschema-specifications==2023.11.2 # via jsonschema +jupyter-core==5.5.0 # via nbformat +jupytext==1.16.0 # via -r requirements-dev.in +lark==1.1.8 # via gt4py (pyproject.toml) +mako==1.3.0 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins markupsafe==2.1.3 # via jinja2, mako, werkzeug mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.5.1 # via -r requirements-dev.in +mypy==1.7.1 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==1.5.2 # via gt4py (pyproject.toml) +nanobind==1.8.0 # via gt4py (pyproject.toml) nbformat==5.9.2 # via jupytext networkx==3.1 # via dace -ninja==1.11.1 # via gt4py (pyproject.toml) +ninja==1.11.1.1 # via gt4py (pyproject.toml) nodeenv==1.8.0 # via pre-commit numpy==1.24.4 # via dace, gt4py (pyproject.toml), types-jack-client ordered-set==4.1.0 # via deepdiff -packaging==23.1 # via black, build, gt4py (pyproject.toml), pyproject-api, pytest, sphinx, tox -pathspec==0.11.2 # via black +packaging==23.2 # via black, build, gt4py (pyproject.toml), jupytext, pyproject-api, pytest, scikit-build, setuptools-scm, sphinx, tox +pathspec==0.12.1 # via black +pep517==0.13.1 # via requirementslib +pip-api==0.0.30 # via isort pip-tools==7.3.0 # via -r requirements-dev.in -pipdeptree==2.13.0 # via -r requirements-dev.in +pipdeptree==2.13.1 # via -r requirements-dev.in +pipreqs==0.4.13 # via isort pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv +platformdirs==4.1.0 # via black, jupyter-core, requirementslib, tox, virtualenv +plette==0.4.4 # via requirementslib pluggy==1.3.0 # via pytest, tox ply==3.11 # via dace -pre-commit==3.4.0 # via -r requirements-dev.in -psutil==5.9.5 # via -r requirements-dev.in, pytest-xdist +pre-commit==3.5.0 # via -r requirements-dev.in +psutil==5.9.6 # via -r requirements-dev.in, pytest-xdist pybind11==2.11.1 # via gt4py (pyproject.toml) -pycodestyle==2.11.0 # via flake8, flake8-debugger +pycodestyle==2.11.1 # via flake8, flake8-debugger pycparser==2.21 # via cffi +pydantic==1.10.13 # via requirementslib pydocstyle==6.3.0 # via flake8-docstrings pyflakes==3.1.0 # via flake8 -pygments==2.16.1 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx +pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.2 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==7.4.3 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in -pytest-factoryboy==2.5.1 # via -r requirements-dev.in -pytest-xdist==3.3.1 # via -r requirements-dev.in +pytest-factoryboy==2.6.0 # via -r requirements-dev.in +pytest-xdist==3.5.0 # via -r requirements-dev.in python-dateutil==2.8.2 # via faker pytz==2023.3.post1 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit -referencing==0.30.2 # via jsonschema, jsonschema-specifications -requests==2.31.0 # via dace, sphinx +referencing==0.32.0 # via jsonschema, jsonschema-specifications +requests==2.31.0 # via dace, requirementslib, sphinx, yarg +requirementslib==3.0.0 # via isort restructuredtext-lint==1.4.0 # via flake8-rst-docstrings -rpds-py==0.10.3 # via jsonschema, referencing -ruff==0.0.290 # via -r requirements-dev.in +rpds-py==0.13.2 # via jsonschema, referencing +ruff==0.1.7 # via -r requirements-dev.in +scikit-build==0.17.6 # via dace +setuptools-scm==8.0.4 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx sortedcontainers==2.4.0 # via hypothesis sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==1.3.0 # via -r requirements-dev.in +sphinx-rtd-theme==2.0.0 # via -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx sphinxcontrib-htmlhelp==2.0.1 # via sphinx @@ -131,31 +143,32 @@ sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme sphinxcontrib-jsmath==1.0.1 # via sphinx sphinxcontrib-qthelp==1.0.3 # via sphinx sphinxcontrib-serializinghtml==1.1.5 # via sphinx -sympy==1.12 # via dace, gt4py (pyproject.toml) +sympy==1.9 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext -tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox +tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pep517, pip-tools, pyproject-api, pyproject-hooks, pytest, scikit-build, setuptools-scm, tox +tomlkit==0.12.3 # via plette, requirementslib toolz==0.12.0 # via cytoolz -tox==4.11.3 # via -r requirements-dev.in -traitlets==5.10.0 # via jupyter-core, nbformat +tox==4.11.4 # via -r requirements-dev.in +traitlets==5.14.0 # via jupyter-core, nbformat types-aiofiles==23.2.0.0 # via types-all types-all==1.0.0 # via -r requirements-dev.in types-annoy==1.17.8.4 # via types-all types-atomicwrites==1.4.5.1 # via types-all types-backports==0.1.3 # via types-all types-backports-abc==0.5.2 # via types-all -types-bleach==6.0.0.4 # via types-all +types-bleach==6.1.0.1 # via types-all types-boto==2.49.18.9 # via types-all -types-cachetools==5.3.0.6 # via types-all +types-cachetools==5.3.0.7 # via types-all types-certifi==2021.10.8.3 # via types-all -types-cffi==1.15.1.15 # via types-jack-client +types-cffi==1.16.0.0 # via types-jack-client types-characteristic==14.3.7 # via types-all types-chardet==5.0.4.6 # via types-all types-click==7.1.8 # via types-all, types-flask -types-click-spinner==0.1.13.5 # via types-all +types-click-spinner==0.1.13.6 # via types-all types-colorama==0.4.15.12 # via types-all types-contextvars==2.4.7.3 # via types-all -types-croniter==1.4.0.1 # via types-all +types-croniter==2.0.0.0 # via types-all types-cryptography==3.3.23.2 # via types-all, types-openssl-python, types-pyjwt types-dataclasses==0.6.6 # via types-all types-dateparser==1.1.4.10 # via types-all @@ -176,44 +189,44 @@ types-futures==3.3.8 # via types-all types-geoip2==3.0.0 # via types-all types-ipaddress==1.0.8 # via types-all, types-maxminddb types-itsdangerous==1.1.6 # via types-all -types-jack-client==0.5.10.9 # via types-all +types-jack-client==0.5.10.10 # via types-all types-jinja2==2.11.9 # via types-all, types-flask types-kazoo==0.1.3 # via types-all -types-markdown==3.4.2.10 # via types-all +types-markdown==3.5.0.3 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 -types-mock==5.1.0.2 # via types-all +types-mock==5.1.0.3 # via types-all types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all -types-paramiko==3.3.0.0 # via types-all, types-pysftp +types-paramiko==3.3.0.2 # via types-all, types-pysftp types-pathlib2==2.3.0 # via types-all -types-pillow==10.0.0.3 # via types-all +types-pillow==10.1.0.2 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.1 # via types-all -types-protobuf==4.24.0.1 # via types-all +types-protobuf==4.24.0.4 # via types-all types-pyaudio==0.2.16.7 # via types-all types-pycurl==7.45.2.5 # via types-all types-pyfarmhash==0.3.1.2 # via types-all types-pyjwt==1.7.1 # via types-all types-pymssql==2.1.0 # via types-all types-pymysql==1.1.0.1 # via types-all -types-pyopenssl==23.2.0.2 # via types-redis +types-pyopenssl==23.3.0.0 # via types-redis types-pyrfc3339==1.1.1.5 # via types-all types-pysftp==0.2.17.6 # via types-all types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.1.0 # via types-all, types-tzlocal +types-pytz==2023.3.1.1 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all -types-pyyaml==6.0.12.11 # via types-all -types-redis==4.6.0.6 # via types-all -types-requests==2.31.0.2 # via types-all +types-pyyaml==6.0.12.12 # via types-all +types-redis==4.6.0.11 # via types-all +types-requests==2.31.0.10 # via types-all types-retry==0.9.9.4 # via types-all types-routes==2.5.0 # via types-all types-scribe==2.0.0 # via types-all -types-setuptools==68.2.0.0 # via types-cffi +types-setuptools==69.0.0.0 # via types-cffi types-simplejson==3.19.0.2 # via types-all types-singledispatch==4.1.0.0 # via types-all types-six==1.16.21.9 # via types-all @@ -222,21 +235,21 @@ types-termcolor==1.1.6.2 # via types-all types-toml==0.10.8.7 # via types-all types-tornado==5.1.1 # via types-all types-typed-ast==1.5.8.7 # via types-all -types-tzlocal==5.0.1.1 # via types-all +types-tzlocal==5.1.0.1 # via types-all types-ujson==5.8.0.1 # via types-all -types-urllib3==1.26.25.14 # via types-requests types-waitress==2.1.4.9 # via types-all types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all -typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy -urllib3==2.0.4 # via requests -virtualenv==20.24.5 # via pre-commit, tox -websockets==11.0.3 # via dace -werkzeug==2.3.7 # via flask -wheel==0.41.2 # via astunparse, pip-tools +typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pydantic, pytest-factoryboy, setuptools-scm +urllib3==2.1.0 # via requests, types-requests +virtualenv==20.25.0 # via pre-commit, tox +websockets==12.0 # via dace +werkzeug==3.0.1 # via flask +wheel==0.42.0 # via astunparse, pip-tools, scikit-build xxhash==3.0.0 # via gt4py (pyproject.toml) -zipp==3.16.2 # via importlib-metadata, importlib-resources +yarg==0.1.9 # via pipreqs +zipp==3.17.0 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: -pip==23.2.1 # via pip-tools -setuptools==68.2.2 # via gt4py (pyproject.toml), nodeenv, pip-tools +pip==23.3.1 # via pip-api, pip-tools, requirementslib +setuptools==69.0.2 # via gt4py (pyproject.toml), nodeenv, pip-tools, requirementslib, scikit-build, setuptools-scm diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 17709206a0..fd7724bac9 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -25,7 +25,7 @@ cmake==3.22 cogapp==3.3 coverage[toml]==5.0 cytoolz==0.12.0 -dace==0.14.2 +dace==0.15.1 darglint==1.6 deepdiff==5.6.0 devtools==0.6 @@ -70,7 +70,7 @@ scipy==1.7.2 setuptools==65.5.0 sphinx==4.4 sphinx_rtd_theme==1.0 -sympy==1.7 +sympy==1.9 tabulate==0.8.10 tomli==2.0.1 tox==3.2.0 diff --git a/pyproject.toml b/pyproject.toml index 5d7a2f2cb6..675bdae9d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,15 +69,15 @@ requires-python = '>=3.8' cuda = ['cupy>=12.0'] cuda11x = ['cupy-cuda11x>=12.0'] cuda12x = ['cupy-cuda12x>=12.0'] -dace = ['dace>=0.14.2,<0.15', 'sympy>=1.7'] +dace = ['dace>=0.15.1,<0.16', 'sympy>=1.9'] formatting = ['clang-format>=9.0'] # Always add all extra packages to 'full' for a simple full gt4py installation full = [ 'clang-format>=9.0', - 'dace>=0.14.2,<0.15', + 'dace>=0.15.1,<0.16', 'hypothesis>=6.0.0', 'pytest>=7.0', - 'sympy>=1.7', + 'sympy>=1.9', 'scipy>=1.7.2', 'jax[cpu]>=0.4.13' ] diff --git a/requirements-dev.txt b/requirements-dev.txt index d6dcc12d21..0fa523866f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,124 +6,136 @@ # aenum==3.1.15 # via dace alabaster==0.7.13 # via sphinx -asttokens==2.4.0 # via devtools +asttokens==2.4.1 # via devtools astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing -babel==2.12.1 # via sphinx -black==23.9.1 # via gt4py (pyproject.toml) -blinker==1.6.2 # via flask -boltons==23.0.0 # via gt4py (pyproject.toml) +babel==2.13.1 # via sphinx +black==23.11.0 # via gt4py (pyproject.toml) +blinker==1.7.0 # via flask +boltons==23.1.1 # via gt4py (pyproject.toml) build==1.0.3 # via pip-tools cached-property==1.5.2 # via gt4py (pyproject.toml) -cachetools==5.3.1 # via tox -certifi==2023.7.22 # via requests -cffi==1.15.1 # via cryptography +cachetools==5.3.2 # via tox +cerberus==1.3.5 # via plette +certifi==2023.11.17 # via requests +cffi==1.16.0 # via cryptography cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox -charset-normalizer==3.2.0 # via requests -clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) +charset-normalizer==3.3.2 # via requests +clang-format==17.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.27.5 # via gt4py (pyproject.toml) +cmake==3.27.9 # via dace, gt4py (pyproject.toml) cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox -coverage[toml]==7.3.1 # via -r requirements-dev.in, pytest-cov -cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis +coverage[toml]==7.3.2 # via -r requirements-dev.in, pytest-cov +cryptography==41.0.7 # via types-paramiko, types-pyopenssl, types-redis cytoolz==0.12.2 # via gt4py (pyproject.toml) -dace==0.14.4 # via gt4py (pyproject.toml) +dace==0.15.1 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -deepdiff==6.5.0 # via gt4py (pyproject.toml) +deepdiff==6.7.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.7 # via dace -distlib==0.3.7 # via virtualenv -docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme +distlib==0.3.7 # via requirementslib, virtualenv +distro==1.8.0 # via scikit-build +docopt==0.6.2 # via pipreqs +docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme eradicate==2.3.0 # via flake8-eradicate -exceptiongroup==1.1.3 # via hypothesis, pytest +exceptiongroup==1.2.0 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist -executing==1.2.0 # via devtools +executing==2.0.1 # via devtools factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==19.6.1 # via factory-boy -fastjsonschema==2.18.0 # via nbformat -filelock==3.12.4 # via tox, virtualenv +faker==20.1.0 # via factory-boy +fastjsonschema==2.19.0 # via nbformat +filelock==3.13.1 # via tox, virtualenv flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings -flake8-bugbear==23.9.16 # via -r requirements-dev.in -flake8-builtins==2.1.0 # via -r requirements-dev.in +flake8-bugbear==23.12.2 # via -r requirements-dev.in +flake8-builtins==2.2.0 # via -r requirements-dev.in flake8-debugger==4.1.2 # via -r requirements-dev.in flake8-docstrings==1.7.0 # via -r requirements-dev.in flake8-eradicate==1.5.0 # via -r requirements-dev.in flake8-mutable==1.2.0 # via -r requirements-dev.in flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in -flask==2.3.3 # via dace -frozendict==2.3.8 # via gt4py (pyproject.toml) +flask==3.0.0 # via dace +fparser==0.1.3 # via dace +frozendict==2.3.10 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.86.1 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.5.29 # via pre-commit -idna==3.4 # via requests +hypothesis==6.92.0 # via -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.5.33 # via pre-commit +idna==3.6 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.8.0 # via build, flask, sphinx -importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications +importlib-metadata==7.0.0 # via build, flask, fparser, sphinx +importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest -isort==5.12.0 # via -r requirements-dev.in +isort==5.13.0 # via -r requirements-dev.in itsdangerous==2.1.2 # via flask jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx -jsonschema==4.19.0 # via nbformat -jsonschema-specifications==2023.7.1 # via jsonschema -jupyter-core==5.3.1 # via nbformat -jupytext==1.15.2 # via -r requirements-dev.in -lark==1.1.7 # via gt4py (pyproject.toml) -mako==1.2.4 # via gt4py (pyproject.toml) +jsonschema==4.20.0 # via nbformat +jsonschema-specifications==2023.11.2 # via jsonschema +jupyter-core==5.5.0 # via nbformat +jupytext==1.16.0 # via -r requirements-dev.in +lark==1.1.8 # via gt4py (pyproject.toml) +mako==1.3.0 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins markupsafe==2.1.3 # via jinja2, mako, werkzeug mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.5.1 # via -r requirements-dev.in +mypy==1.7.1 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==1.5.2 # via gt4py (pyproject.toml) +nanobind==1.8.0 # via gt4py (pyproject.toml) nbformat==5.9.2 # via jupytext networkx==3.1 # via dace -ninja==1.11.1 # via gt4py (pyproject.toml) +ninja==1.11.1.1 # via gt4py (pyproject.toml) nodeenv==1.8.0 # via pre-commit numpy==1.24.4 # via dace, gt4py (pyproject.toml), types-jack-client ordered-set==4.1.0 # via deepdiff -packaging==23.1 # via black, build, gt4py (pyproject.toml), pyproject-api, pytest, sphinx, tox -pathspec==0.11.2 # via black +packaging==23.2 # via black, build, gt4py (pyproject.toml), jupytext, pyproject-api, pytest, scikit-build, setuptools-scm, sphinx, tox +pathspec==0.12.1 # via black +pep517==0.13.1 # via requirementslib +pip-api==0.0.30 # via isort pip-tools==7.3.0 # via -r requirements-dev.in -pipdeptree==2.13.0 # via -r requirements-dev.in +pipdeptree==2.13.1 # via -r requirements-dev.in +pipreqs==0.4.13 # via isort pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv +platformdirs==4.1.0 # via black, jupyter-core, requirementslib, tox, virtualenv +plette[validation]==0.4.4 # via requirementslib pluggy==1.3.0 # via pytest, tox ply==3.11 # via dace -pre-commit==3.4.0 # via -r requirements-dev.in -psutil==5.9.5 # via -r requirements-dev.in, pytest-xdist +pre-commit==3.5.0 # via -r requirements-dev.in +psutil==5.9.6 # via -r requirements-dev.in, pytest-xdist pybind11==2.11.1 # via gt4py (pyproject.toml) -pycodestyle==2.11.0 # via flake8, flake8-debugger +pycodestyle==2.11.1 # via flake8, flake8-debugger pycparser==2.21 # via cffi +pydantic==1.10.13 # via requirementslib pydocstyle==6.3.0 # via flake8-docstrings pyflakes==3.1.0 # via flake8 -pygments==2.16.1 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx +pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.2 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==7.4.3 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in -pytest-factoryboy==2.5.1 # via -r requirements-dev.in -pytest-xdist[psutil]==3.3.1 # via -r requirements-dev.in +pytest-factoryboy==2.6.0 # via -r requirements-dev.in +pytest-xdist[psutil]==3.5.0 # via -r requirements-dev.in python-dateutil==2.8.2 # via faker pytz==2023.3.post1 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit -referencing==0.30.2 # via jsonschema, jsonschema-specifications -requests==2.31.0 # via dace, sphinx +referencing==0.32.0 # via jsonschema, jsonschema-specifications +requests==2.31.0 # via dace, requirementslib, sphinx, yarg +requirementslib==3.0.0 # via isort restructuredtext-lint==1.4.0 # via flake8-rst-docstrings -rpds-py==0.10.3 # via jsonschema, referencing -ruff==0.0.290 # via -r requirements-dev.in +rpds-py==0.13.2 # via jsonschema, referencing +ruff==0.1.7 # via -r requirements-dev.in +scikit-build==0.17.6 # via dace +setuptools-scm==8.0.4 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx sortedcontainers==2.4.0 # via hypothesis sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==1.3.0 # via -r requirements-dev.in +sphinx-rtd-theme==2.0.0 # via -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx sphinxcontrib-htmlhelp==2.0.1 # via sphinx @@ -131,31 +143,32 @@ sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme sphinxcontrib-jsmath==1.0.1 # via sphinx sphinxcontrib-qthelp==1.0.3 # via sphinx sphinxcontrib-serializinghtml==1.1.5 # via sphinx -sympy==1.12 # via dace, gt4py (pyproject.toml) +sympy==1.9 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext -tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox +tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pep517, pip-tools, pyproject-api, pyproject-hooks, pytest, scikit-build, setuptools-scm, tox +tomlkit==0.12.3 # via plette, requirementslib toolz==0.12.0 # via cytoolz -tox==4.11.3 # via -r requirements-dev.in -traitlets==5.10.0 # via jupyter-core, nbformat +tox==4.11.4 # via -r requirements-dev.in +traitlets==5.14.0 # via jupyter-core, nbformat types-aiofiles==23.2.0.0 # via types-all types-all==1.0.0 # via -r requirements-dev.in types-annoy==1.17.8.4 # via types-all types-atomicwrites==1.4.5.1 # via types-all types-backports==0.1.3 # via types-all types-backports-abc==0.5.2 # via types-all -types-bleach==6.0.0.4 # via types-all +types-bleach==6.1.0.1 # via types-all types-boto==2.49.18.9 # via types-all -types-cachetools==5.3.0.6 # via types-all +types-cachetools==5.3.0.7 # via types-all types-certifi==2021.10.8.3 # via types-all -types-cffi==1.15.1.15 # via types-jack-client +types-cffi==1.16.0.0 # via types-jack-client types-characteristic==14.3.7 # via types-all types-chardet==5.0.4.6 # via types-all types-click==7.1.8 # via types-all, types-flask -types-click-spinner==0.1.13.5 # via types-all +types-click-spinner==0.1.13.6 # via types-all types-colorama==0.4.15.12 # via types-all types-contextvars==2.4.7.3 # via types-all -types-croniter==1.4.0.1 # via types-all +types-croniter==2.0.0.0 # via types-all types-cryptography==3.3.23.2 # via types-all, types-openssl-python, types-pyjwt types-dataclasses==0.6.6 # via types-all types-dateparser==1.1.4.10 # via types-all @@ -176,44 +189,44 @@ types-futures==3.3.8 # via types-all types-geoip2==3.0.0 # via types-all types-ipaddress==1.0.8 # via types-all, types-maxminddb types-itsdangerous==1.1.6 # via types-all -types-jack-client==0.5.10.9 # via types-all +types-jack-client==0.5.10.10 # via types-all types-jinja2==2.11.9 # via types-all, types-flask types-kazoo==0.1.3 # via types-all -types-markdown==3.4.2.10 # via types-all +types-markdown==3.5.0.3 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 -types-mock==5.1.0.2 # via types-all +types-mock==5.1.0.3 # via types-all types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all -types-paramiko==3.3.0.0 # via types-all, types-pysftp +types-paramiko==3.3.0.2 # via types-all, types-pysftp types-pathlib2==2.3.0 # via types-all -types-pillow==10.0.0.3 # via types-all +types-pillow==10.1.0.2 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.1 # via types-all -types-protobuf==4.24.0.1 # via types-all +types-protobuf==4.24.0.4 # via types-all types-pyaudio==0.2.16.7 # via types-all types-pycurl==7.45.2.5 # via types-all types-pyfarmhash==0.3.1.2 # via types-all types-pyjwt==1.7.1 # via types-all types-pymssql==2.1.0 # via types-all types-pymysql==1.1.0.1 # via types-all -types-pyopenssl==23.2.0.2 # via types-redis +types-pyopenssl==23.3.0.0 # via types-redis types-pyrfc3339==1.1.1.5 # via types-all types-pysftp==0.2.17.6 # via types-all types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.1.0 # via types-all, types-tzlocal +types-pytz==2023.3.1.1 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all -types-pyyaml==6.0.12.11 # via types-all -types-redis==4.6.0.6 # via types-all -types-requests==2.31.0.2 # via types-all +types-pyyaml==6.0.12.12 # via types-all +types-redis==4.6.0.11 # via types-all +types-requests==2.31.0.10 # via types-all types-retry==0.9.9.4 # via types-all types-routes==2.5.0 # via types-all types-scribe==2.0.0 # via types-all -types-setuptools==68.2.0.0 # via types-cffi +types-setuptools==69.0.0.0 # via types-cffi types-simplejson==3.19.0.2 # via types-all types-singledispatch==4.1.0.0 # via types-all types-six==1.16.21.9 # via types-all @@ -222,21 +235,21 @@ types-termcolor==1.1.6.2 # via types-all types-toml==0.10.8.7 # via types-all types-tornado==5.1.1 # via types-all types-typed-ast==1.5.8.7 # via types-all -types-tzlocal==5.0.1.1 # via types-all +types-tzlocal==5.1.0.1 # via types-all types-ujson==5.8.0.1 # via types-all -types-urllib3==1.26.25.14 # via types-requests types-waitress==2.1.4.9 # via types-all types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all -typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy -urllib3==2.0.4 # via requests -virtualenv==20.24.5 # via pre-commit, tox -websockets==11.0.3 # via dace -werkzeug==2.3.7 # via flask -wheel==0.41.2 # via astunparse, pip-tools +typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pydantic, pytest-factoryboy, setuptools-scm +urllib3==2.1.0 # via requests, types-requests +virtualenv==20.25.0 # via pre-commit, tox +websockets==12.0 # via dace +werkzeug==3.0.1 # via flask +wheel==0.42.0 # via astunparse, pip-tools, scikit-build xxhash==3.0.0 # via gt4py (pyproject.toml) -zipp==3.16.2 # via importlib-metadata, importlib-resources +yarg==0.1.9 # via pipreqs +zipp==3.17.0 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: -pip==23.2.1 # via pip-tools -setuptools==68.2.2 # via gt4py (pyproject.toml), nodeenv, pip-tools +pip==23.3.1 # via pip-api, pip-tools, requirementslib +setuptools==69.0.2 # via gt4py (pyproject.toml), nodeenv, pip-tools, requirementslib, scikit-build, setuptools-scm diff --git a/src/gt4py/__init__.py b/src/gt4py/__init__.py index 7d255de142..c28c5cf2d6 100644 --- a/src/gt4py/__init__.py +++ b/src/gt4py/__init__.py @@ -33,6 +33,6 @@ if _sys.version_info >= (3, 10): - from . import next + from . import next # noqa: A004 __all__ += ["next"] diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index b1e559a41e..5dae025acb 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -562,12 +562,6 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.S omp_threads = "" omp_header = "" - # Backward compatible state struct name change in DaCe >=0.15.x - try: - dace_state_suffix = dace.Config.get("compiler.codegen_state_struct_suffix") - except (KeyError, TypeError): - dace_state_suffix = "_t" # old structure name - interface = cls.template.definition.render( name=sdfg.name, backend_specifics=omp_threads, @@ -575,7 +569,7 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.S functor_args=self.generate_functor_args(sdfg), tmp_allocs=self.generate_tmp_allocs(sdfg), allocator="gt::cuda_util::cuda_malloc" if is_gpu else "std::make_unique", - state_suffix=dace_state_suffix, + state_suffix=dace.Config.get("compiler.codegen_state_struct_suffix"), ) generated_code = textwrap.dedent( f"""#include diff --git a/src/gt4py/cartesian/gtc/dace/nodes.py b/src/gt4py/cartesian/gtc/dace/nodes.py index ddcb719b5f..bd8c08034c 100644 --- a/src/gt4py/cartesian/gtc/dace/nodes.py +++ b/src/gt4py/cartesian/gtc/dace/nodes.py @@ -121,7 +121,7 @@ def __init__( *args, **kwargs, ): - super().__init__(name=name, *args, **kwargs) + super().__init__(*args, name=name, **kwargs) from gt4py.cartesian.gtc.dace.utils import compute_dcir_access_infos diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index fcd53d1312..5660fdbf76 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -814,7 +814,7 @@ def concretize( """ # noqa: RST301 # doctest conventions confuse RST validator concrete_cls: Type[DataModelT] = _make_concrete_with_cache( - datamodel_cls, *type_args, class_name=class_name, module=module + datamodel_cls, *type_args, class_name=class_name, module=module # type: ignore[arg-type] ) assert isinstance(concrete_cls, type) and is_datamodel(concrete_cls) diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 7104f7658f..624407f319 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -1225,7 +1225,7 @@ def unzip(self) -> XIterable[Tuple[Any, ...]]: [('a', 'b', 'c'), (1, 2, 3)] """ - return XIterable(zip(*self.iterator)) # type: ignore # mypy gets confused with *args + return XIterable(zip(*self.iterator)) @typing.overload def islice(self, __stop: int) -> XIterable[T]: @@ -1536,7 +1536,7 @@ def reduceby( ) -> Dict[K, S]: ... - def reduceby( # type: ignore[misc] # signatures 2 and 4 are not satified due to inconsistencies with type variables + def reduceby( self, bin_op_func: Callable[[S, T], S], key: Union[str, List[K], Callable[[T], K]], diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 6bf6858369..949f4b461a 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -189,11 +189,12 @@ def __and__(self, other: UnitRange) -> UnitRange: return UnitRange(max(self.start, other.start), min(self.stop, other.stop)) def __contains__(self, value: Any) -> bool: - return ( - isinstance(value, core_defs.INTEGRAL_TYPES) - and value >= self.start - and value < self.stop - ) + # TODO(egparedes): use core_defs.IntegralScalar for `isinstance()` checks (see PEP 604) + # and remove int cast, once the related mypy bug (#16358) gets fixed + if isinstance(value, core_defs.INTEGRAL_TYPES): + return self.start <= cast(int, value) < self.stop + else: + return False def __le__(self, other: UnitRange) -> bool: return self.start >= other.start and self.stop <= other.stop diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 278dde9180..cd75538da7 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -15,7 +15,7 @@ import dataclasses import functools import inspect -from builtins import bool, float, int, tuple +from builtins import bool, float, int, tuple # noqa: A004 from typing import Any, Callable, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast import numpy as np diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index ed8b768972..3a82f9c738 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -82,7 +82,7 @@ def replace(self, **kwargs: Any) -> Self: if not dataclasses.is_dataclass(self): raise TypeError(f"'{self.__class__}' is not a dataclass.") assert not isinstance(self, type) - return dataclasses.replace(self, **kwargs) # type: ignore[misc] # `self` is guaranteed to be a dataclass (is_dataclass) should be a `TypeGuard`? + return dataclasses.replace(self, **kwargs) class ChainableWorkflowMixin(Workflow[StartT, EndT]): diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 7fd4794e57..fdd8a61054 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -260,10 +260,12 @@ def build_sdfg_from_itir( # visit ITIR and generate SDFG program = preprocess_program(program, offset_provider, lift_mode) - # TODO: According to Lex one should build the SDFG first in a general mannor. - # Generalisation to a particular device should happen only at the end. - sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, on_gpu) + sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) sdfg = sdfg_genenerator.visit(program) + if sdfg is None: + raise RuntimeError(f"Visit failed for program {program.id}.") + + # run DaCe transformations to simplify the SDFG sdfg.simplify() # run DaCe auto-optimization heuristics @@ -274,6 +276,9 @@ def build_sdfg_from_itir( device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu) + if on_gpu: + sdfg.apply_gpu_transformations() + return sdfg @@ -283,7 +288,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): compiler_args = kwargs.get("compiler_args", None) # `None` will take default. build_type = kwargs.get("build_type", "RelWithDebInfo") on_gpu = kwargs.get("on_gpu", False) - auto_optimize = kwargs.get("auto_optimize", False) + auto_optimize = kwargs.get("auto_optimize", True) lift_mode = kwargs.get("lift_mode", itir_transforms.LiftMode.FORCE_INLINE) # ITIR parameters column_axis = kwargs.get("column_axis", None) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index e3b5ddf2ac..fb2f82fed0 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -99,20 +99,17 @@ class ItirToSDFG(eve.NodeVisitor): offset_provider: dict[str, Any] node_types: dict[int, next_typing.Type] unique_id: int - use_gpu_storage: bool def __init__( self, param_types: list[ts.TypeSpec], offset_provider: dict[str, NeighborTableOffsetProvider], column_axis: Optional[Dimension] = None, - use_gpu_storage: bool = False, ): self.param_types = param_types self.column_axis = column_axis self.offset_provider = offset_provider self.storage_types = {} - self.use_gpu_storage = use_gpu_storage def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True): if isinstance(type_, ts.FieldType): @@ -123,14 +120,7 @@ def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset else None ) dtype = as_dace_type(type_.dtype) - storage = ( - dace.dtypes.StorageType.GPU_Global - if self.use_gpu_storage - else dace.dtypes.StorageType.Default - ) - sdfg.add_array( - name, shape=shape, strides=strides, offset=offset, dtype=dtype, storage=storage - ) + sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype) elif isinstance(type_, ts.ScalarType): sdfg.add_symbol(name, as_dace_type(type_)) @@ -246,7 +236,6 @@ def visit_StencilClosure( shape=array_table[name].shape, strides=array_table[name].strides, dtype=array_table[name].dtype, - storage=array_table[name].storage, transient=True, ) closure_init_state.add_nedge( @@ -261,7 +250,6 @@ def visit_StencilClosure( shape=array_table[name].shape, strides=array_table[name].strides, dtype=array_table[name].dtype, - storage=array_table[name].storage, ) else: assert isinstance(self.storage_types[name], ts.ScalarType) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py index e580333bc8..8cfff12df4 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py @@ -312,7 +312,7 @@ def test_symbolref_validation_for_valid_tree(): SymbolTableRootNode( nodes=[SymbolChildNode(name="foo"), SymbolRefChildNode(name="foo")], ) - SymbolTableRootNode( + SymbolTableRootNode( # noqa: B018 nodes=[ SymbolChildNode(name="foo"), SymbolRefChildNode(name="foo"), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 698dce2b5c..d100cd380c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -30,16 +30,6 @@ def test_external_local_field(unstructured_case): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - @gtx.field_operator def testee( inp: gtx.Field[[Vertex, V2EDim], int32], ones: gtx.Field[[Edge], int32] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index e8d0c8b163..e2434d860a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -46,16 +46,6 @@ ids=["positive_values", "negative_values"], ) def test_maxover_execution_(unstructured_case, strategy): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - if unstructured_case.backend in [ gtfn.run_gtfn, gtfn.run_gtfn_gpu, @@ -79,16 +69,6 @@ def testee(edge_f: cases.EField) -> cases.VField: @pytest.mark.uses_unstructured_shift def test_minover_execution(unstructured_case): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - @gtx.field_operator def minover(edge_f: cases.EField) -> cases.VField: out = min_over(edge_f(V2E), axis=V2EDim) @@ -102,16 +82,6 @@ def minover(edge_f: cases.EField) -> cases.VField: @pytest.mark.uses_unstructured_shift def test_reduction_execution(unstructured_case): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - @gtx.field_operator def reduction(edge_f: cases.EField) -> cases.VField: return neighbor_sum(edge_f(V2E), axis=V2EDim) @@ -150,16 +120,6 @@ def fencil(edge_f: cases.EField, out: cases.VField): @pytest.mark.uses_unstructured_shift def test_reduction_with_common_expression(unstructured_case): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - @gtx.field_operator def testee(flux: cases.EField) -> cases.VField: return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index da0945fe96..788081b81e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -20,14 +20,20 @@ from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms from gt4py.next.program_processors import otf_compile_executor from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries -from tests.next_tests.integration_tests.cases import Case -from tests.next_tests.toy_connectivity import Cell, Edge from next_tests.integration_tests import cases -from next_tests.integration_tests.cases import E2V, KDim, Vertex, cartesian_case, unstructured_case +from next_tests.integration_tests.cases import ( + E2V, + Case, + KDim, + Vertex, + cartesian_case, + unstructured_case, +) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( reduction_setup, ) +from next_tests.toy_connectivity import Cell, Edge @pytest.fixture