From 6b269bdf4bdc9e6e6c896b03c43f73c6342eb443 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 4 Jan 2024 17:04:55 +0100 Subject: [PATCH 01/11] bug[next]: recover grid_type in scan_operator (#1408) lost in merge conflict in #1380 --- src/gt4py/next/ffront/decorator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 76a0ddcde0..147059b1bd 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -847,6 +847,7 @@ def scan_operator_inner(definition: types.FunctionType) -> FieldOperator: return FieldOperator.from_function( definition, DEFAULT_BACKEND if backend is eve.NOTHING else backend, + grid_type, operator_node_cls=foast.ScanOperator, operator_attributes={"axis": axis, "forward": forward, "init": init}, ) From 6e6271c2c5d3ad0e31038d5bd597bd2327534a5e Mon Sep 17 00:00:00 2001 From: SF-N <65219381+SF-N@users.noreply.github.com> Date: Wed, 17 Jan 2024 14:47:23 +0100 Subject: [PATCH 02/11] feature[next]: Add power unrolling functionality and respective unit tests. (#1409) * Add power unrolling functionality and respective unit tests. * Define base and exponent variables for better readability in PowerUnrolling * Remove distinction between SymRef and FunCall in power unrolling * Optimize power unrolling to avoid multiple computations of FunCalls * Further improve power unrolling * Update wrt review and adapt expected results respectively * Add correct annotation --------- Co-authored-by: Sara Faghih-Naini --- .../iterator/transforms/power_unrolling.py | 84 +++++++++ .../transforms_tests/test_power_unrolling.py | 161 ++++++++++++++++++ 2 files changed, 245 insertions(+) create mode 100644 src/gt4py/next/iterator/transforms/power_unrolling.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_power_unrolling.py diff --git a/src/gt4py/next/iterator/transforms/power_unrolling.py b/src/gt4py/next/iterator/transforms/power_unrolling.py new file mode 100644 index 0000000000..ac71f2747d --- /dev/null +++ b/src/gt4py/next/iterator/transforms/power_unrolling.py @@ -0,0 +1,84 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +import dataclasses +import math + +from gt4py.eve import NodeTranslator +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas + + +def _is_power_call( + node: ir.FunCall, +) -> bool: + """Match expressions of the form `power(base, integral_literal)`.""" + return ( + isinstance(node.fun, ir.SymRef) + and node.fun.id == "power" + and isinstance(node.args[1], ir.Literal) + and float(node.args[1].value) == int(node.args[1].value) + and node.args[1].value >= im.literal_from_value(0).value + ) + + +def _compute_integer_power_of_two(exp: int) -> int: + return math.floor(math.log2(exp)) + + +@dataclasses.dataclass +class PowerUnrolling(NodeTranslator): + max_unroll: int + + @classmethod + def apply(cls, node: ir.Node, max_unroll: int = 5) -> ir.Node: + return cls(max_unroll=max_unroll).visit(node) + + def visit_FunCall(self, node: ir.FunCall): + new_node = self.generic_visit(node) + + if _is_power_call(new_node): + assert len(new_node.args) == 2 + # Check if unroll should be performed or if exponent is too large + base, exponent = new_node.args[0], int(new_node.args[1].value) + if 1 <= exponent <= self.max_unroll: + # Calculate and store powers of two of the base as long as they are smaller than the exponent. + # Do the same (using the stored values) with the remainder and multiply computed values. + pow_cur = _compute_integer_power_of_two(exponent) + pow_max = pow_cur + remainder = exponent + + # Build target expression + ret = im.ref(f"power_{2 ** pow_max}") + remainder -= 2**pow_cur + while remainder > 0: + pow_cur = _compute_integer_power_of_two(remainder) + remainder -= 2**pow_cur + + ret = im.multiplies_(ret, f"power_{2 ** pow_cur}") + + # Nest target expression to avoid multiple redundant evaluations + for i in range(pow_max, 0, -1): + ret = im.let( + f"power_{2 ** i}", + im.multiplies_(f"power_{2**(i-1)}", f"power_{2**(i-1)}"), + )(ret) + ret = im.let("power_1", base)(ret) + + # Simplify expression in case of SymRef by resolving let statements + if isinstance(base, ir.SymRef): + return InlineLambdas.apply(ret, opcount_preserving=True) + else: + return ret + return new_node diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_power_unrolling.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_power_unrolling.py new file mode 100644 index 0000000000..ae23becb4c --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_power_unrolling.py @@ -0,0 +1,161 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pytest + +from gt4py.eve import SymbolRef +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.power_unrolling import PowerUnrolling + + +def test_power_unrolling_zero(): + pytest.xfail( + "Not implementeds we don't have an easy way to determine the type of the one literal (type inference is to expensive)." + ) + testee = im.call("power")("x", 0) + expected = im.literal_from_value(1) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_one(): + testee = im.call("power")("x", 1) + expected = ir.SymRef(id=SymbolRef("x")) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_two(): + testee = im.call("power")("x", 2) + expected = im.multiplies_("x", "x") + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_two_x_plus_two(): + testee = im.call("power")(im.plus("x", 2), 2) + expected = im.let("power_1", im.plus("x", 2))( + im.let("power_2", im.multiplies_("power_1", "power_1"))("power_2") + ) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_two_x_plus_one_times_three(): + testee = im.call("power")(im.multiplies_(im.plus("x", 1), 3), 2) + expected = im.let("power_1", im.multiplies_(im.plus("x", 1), 3))( + im.let("power_2", im.multiplies_("power_1", "power_1"))("power_2") + ) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_three(): + testee = im.call("power")("x", 3) + expected = im.multiplies_(im.multiplies_("x", "x"), "x") + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_four(): + testee = im.call("power")("x", 4) + expected = im.let("power_2", im.multiplies_("x", "x"))(im.multiplies_("power_2", "power_2")) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_five(): + testee = im.call("power")("x", 5) + tmp2 = im.multiplies_("x", "x") + expected = im.multiplies_(im.multiplies_(tmp2, tmp2), "x") + expected = im.let("power_2", im.multiplies_("x", "x"))( + im.multiplies_(im.multiplies_("power_2", "power_2"), "x") + ) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_seven(): + testee = im.call("power")("x", 7) + expected = im.call("power")("x", 7) + + actual = PowerUnrolling.apply(testee, max_unroll=5) + assert actual == expected + + +def test_power_unrolling_seven_unrolled(): + testee = im.call("power")("x", 7) + expected = im.let("power_2", im.multiplies_("x", "x"))( + im.multiplies_(im.multiplies_(im.multiplies_("power_2", "power_2"), "power_2"), "x") + ) + + actual = PowerUnrolling.apply(testee, max_unroll=7) + assert actual == expected + + +def test_power_unrolling_seven_x_plus_one_unrolled(): + testee = im.call("power")(im.plus("x", 1), 7) + expected = im.let("power_1", im.plus("x", 1))( + im.let("power_2", im.multiplies_("power_1", "power_1"))( + im.let("power_4", im.multiplies_("power_2", "power_2"))( + im.multiplies_(im.multiplies_("power_4", "power_2"), "power_1") + ) + ) + ) + + actual = PowerUnrolling.apply(testee, max_unroll=7) + assert actual == expected + + +def test_power_unrolling_eight(): + testee = im.call("power")("x", 8) + expected = im.call("power")("x", 8) + + actual = PowerUnrolling.apply(testee, max_unroll=5) + assert actual == expected + + +def test_power_unrolling_eight_unrolled(): + testee = im.call("power")("x", 8) + expected = im.let("power_2", im.multiplies_("x", "x"))( + im.let("power_4", im.multiplies_("power_2", "power_2"))( + im.multiplies_("power_4", "power_4") + ) + ) + + actual = PowerUnrolling.apply(testee, max_unroll=8) + assert actual == expected + + +def test_power_unrolling_eight_x_plus_one_unrolled(): + testee = im.call("power")(im.plus("x", 1), 8) + expected = im.let("power_1", im.plus("x", 1))( + im.let("power_2", im.multiplies_("power_1", "power_1"))( + im.let("power_4", im.multiplies_("power_2", "power_2"))( + im.let("power_8", im.multiplies_("power_4", "power_4"))("power_8") + ) + ) + ) + + actual = PowerUnrolling.apply(testee, max_unroll=8) + assert actual == expected From 6283ac930ee301a02300b2ad9bb440b3aab04b2d Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 17 Jan 2024 15:50:57 +0100 Subject: [PATCH 03/11] fix[cartesian]: DaCe array access in tasklet (#1410) Found some incompatible tasklet representation while upgrading to dace v0.15.1. Array access inside tasklet with partial index subset worked in v0.14.1, although not valid. The fix consists of modifying the memlets to pass the full array shape to such tasklet, and use all explicit indices inside the tasklet to access the array. This is the right representation in DaCe SDFG, as discussed with the DaCe developers. --- .../gtc/dace/expansion/daceir_builder.py | 35 +++++++++++++++++++ src/gt4py/cartesian/gtc/daceir.py | 4 +-- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index db276a48b9..48b129fa87 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -30,6 +30,7 @@ compute_dcir_access_infos, flatten_list, get_tasklet_symbol, + make_dace_subset, union_inout_memlets, union_node_grid_subsets, untile_memlets, @@ -458,6 +459,40 @@ def visit_HorizontalExecution( write_memlets=write_memlets, ) + for memlet in [*read_memlets, *write_memlets]: + """ + This loop handles the special case of a tasklet performing array access. + The memlet should pass the full array shape (no tiling) and + the tasklet expression for array access should use all explicit indexes. + """ + array_ndims = len(global_ctx.arrays[memlet.field].shape) + field_decl = global_ctx.library_node.field_decls[memlet.field] + # calculate array subset on original memlet + memlet_subset = make_dace_subset( + global_ctx.library_node.access_infos[memlet.field], + memlet.access_info, + field_decl.data_dims, + ) + # select index values for single-point grid access + memlet_data_index = [ + dcir.Literal(value=str(dim_range[0]), dtype=common.DataType.INT32) + for dim_range, dim_size in zip(memlet_subset, memlet_subset.size()) + if dim_size == 1 + ] + if len(memlet_data_index) < array_ndims: + reshape_memlet = False + for access_node in dcir_node.walk_values().if_isinstance(dcir.IndexAccess): + if access_node.data_index and access_node.name == memlet.connector: + access_node.data_index = memlet_data_index + access_node.data_index + assert len(access_node.data_index) == array_ndims + reshape_memlet = True + if reshape_memlet: + # ensure that memlet symbols used for array indexing are defined in context + for sym in memlet.access_info.grid_subset.free_symbols: + symbol_collector.add_symbol(sym) + # set full shape on memlet + memlet.access_info = global_ctx.library_node.access_infos[memlet.field] + for item in reversed(expansion_items): iteration_ctx = iteration_ctx.pop() dcir_node = self._process_iteration_item( diff --git a/src/gt4py/cartesian/gtc/daceir.py b/src/gt4py/cartesian/gtc/daceir.py index 28ebc8cd8e..0366317360 100644 --- a/src/gt4py/cartesian/gtc/daceir.py +++ b/src/gt4py/cartesian/gtc/daceir.py @@ -536,7 +536,7 @@ def union(self, other): else: assert ( isinstance(interval2, (TileInterval, DomainInterval)) - and isinstance(interval1, IndexWithExtent) + and isinstance(interval1, (IndexWithExtent, DomainInterval)) ) or ( isinstance(interval1, (TileInterval, DomainInterval)) and isinstance(interval2, IndexWithExtent) @@ -573,7 +573,7 @@ def overapproximated_shape(self): def apply_iteration(self, grid_subset: GridSubset): res_intervals = dict(self.grid_subset.intervals) for axis, field_interval in self.grid_subset.intervals.items(): - if axis in grid_subset.intervals: + if axis in grid_subset.intervals and not isinstance(field_interval, DomainInterval): grid_interval = grid_subset.intervals[axis] assert isinstance(field_interval, IndexWithExtent) extent = field_interval.extent From 3edf21e9c7fce64976068d19d2a15c5b856d94d6 Mon Sep 17 00:00:00 2001 From: Nicoletta Farabullini <41536517+nfarabullini@users.noreply.github.com> Date: Thu, 18 Jan 2024 09:19:00 +0100 Subject: [PATCH 04/11] bug[next]: Bound args kwargs edit (#1411) * edits for BoundArgs with kwargs in correct order --- src/gt4py/next/ffront/decorator.py | 17 +++-- .../ffront_tests/test_bound_args.py | 64 +++++++++++++++++++ .../ffront_tests/test_execution.py | 20 ------ 3 files changed, 75 insertions(+), 26 deletions(-) create mode 100644 tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 147059b1bd..05cbe1c882 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -453,27 +453,32 @@ def _process_args(self, args: tuple, kwargs: dict): ) from err full_args = [*args] + full_kwargs = {**kwargs} for index, param in enumerate(self.past_node.params): if param.id in self.bound_args.keys(): - full_args.insert(index, self.bound_args[param.id]) + if index < len(full_args): + full_args.insert(index, self.bound_args[param.id]) + else: + full_kwargs[str(param.id)] = self.bound_args[param.id] - return super()._process_args(tuple(full_args), kwargs) + return super()._process_args(tuple(full_args), full_kwargs) @functools.cached_property def itir(self): new_itir = super().itir for new_clos in new_itir.closures: - for key in self.bound_args.keys(): + new_args = [ref(inp.id) for inp in new_clos.inputs] + for key, value in self.bound_args.items(): index = next( index for index, closure_input in enumerate(new_clos.inputs) if closure_input.id == key ) + new_args[new_args.index(new_clos.inputs[index])] = promote_to_const_iterator( + literal_from_value(value) + ) new_clos.inputs.pop(index) - new_args = [ref(inp.id) for inp in new_clos.inputs] params = [sym(inp.id) for inp in new_clos.inputs] - for value in self.bound_args.values(): - new_args.append(promote_to_const_iterator(literal_from_value(value))) expr = itir.FunCall( fun=new_clos.stencil, args=new_args, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py new file mode 100644 index 0000000000..0de953d85f --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np + +import gt4py.next as gtx +from gt4py.next import int32 + +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import cartesian_case +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + fieldview_backend, + reduction_setup, +) + + +def test_with_bound_args(cartesian_case): + @gtx.field_operator + def fieldop_bound_args(a: cases.IField, scalar: int32, condition: bool) -> cases.IField: + if not condition: + scalar = 0 + return a + scalar + + @gtx.program + def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cases.IField): + fieldop_bound_args(a, scalar, condition, out=out) + + a = cases.allocate(cartesian_case, program_bound_args, "a")() + scalar = int32(1) + ref = a + scalar + out = cases.allocate(cartesian_case, program_bound_args, "out")() + + prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True) + cases.verify(cartesian_case, prog_bounds, a, out, inout=out, ref=ref) + + +def test_with_bound_args_order_args(cartesian_case): + @gtx.field_operator + def fieldop_args(a: cases.IField, condition: bool, scalar: int32) -> cases.IField: + scalar = 0 if not condition else scalar + return a + scalar + + @gtx.program(backend=cartesian_case.backend) + def program_args(a: cases.IField, condition: bool, scalar: int32, out: cases.IField): + fieldop_args(a, condition, scalar, out=out) + + a = cases.allocate(cartesian_case, program_args, "a")() + out = cases.allocate(cartesian_case, program_args, "out")() + + prog_bounds = program_args.with_bound_args(condition=True) + prog_bounds(a=a, scalar=int32(1), out=out, offset_provider={}) + np.allclose(out.asnumpy(), a.asnumpy() + int32(1)) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index a08931628b..70c79d7b6c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -898,26 +898,6 @@ def test_docstring(a: cases.IField): cases.verify(cartesian_case, test_docstring, a, inout=a, ref=a) -def test_with_bound_args(cartesian_case): - @gtx.field_operator - def fieldop_bound_args(a: cases.IField, scalar: int32, condition: bool) -> cases.IField: - if not condition: - scalar = 0 - return a + a + scalar - - @gtx.program - def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cases.IField): - fieldop_bound_args(a, scalar, condition, out=out) - - a = cases.allocate(cartesian_case, program_bound_args, "a")() - scalar = int32(1) - ref = a + a + 1 - out = cases.allocate(cartesian_case, program_bound_args, "out")() - - prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True) - cases.verify(cartesian_case, prog_bounds, a, out, inout=out, ref=ref) - - def test_domain(cartesian_case): @gtx.field_operator def fieldop_domain(a: cases.IField) -> cases.IField: From ba368564c27807cbac207bd7d5631501f87b062a Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 18 Jan 2024 11:14:43 +0100 Subject: [PATCH 05/11] example: cartesian with next compatibility (#1202) Add an example illustrating using gt4py.cartesian and gt4py.next computations next to each other using gt4py.next storages. Refactor GTFieldInterface and cleanup GTDimsInterface for next. --- examples/lap_cartesian_vs_next.ipynb | 189 ++++++++++++++++++ src/gt4py/next/__init__.py | 9 + src/gt4py/next/common.py | 35 ++-- src/gt4py/next/embedded/nd_array_field.py | 4 - src/gt4py/next/iterator/embedded.py | 32 +-- src/gt4py/next/iterator/tracing.py | 2 +- .../next/type_system/type_translation.py | 2 +- src/gt4py/storage/cartesian/utils.py | 4 + 8 files changed, 239 insertions(+), 38 deletions(-) create mode 100644 examples/lap_cartesian_vs_next.ipynb diff --git a/examples/lap_cartesian_vs_next.ipynb b/examples/lap_cartesian_vs_next.ipynb new file mode 100644 index 0000000000..cb80122570 --- /dev/null +++ b/examples/lap_cartesian_vs_next.ipynb @@ -0,0 +1,189 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "GT4Py - GridTools for Python\n", + "\n", + "Copyright (c) 2014-2023, ETH Zurich\n", + "All rights reserved.\n", + "\n", + "This file is part the GT4Py project and the GridTools framework.\n", + "GT4Py is free software: you can redistribute it and/or modify it under\n", + "the terms of the GNU General Public License as published by the\n", + "Free Software Foundation, either version 3 of the License, or any later\n", + "version. See the LICENSE.txt file at the top-level directory of this\n", + "distribution for a copy of the license or check .\n", + "\n", + "SPDX-License-Identifier: GPL-3.0-or-later" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Demonstrates gt4py.cartesian with gt4py.next compatibility" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "nx = 32\n", + "ny = 32\n", + "nz = 1\n", + "dtype = np.float64" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Storages\n", + "--\n", + "\n", + "We create fields using the gt4py.next constructors. These fields are compatible with gt4py.cartesian when we use \"I\", \"J\", \"K\" as the dimension names." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import gt4py.next as gtx\n", + "\n", + "allocator = gtx.itir_embedded # should match the executor\n", + "# allocator = gtx.gtfn_cpu\n", + "# allocator = gtx.gtfn_gpu\n", + "\n", + "# Note: for gt4py.next, names don't matter, for gt4py.cartesian they have to be \"I\", \"J\", \"K\"\n", + "I = gtx.Dimension(\"I\")\n", + "J = gtx.Dimension(\"J\")\n", + "K = gtx.Dimension(\"K\", kind=gtx.DimensionKind.VERTICAL)\n", + "\n", + "domain = gtx.domain({I: nx, J: ny, K: nz})\n", + "\n", + "inp = gtx.as_field(domain, np.fromfunction(lambda x, y, z: x**2+y**2, shape=(nx, ny, nz)), dtype, allocator=allocator)\n", + "out_cartesian = gtx.zeros(domain, dtype, allocator=allocator)\n", + "out_next = gtx.zeros(domain, dtype, allocator=allocator)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "gt4py.cartesian\n", + "--" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import gt4py.cartesian.gtscript as gtscript\n", + "\n", + "cartesian_backend = \"numpy\"\n", + "# cartesian_backend = \"gt:cpu_ifirst\"\n", + "# cartesian_backend = \"gt:gpu\"\n", + "\n", + "@gtscript.stencil(backend=cartesian_backend)\n", + "def lap_cartesian(\n", + " inp: gtscript.Field[dtype],\n", + " out: gtscript.Field[dtype],\n", + "):\n", + " with computation(PARALLEL), interval(...):\n", + " out = -4.0 * inp[0, 0, 0] + inp[-1, 0, 0] + inp[1, 0, 0] + inp[0, -1, 0] + inp[0, 1, 0]\n", + "\n", + "lap_cartesian(inp=inp, out=out_cartesian, origin=(1, 1, 0), domain=(nx-2, ny-2, nz))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from gt4py.next import Field\n", + "\n", + "next_backend = gtx.itir_embedded\n", + "# next_backend = gtx.gtfn_cpu\n", + "# next_backend = gtx.gtfn_gpu\n", + "\n", + "Ioff = gtx.FieldOffset(\"I\", source=I, target=(I,))\n", + "Joff = gtx.FieldOffset(\"J\", source=J, target=(J,))\n", + "\n", + "@gtx.field_operator\n", + "def lap_next(inp: Field[[I, J, K], dtype]) -> Field[[I, J, K], dtype]:\n", + " return -4.0 * inp + inp(Ioff[-1]) + inp(Ioff[1]) + inp(Joff[-1]) + inp(Joff[1])\n", + "\n", + "@gtx.program(backend=next_backend)\n", + "def lap_next_program(inp: Field[[I, J, K], dtype], out: Field[[I, J, K], dtype]):\n", + " lap_next(inp, out=out[1:-1, 1:-1, :])\n", + "\n", + "lap_next_program(inp, out_next, offset_provider={\"Ioff\": I, \"Joff\": J})" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "assert np.allclose(out_cartesian.asnumpy(), out_next.asnumpy())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index cbd5735949..1398af5f03 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -39,6 +39,11 @@ index_field, np_as_located_field, ) +from .program_processors.runners.gtfn import ( + run_gtfn_cached as gtfn_cpu, + run_gtfn_gpu_cached as gtfn_gpu, +) +from .program_processors.runners.roundtrip import backend as itir_python __all__ = [ @@ -74,5 +79,9 @@ "field_operator", "program", "scan_operator", + # from program_processor + "gtfn_cpu", + "gtfn_gpu", + "itir_python", *fbuiltins.__all__, ] diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 29d606ccc0..6bf6858369 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -574,38 +574,39 @@ def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _ ... -# TODO(havogt): replace this protocol with the new `GTFieldInterface` protocol -class NextGTDimsInterface(Protocol): +# TODO(havogt): we need to describe when this interface should be used instead of the `Field` protocol. +class GTFieldInterface(core_defs.GTDimsInterface, core_defs.GTOriginInterface, Protocol): """ - Protocol for objects providing the `__gt_dims__` property, naming :class:`Field` dimensions. + Protocol for object providing the `__gt_domain__` property, specifying the :class:`Domain` of a :class:`Field`. - The dimension names are objects of type :class:`Dimension`, in contrast to - :mod:`gt4py.cartesian`, where the labels are `str` s with implied semantics, - see :class:`~gt4py._core.definitions.GTDimsInterface` . + Note: + - A default implementation of the `__gt_dims__` interface from `gt4py.cartesian` is provided. + - No implementation of `__gt_origin__` is provided because of infinite fields. """ @property - def __gt_dims__(self) -> tuple[Dimension, ...]: + def __gt_domain__(self) -> Domain: + # TODO probably should be changed to `DomainLike` (with a new concept `DimensionLike`) + # to allow implementations without having to import gtx.Domain. ... - -# TODO(egparedes): add support for this new protocol in the cartesian module -class GTFieldInterface(Protocol): - """Protocol for object providing the `__gt_domain__` property, specifying the :class:`Domain` of a :class:`Field`.""" - @property - def __gt_domain__(self) -> Domain: - ... + def __gt_dims__(self) -> tuple[str, ...]: + return tuple(d.value for d in self.__gt_domain__.dims) @extended_runtime_checkable -class Field(NextGTDimsInterface, core_defs.GTOriginInterface, Protocol[DimsT, core_defs.ScalarT]): +class Field(GTFieldInterface, Protocol[DimsT, core_defs.ScalarT]): __gt_builtin_func__: ClassVar[GTBuiltInFuncDispatcher] @property def domain(self) -> Domain: ... + @property + def __gt_domain__(self) -> Domain: + return self.domain + @property def codomain(self) -> type[core_defs.ScalarT] | Dimension: ... @@ -923,10 +924,6 @@ def asnumpy(self) -> Never: def domain(self) -> Domain: return Domain(dims=(self.dimension,), ranges=(UnitRange.infinite(),)) - @property - def __gt_dims__(self) -> tuple[Dimension, ...]: - return self.domain.dims - @property def __gt_origin__(self) -> Never: raise TypeError("'CartesianConnectivity' does not support this operation.") diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 8bd2673db9..9fc1b42038 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -107,10 +107,6 @@ def domain(self) -> common.Domain: def shape(self) -> tuple[int, ...]: return self._ndarray.shape - @property - def __gt_dims__(self) -> tuple[common.Dimension, ...]: - return self._domain.dims - @property def __gt_origin__(self) -> tuple[int, ...]: assert common.Domain.is_finite(self._domain) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index ef70a2e645..390bec4312 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -172,7 +172,7 @@ class LocatedField(Protocol): @property @abc.abstractmethod - def __gt_dims__(self) -> tuple[common.Dimension, ...]: + def __gt_domain__(self) -> common.Domain: ... # TODO(havogt): define generic Protocol to provide a concrete return type @@ -182,7 +182,7 @@ def field_getitem(self, indices: NamedFieldIndices) -> Any: @property def __gt_origin__(self) -> tuple[int, ...]: - return tuple([0] * len(self.__gt_dims__)) + return tuple([0] * len(self.__gt_domain__.dims)) @runtime_checkable @@ -675,12 +675,18 @@ def _is_concrete_position(pos: Position) -> TypeGuard[ConcretePosition]: def _get_axes( field_or_tuple: LocatedField | tuple, ) -> Sequence[common.Dimension]: # arbitrary nesting of tuples of LocatedField + return _get_domain(field_or_tuple).dims + + +def _get_domain( + field_or_tuple: LocatedField | tuple, +) -> common.Domain: # arbitrary nesting of tuples of LocatedField if isinstance(field_or_tuple, tuple): - first = _get_axes(field_or_tuple[0]) - assert all(first == _get_axes(f) for f in field_or_tuple) + first = _get_domain(field_or_tuple[0]) + assert all(first == _get_domain(f) for f in field_or_tuple) return first else: - return field_or_tuple.__gt_dims__ + return field_or_tuple.__gt_domain__ def _single_vertical_idx( @@ -894,14 +900,14 @@ class NDArrayLocatedFieldWrapper(MutableLocatedField): _ndarrayfield: common.Field @property - def __gt_dims__(self) -> tuple[common.Dimension, ...]: - return self._ndarrayfield.__gt_dims__ + def __gt_domain__(self) -> common.Domain: + return self._ndarrayfield.__gt_domain__ def _translate_named_indices( self, _named_indices: NamedFieldIndices ) -> common.AbsoluteIndexSequence: named_indices: Mapping[common.Dimension, FieldIndex | SparsePositionEntry] = { - d: _named_indices[d.value] for d in self._ndarrayfield.__gt_dims__ + d: _named_indices[d.value] for d in self._ndarrayfield.__gt_domain__.dims } domain_slice: list[common.NamedRange | common.NamedIndex] = [] for d, v in named_indices.items(): @@ -1046,8 +1052,8 @@ class IndexField(common.Field): _dimension: common.Dimension @property - def __gt_dims__(self) -> tuple[common.Dimension, ...]: - return (self._dimension,) + def __gt_domain__(self) -> common.Domain: + return self.domain @property def __gt_origin__(self) -> tuple[int, ...]: @@ -1165,8 +1171,8 @@ class ConstantField(common.Field[Any, core_defs.ScalarT]): _value: core_defs.ScalarT @property - def __gt_dims__(self) -> tuple[common.Dimension, ...]: - return tuple() + def __gt_domain__(self) -> common.Domain: + return self.domain @property def __gt_origin__(self) -> tuple[int, ...]: @@ -1452,7 +1458,7 @@ def _tuple_assign(field: tuple | MutableLocatedField, value: Any, named_indices: class TupleOfFields(TupleField): def __init__(self, data): self.data = data - self.__gt_dims__ = _get_axes(data) + self.__gt_domain__ = _get_domain(data) def field_getitem(self, named_indices: NamedFieldIndices) -> Any: return _build_tuple_result(self.data, named_indices) diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index 30fec1f9fd..05ebd02352 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -254,7 +254,7 @@ def _contains_tuple_dtype_field(arg): # other `np.int32`). We just ignore the error here and postpone fixing this to when # the new storages land (The implementation here works for LocatedFieldImpl). - return common.is_field(arg) and any(dim is None for dim in arg.__gt_dims__) + return common.is_field(arg) and any(dim is None for dim in arg.domain.dims) def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]: diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 88a8347fe4..12649bf620 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -184,7 +184,7 @@ def from_value(value: Any) -> ts.TypeSpec: elif isinstance(value, common.Dimension): symbol_type = ts.DimensionType(dim=value) elif common.is_field(value): - dims = list(value.__gt_dims__) + dims = list(value.domain.dims) dtype = from_type_hint(value.dtype.scalar_type) symbol_type = ts.FieldType(dims=dims, dtype=dtype) elif isinstance(value, tuple): diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py index 0f7cf5d0ab..4e7ebb0c21 100644 --- a/src/gt4py/storage/cartesian/utils.py +++ b/src/gt4py/storage/cartesian/utils.py @@ -192,6 +192,10 @@ def cpu_copy(array: Union[np.ndarray, "cp.ndarray"]) -> np.ndarray: def asarray( array: FieldLike, *, device: Literal["cpu", "gpu", None] = None ) -> np.ndarray | cp.ndarray: + if hasattr(array, "ndarray"): + # extract the buffer from a gt4py.next.Field + # TODO(havogt): probably `Field` should provide the array interface methods when applicable + array = array.ndarray if device == "gpu" or (not device and hasattr(array, "__cuda_array_interface__")): return cp.asarray(array) if device == "cpu" or ( From 49db7efadbe6c5329fdadbf3e3f3a0fd1728ee00 Mon Sep 17 00:00:00 2001 From: Samuel Date: Thu, 18 Jan 2024 17:00:33 +0100 Subject: [PATCH 06/11] feat[next]: Pass sizes to temporaries from gt4py program (#1359) --- src/gt4py/eve/trees.py | 4 +- .../next/iterator/transforms/global_tmps.py | 41 +++++-- .../next/iterator/transforms/pass_manager.py | 6 +- src/gt4py/next/iterator/type_inference.py | 6 +- .../codegens/gtfn/gtfn_backend.py | 77 ------------ .../codegens/gtfn/gtfn_module.py | 95 +++++++++++---- .../program_processors/formatters/gtfn.py | 13 +- .../test_temporaries_with_sizes.py | 113 ++++++++++++++++++ .../cpp_backend_tests/anton_lap.py | 6 +- .../cpp_backend_tests/copy_stencil.py | 6 +- .../copy_stencil_field_view.py | 6 +- .../cpp_backend_tests/fvm_nabla.py | 11 +- .../cpp_backend_tests/tridiagonal_solve.py | 6 +- .../transforms_tests/test_global_tmps.py | 2 +- 14 files changed, 264 insertions(+), 128 deletions(-) delete mode 100644 src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py create mode 100644 tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py diff --git a/src/gt4py/eve/trees.py b/src/gt4py/eve/trees.py index cd7e71588f..74c5bd41bb 100644 --- a/src/gt4py/eve/trees.py +++ b/src/gt4py/eve/trees.py @@ -133,7 +133,7 @@ def _pre_walk_items( yield from _pre_walk_items(child, __key__=key) -def _pre_walk_values(node: TreeLike) -> Iterable[Tuple[Any]]: +def _pre_walk_values(node: TreeLike) -> Iterable: """Create a pre-order tree traversal iterator of values.""" yield node for child in iter_children_values(node): @@ -153,7 +153,7 @@ def _post_walk_items( yield __key__, node -def _post_walk_values(node: TreeLike) -> Iterable[Tuple[Any]]: +def _post_walk_values(node: TreeLike) -> Iterable: """Create a post-order tree traversal iterator of values.""" if (iter_children_values := getattr(node, "iter_children_values", None)) is not None: for child in iter_children_values(): diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index d9d3d18213..0033f36cab 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -22,6 +22,7 @@ from gt4py.eve import Coerced, NodeTranslator from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.utils import UIDGenerator +from gt4py.next import common from gt4py.next.iterator import ir, type_inference from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift @@ -437,9 +438,12 @@ def _group_offsets( return zip(tags, offsets, strict=True) # type: ignore[return-value] # mypy doesn't infer literal correctly -def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, Any]): +def update_domains( + node: FencilWithTemporaries, + offset_provider: Mapping[str, Any], + symbolic_sizes: Optional[dict[str, str]], +): horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider) - closures: list[ir.StencilClosure] = [] domains = dict[str, ir.FunCall]() for closure in reversed(node.fencil.closures): @@ -479,16 +483,29 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An # cartesian shift dim = offset_provider[offset_name].value consumed_domain.ranges[dim] = consumed_domain.ranges[dim].translate(offset) - elif isinstance(offset_provider[offset_name], gtx.NeighborTableOffsetProvider): + elif isinstance(offset_provider[offset_name], common.Connectivity): # unstructured shift nbt_provider = offset_provider[offset_name] old_axis = nbt_provider.origin_axis.value new_axis = nbt_provider.neighbor_axis.value - consumed_domain.ranges.pop(old_axis) - assert new_axis not in consumed_domain.ranges - consumed_domain.ranges[new_axis] = SymbolicRange( - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.literal(str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN), + + assert new_axis not in consumed_domain.ranges or old_axis == new_axis + + if symbolic_sizes is None: + new_range = SymbolicRange( + im.literal("0", ir.INTEGER_INDEX_BUILTIN), + im.literal( + str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN + ), + ) + else: + new_range = SymbolicRange( + im.literal("0", ir.INTEGER_INDEX_BUILTIN), + im.ref(symbolic_sizes[new_axis]), + ) + consumed_domain.ranges = dict( + (axis, range_) if axis != old_axis else (new_axis, new_range) + for axis, range_ in consumed_domain.ranges.items() ) else: raise NotImplementedError @@ -570,7 +587,11 @@ class CreateGlobalTmps(NodeTranslator): """ def visit_FencilDefinition( - self, node: ir.FencilDefinition, *, offset_provider: Mapping[str, Any] + self, + node: ir.FencilDefinition, + *, + offset_provider: Mapping[str, Any], + symbolic_sizes: Optional[dict[str, str]], ) -> FencilWithTemporaries: # Split closures on lifted function calls and introduce temporaries res = split_closures(node, offset_provider=offset_provider) @@ -581,6 +602,6 @@ def visit_FencilDefinition( # Perform an eta-reduction which should put all calls at the highest level of a closure res = EtaReduction().visit(res) # Perform a naive extent analysis to compute domain sizes of closures and temporaries - res = update_domains(res, offset_provider) + res = update_domains(res, offset_provider, symbolic_sizes) # Use type inference to determine the data type of the temporaries return collect_tmps_info(res, offset_provider=offset_provider) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 2e05391634..08897861c2 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import enum +from typing import Optional from gt4py.next.iterator import ir from gt4py.next.iterator.transforms import simple_inline_heuristic @@ -81,6 +82,7 @@ def apply_common_transforms( common_subexpression_elimination=True, force_inline_lambda_args=False, unconditionally_collapse_tuples=False, + symbolic_domain_sizes: Optional[dict[str, str]] = None, ): if lift_mode is None: lift_mode = LiftMode.FORCE_INLINE @@ -147,7 +149,9 @@ def apply_common_transforms( if lift_mode != LiftMode.FORCE_INLINE: assert offset_provider is not None - ir = CreateGlobalTmps().visit(ir, offset_provider=offset_provider) + ir = CreateGlobalTmps().visit( + ir, offset_provider=offset_provider, symbolic_sizes=symbolic_domain_sizes + ) ir = InlineLifts().visit(ir) # If after creating temporaries, the scan is not at the top, we inline. # The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py index 68627cfd89..d65f67b266 100644 --- a/src/gt4py/next/iterator/type_inference.py +++ b/src/gt4py/next/iterator/type_inference.py @@ -567,9 +567,7 @@ def _infer_shift_location_types(shift_args, offset_provider, constraints): axis = offset_provider[offset] if isinstance(axis, gtx.Dimension): continue # Cartesian shifts don’t change the location type - elif isinstance( - axis, (gtx.NeighborTableOffsetProvider, gtx.StridedNeighborOffsetProvider) - ): + elif isinstance(axis, Connectivity): assert ( axis.origin_axis.kind == axis.neighbor_axis.kind @@ -964,7 +962,7 @@ def visit_FencilDefinition( def _save_types_to_annex(node: ir.Node, types: dict[int, Type]) -> None: for child_node in node.pre_walk_values().if_isinstance(*TYPED_IR_NODES): try: - child_node.annex.type = types[id(child_node)] # type: ignore[attr-defined] + child_node.annex.type = types[id(child_node)] except KeyError: if not ( isinstance(child_node, ir.SymRef) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py deleted file mode 100644 index 4183f52550..0000000000 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py +++ /dev/null @@ -1,77 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from typing import Any - -import gt4py.next.iterator.ir as itir -from gt4py.eve import codegen -from gt4py.eve.exceptions import EveValueError -from gt4py.next.iterator.transforms.pass_manager import apply_common_transforms -from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen -from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_to_gtfn_im_ir import GTFN_IM_lowering -from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import GTFN_lowering - - -def _lower( - program: itir.FencilDefinition, enable_itir_transforms: bool, do_unroll: bool, **kwargs: Any -): - offset_provider = kwargs.get("offset_provider") - assert isinstance(offset_provider, dict) - if enable_itir_transforms: - program = apply_common_transforms( - program, - lift_mode=kwargs.get("lift_mode"), - offset_provider=offset_provider, - unroll_reduce=do_unroll, - unconditionally_collapse_tuples=True, # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements - ) - gtfn_ir = GTFN_lowering.apply( - program, - offset_provider=offset_provider, - column_axis=kwargs.get("column_axis"), - ) - return gtfn_ir - - -def generate( - program: itir.FencilDefinition, enable_itir_transforms: bool = True, **kwargs: Any -) -> str: - if kwargs.get("imperative", False): - try: - gtfn_ir = _lower( - program=program, - enable_itir_transforms=enable_itir_transforms, - do_unroll=False, - **kwargs, - ) - except EveValueError: - # if we don't unroll, there may be lifts left in the itir which can't be lowered to - # gtfn. In this case, just retry with unrolled reductions. - gtfn_ir = _lower( - program=program, - enable_itir_transforms=enable_itir_transforms, - do_unroll=True, - **kwargs, - ) - gtfn_im_ir = GTFN_IM_lowering().visit(node=gtfn_ir, **kwargs) - generated_code = GTFNIMCodegen.apply(gtfn_im_ir, **kwargs) - else: - gtfn_ir = _lower( - program=program, - enable_itir_transforms=enable_itir_transforms, - do_unroll=True, - **kwargs, - ) - generated_code = GTFNCodegen.apply(gtfn_ir, **kwargs) - return codegen.format_source("cpp", generated_code, style="LLVM") diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 4abdaa6eea..718fef72af 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -15,21 +15,24 @@ from __future__ import annotations import dataclasses +import functools import warnings from typing import Any, Final, Optional import numpy as np from gt4py._core import definitions as core_defs -from gt4py.eve import trees, utils +from gt4py.eve import codegen, trees, utils from gt4py.next import common from gt4py.next.common import Connectivity, Dimension from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import LiftMode +from gt4py.next.iterator.transforms import LiftMode, pass_manager from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import cpp_interface, interface -from gt4py.next.program_processors.codegens.gtfn import gtfn_backend +from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen +from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_to_gtfn_im_ir import GTFN_IM_lowering +from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import GTFN_lowering from gt4py.next.type_system import type_specifications as ts, type_translation @@ -54,6 +57,7 @@ class GTFNTranslationStep( use_imperative_backend: bool = False lift_mode: Optional[LiftMode] = None device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + symbolic_domain_sizes: Optional[dict[str, str]] = None def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings: match self.device_type: @@ -171,6 +175,70 @@ def _process_connectivity_args( return parameters, arg_exprs + def _preprocess_program( + self, + program: itir.FencilDefinition, + offset_provider: dict[str, Connectivity | Dimension], + runtime_lift_mode: Optional[LiftMode] = None, + ) -> itir.FencilDefinition: + # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added + # to the interface of all (or at least all of concern) backends, but instead should be + # configured in the backend itself (like it is here), until then we respect the argument + # here and warn the user if it differs from the one configured. + lift_mode = runtime_lift_mode or self.lift_mode + if lift_mode != self.lift_mode: + warnings.warn( + f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but " + f"overriden to be {str(runtime_lift_mode)} at runtime." + ) + + if not self.enable_itir_transforms: + return program + + apply_common_transforms = functools.partial( + pass_manager.apply_common_transforms, + lift_mode=lift_mode, + offset_provider=offset_provider, + # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements + unconditionally_collapse_tuples=True, + symbolic_domain_sizes=self.symbolic_domain_sizes, + ) + + new_program = apply_common_transforms( + program, unroll_reduce=not self.use_imperative_backend + ) + + if self.use_imperative_backend and any( + node.id == "neighbors" + for node in new_program.pre_walk_values().if_isinstance(itir.SymRef) + ): + # if we don't unroll, there may be lifts left in the itir which can't be lowered to + # gtfn. In this case, just retry with unrolled reductions. + new_program = apply_common_transforms(program, unroll_reduce=True) + + return new_program + + def generate_stencil_source( + self, + program: itir.FencilDefinition, + offset_provider: dict[str, Connectivity | Dimension], + column_axis: Optional[common.Dimension], + runtime_lift_mode: Optional[LiftMode] = None, + ) -> str: + new_program = self._preprocess_program(program, offset_provider, runtime_lift_mode) + gtfn_ir = GTFN_lowering.apply( + new_program, + offset_provider=offset_provider, + column_axis=column_axis, + ) + + if self.use_imperative_backend: + gtfn_im_ir = GTFN_IM_lowering().visit(node=gtfn_ir) + generated_code = GTFNIMCodegen.apply(gtfn_im_ir) + else: + generated_code = GTFNCodegen.apply(gtfn_ir) + return codegen.format_source("cpp", generated_code, style="LLVM") + def __call__( self, inp: stages.ProgramCall, @@ -190,18 +258,6 @@ def __call__( inp.kwargs["offset_provider"] ) - # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added - # to the interface of all (or at least all of concern) backends, but instead should be - # configured in the backend itself (like it is here), until then we respect the argument - # here and warn the user if it differs from the one configured. - runtime_lift_mode = inp.kwargs.pop("lift_mode", None) - lift_mode = runtime_lift_mode or self.lift_mode - if runtime_lift_mode != self.lift_mode: - warnings.warn( - f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but " - "overriden to be {str(runtime_lift_mode)} at runtime." - ) - # combine into a format that is aligned with what the backend expects parameters: list[interface.Parameter] = regular_parameters + connectivity_parameters backend_arg = self._backend_type() @@ -213,12 +269,11 @@ def __call__( f"{', '.join(connectivity_args_expr)})({', '.join(args_expr)});" ) decl_src = cpp_interface.render_function_declaration(function, body=decl_body) - stencil_src = gtfn_backend.generate( + stencil_src = self.generate_stencil_source( program, - enable_itir_transforms=self.enable_itir_transforms, - lift_mode=lift_mode, - imperative=self.use_imperative_backend, - **inp.kwargs, + inp.kwargs["offset_provider"], + inp.kwargs.get("column_axis", None), + inp.kwargs.get("lift_mode", None), ) source_code = interface.format_source( self._language_settings(), diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py index f9fa154641..27dec77ed1 100644 --- a/src/gt4py/next/program_processors/formatters/gtfn.py +++ b/src/gt4py/next/program_processors/formatters/gtfn.py @@ -15,10 +15,19 @@ from typing import Any from gt4py.next.iterator import ir as itir -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.codegens.gtfn.gtfn_module import GTFNTranslationStep from gt4py.next.program_processors.processor_interface import program_formatter +from gt4py.next.program_processors.runners.gtfn import gtfn_executor @program_formatter def format_cpp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: - return generate(program, **kwargs) + # TODO(tehrengruber): This is a little ugly. Revisit. + gtfn_translation = gtfn_executor.otf_workflow.translation + assert isinstance(gtfn_translation, GTFNTranslationStep) + return gtfn_translation.generate_stencil_source( + program, + offset_provider=kwargs.get("offset_provider", None), + column_axis=kwargs.get("column_axis", None), + runtime_lift_mode=kwargs.get("lift_mode", None), + ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py new file mode 100644 index 0000000000..da0945fe96 --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -0,0 +1,113 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pytest +from numpy import int32, int64 + +from gt4py import next as gtx +from gt4py.next import common +from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms +from gt4py.next.program_processors import otf_compile_executor +from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries +from tests.next_tests.integration_tests.cases import Case +from tests.next_tests.toy_connectivity import Cell, Edge + +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import E2V, KDim, Vertex, cartesian_case, unstructured_case +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + reduction_setup, +) + + +@pytest.fixture +def run_gtfn_with_temporaries_and_symbolic_sizes(): + return otf_compile_executor.OTFBackend( + executor=otf_compile_executor.OTFCompileExecutor( + name="run_gtfn_with_temporaries_and_sizes", + otf_workflow=run_gtfn_with_temporaries.executor.otf_workflow.replace( + translation=run_gtfn_with_temporaries.executor.otf_workflow.translation.replace( + symbolic_domain_sizes={ + "Cell": "num_cells", + "Edge": "num_edges", + "Vertex": "num_vertices", + }, + ), + ), + ), + allocator=run_gtfn_with_temporaries.allocator, + ) + + +@pytest.fixture +def testee(): + @gtx.field_operator + def testee_op(a: cases.VField) -> cases.EField: + amul = a * 2 + return amul(E2V[0]) + amul(E2V[1]) + + @gtx.program + def prog( + a: cases.VField, + out: cases.EField, + num_vertices: int32, + num_edges: int64, + num_cells: int32, + ): + testee_op(a, out=out) + + return prog + + +def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, reduction_setup): + unstructured_case = Case( + run_gtfn_with_temporaries_and_symbolic_sizes, + offset_provider=reduction_setup.offset_provider, + default_sizes={ + Vertex: reduction_setup.num_vertices, + Edge: reduction_setup.num_edges, + Cell: reduction_setup.num_cells, + KDim: reduction_setup.k_levels, + }, + grid_type=common.GridType.UNSTRUCTURED, + ) + + a = cases.allocate(unstructured_case, testee, "a")() + out = cases.allocate(unstructured_case, testee, "out")() + + first_nbs, second_nbs = (reduction_setup.offset_provider["E2V"].table[:, i] for i in [0, 1]) + ref = (a.ndarray * 2)[first_nbs] + (a.ndarray * 2)[second_nbs] + + cases.verify( + unstructured_case, + testee, + a, + out, + reduction_setup.num_vertices, + reduction_setup.num_edges, + reduction_setup.num_cells, + inout=out, + ref=ref, + ) + + +def test_temporary_symbols(testee, reduction_setup): + itir_with_tmp = apply_common_transforms( + testee.itir, + lift_mode=LiftMode.FORCE_TEMPORARIES, + offset_provider=reduction_setup.offset_provider, + ) + + params = ["num_vertices", "num_edges", "num_cells"] + for param in params: + assert any([param == str(p) for p in itir_with_tmp.fencil.params]) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py index e851e7b130..5af4605988 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py @@ -18,7 +18,7 @@ from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fundef, offset from gt4py.next.iterator.tracing import trace_fencil_definition -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn @fundef @@ -69,7 +69,9 @@ def lap_fencil(i_size, j_size, k_size, i_off, j_off, k_off, out, inp): output_file = sys.argv[1] prog = trace_fencil_definition(lap_fencil, [None] * 8, use_arg_types=False) - generated_code = generate(prog, offset_provider={"i": IDim, "j": JDim}) + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider={"i": IDim, "j": JDim}, column_axis=None + ) with open(output_file, "w+") as output: output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py index 33c7d5baa7..3e8b88ac66 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py @@ -18,7 +18,7 @@ from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fundef from gt4py.next.iterator.tracing import trace_fencil_definition -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn IDim = gtx.Dimension("IDim") @@ -48,7 +48,9 @@ def copy_fencil(isize, jsize, ksize, inp, out): output_file = sys.argv[1] prog = trace_fencil_definition(copy_fencil, [None] * 5, use_arg_types=False) - generated_code = generate(prog, offset_provider={}) + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider={}, column_axis=None + ) with open(output_file, "w+") as output: output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py index f7472d4ac3..fdc57449ee 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py @@ -18,7 +18,7 @@ import gt4py.next as gtx from gt4py.next import Field, field_operator, program -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn IDim = gtx.Dimension("IDim") @@ -47,7 +47,9 @@ def copy_program( output_file = sys.argv[1] prog = copy_program.itir - generated_code = generate(prog, offset_provider={}) + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider={}, column_axis=None + ) with open(output_file, "w+") as output: output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py index 1dfd74baca..abc3755dca 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py @@ -19,7 +19,7 @@ from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fundef, offset from gt4py.next.iterator.tracing import trace_fencil_definition -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn, run_gtfn_imperative E2V = offset("E2V") @@ -92,13 +92,20 @@ def mapped_index(_, __) -> int: output_file = sys.argv[1] imperative = sys.argv[2].lower() == "true" + if imperative: + backend = run_gtfn_imperative + else: + backend = run_gtfn + # prog = trace(zavgS_fencil, [None] * 4) # TODO allow generating of 2 fencils prog = trace_fencil_definition(nabla_fencil, [None] * 7, use_arg_types=False) offset_provider = { "V2E": DummyConnectivity(max_neighbors=6, has_skip_values=True), "E2V": DummyConnectivity(max_neighbors=2, has_skip_values=False), } - generated_code = generate(prog, offset_provider=offset_provider, imperative=imperative) + generated_code = backend.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider=offset_provider, column_axis=None + ) with open(output_file, "w+") as output: output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py index 578a19faab..9755774fd0 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py @@ -19,7 +19,7 @@ from gt4py.next.iterator.runtime import closure, fundef from gt4py.next.iterator.tracing import trace_fencil_definition from gt4py.next.iterator.transforms import LiftMode -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn IDim = gtx.Dimension("IDim") @@ -67,10 +67,10 @@ def tridiagonal_solve_fencil(isize, jsize, ksize, a, b, c, d, x): prog = trace_fencil_definition(tridiagonal_solve_fencil, [None] * 8, use_arg_types=False) offset_provider = {"I": gtx.Dimension("IDim"), "J": gtx.Dimension("JDim")} - generated_code = generate( + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( prog, offset_provider=offset_provider, - lift_mode=LiftMode.SIMPLE_HEURISTIC, + runtime_lift_mode=LiftMode.SIMPLE_HEURISTIC, column_axis=KDim, ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 86c3c98c62..5c2802f90c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -323,7 +323,7 @@ def test_update_cartesian_domains(): for a, s in (("JDim", "j"), ("KDim", "k")) ], ) - actual = update_domains(testee, {"I": gtx.Dimension("IDim")}) + actual = update_domains(testee, {"I": gtx.Dimension("IDim")}, symbolic_sizes=None) assert actual == expected From b900b474566f21339d5c99aa2365f9bed86bf1ec Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 19 Jan 2024 11:46:04 +0100 Subject: [PATCH 07/11] build[cartesian][next]: Bump dace version from 0.14.4 to 0.15.1 (#1391) Bumping dace version to 0.15.1 affects both cartesian and next gt4py: * cartesian: removed try/except for dace backward compatibility * next: re-enabled some tests that were broken on dace 0.14.4 * all: fixed and/or suppressed flake8 and mypy errors --- .pre-commit-config.yaml | 38 ++-- constraints.txt | 191 ++++++++++-------- min-extra-requirements-test.txt | 4 +- pyproject.toml | 6 +- requirements-dev.txt | 191 ++++++++++-------- src/gt4py/__init__.py | 2 +- src/gt4py/cartesian/backend/dace_backend.py | 8 +- src/gt4py/cartesian/gtc/dace/nodes.py | 2 +- src/gt4py/eve/datamodels/core.py | 2 +- src/gt4py/eve/utils.py | 4 +- src/gt4py/next/common.py | 11 +- src/gt4py/next/ffront/fbuiltins.py | 2 +- src/gt4py/next/otf/workflow.py | 2 +- .../runners/dace_iterator/__init__.py | 13 +- .../runners/dace_iterator/itir_to_sdfg.py | 14 +- .../unit_tests/test_gtc/test_common.py | 2 +- .../ffront_tests/test_external_local_field.py | 10 - .../ffront_tests/test_gt4py_builtins.py | 40 ---- .../test_temporaries_with_sizes.py | 12 +- 19 files changed, 262 insertions(+), 292 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b1092fafd0..d9cfa0ff48 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -62,7 +62,7 @@ repos: ## version = re.search('black==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: '{version}' # version from constraints.txt") ##]]] - rev: '23.9.1' # version from constraints.txt + rev: '23.11.0' # version from constraints.txt ##[[[end]]] hooks: - id: black @@ -73,7 +73,7 @@ repos: ## version = re.search('isort==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: '{version}' # version from constraints.txt") ##]]] - rev: '5.12.0' # version from constraints.txt + rev: '5.13.0' # version from constraints.txt ##[[[end]]] hooks: - id: isort @@ -97,14 +97,14 @@ repos: ## print(f"- {pkg}==" + str(re.search(f'\n{pkg}==([0-9\.]*)', constraints)[1])) ##]]] - darglint==1.8.1 - - flake8-bugbear==23.9.16 - - flake8-builtins==2.1.0 + - flake8-bugbear==23.12.2 + - flake8-builtins==2.2.0 - flake8-debugger==4.1.2 - flake8-docstrings==1.7.0 - flake8-eradicate==1.5.0 - flake8-mutable==1.2.0 - flake8-pyproject==1.2.3 - - pygments==2.16.1 + - pygments==2.17.2 ##[[[end]]] # - flake8-rst-docstrings # Disabled for now due to random false positives exclude: | @@ -146,9 +146,9 @@ repos: ## version = re.search('mypy==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"#========= FROM constraints.txt: v{version} =========") ##]]] - #========= FROM constraints.txt: v1.5.1 ========= + #========= FROM constraints.txt: v1.7.1 ========= ##[[[end]]] - rev: v1.5.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) + rev: v1.7.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) hooks: - id: mypy additional_dependencies: # versions from constraints.txt @@ -162,26 +162,26 @@ repos: ##]]] - astunparse==1.6.3 - attrs==23.1.0 - - black==23.9.1 - - boltons==23.0.0 + - black==23.11.0 + - boltons==23.1.1 - cached-property==1.5.2 - click==8.1.7 - - cmake==3.27.5 + - cmake==3.27.9 - cytoolz==0.12.2 - - deepdiff==6.5.0 + - deepdiff==6.7.1 - devtools==0.12.2 - - frozendict==2.3.8 + - frozendict==2.3.10 - gridtools-cpp==2.3.1 - - importlib-resources==6.0.1 + - importlib-resources==6.1.1 - jinja2==3.1.2 - - lark==1.1.7 - - mako==1.2.4 - - nanobind==1.5.2 - - ninja==1.11.1 + - lark==1.1.8 + - mako==1.3.0 + - nanobind==1.8.0 + - ninja==1.11.1.1 - numpy==1.24.4 - - packaging==23.1 + - packaging==23.2 - pybind11==2.11.1 - - setuptools==68.2.2 + - setuptools==69.0.2 - tabulate==0.9.0 - typing-extensions==4.5.0 - xxhash==3.0.0 diff --git a/constraints.txt b/constraints.txt index b334851af1..81abd64c6e 100644 --- a/constraints.txt +++ b/constraints.txt @@ -6,124 +6,136 @@ # aenum==3.1.15 # via dace alabaster==0.7.13 # via sphinx -asttokens==2.4.0 # via devtools +asttokens==2.4.1 # via devtools astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing -babel==2.12.1 # via sphinx -black==23.9.1 # via gt4py (pyproject.toml) -blinker==1.6.2 # via flask -boltons==23.0.0 # via gt4py (pyproject.toml) +babel==2.13.1 # via sphinx +black==23.11.0 # via gt4py (pyproject.toml) +blinker==1.7.0 # via flask +boltons==23.1.1 # via gt4py (pyproject.toml) build==1.0.3 # via pip-tools cached-property==1.5.2 # via gt4py (pyproject.toml) -cachetools==5.3.1 # via tox -certifi==2023.7.22 # via requests -cffi==1.15.1 # via cryptography +cachetools==5.3.2 # via tox +cerberus==1.3.5 # via plette +certifi==2023.11.17 # via requests +cffi==1.16.0 # via cryptography cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox -charset-normalizer==3.2.0 # via requests -clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) +charset-normalizer==3.3.2 # via requests +clang-format==17.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.27.5 # via gt4py (pyproject.toml) +cmake==3.27.9 # via dace, gt4py (pyproject.toml) cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox -coverage==7.3.1 # via -r requirements-dev.in, pytest-cov -cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis +coverage==7.3.2 # via -r requirements-dev.in, pytest-cov +cryptography==41.0.7 # via types-paramiko, types-pyopenssl, types-redis cytoolz==0.12.2 # via gt4py (pyproject.toml) -dace==0.14.4 # via gt4py (pyproject.toml) +dace==0.15.1 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -deepdiff==6.5.0 # via gt4py (pyproject.toml) +deepdiff==6.7.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.7 # via dace -distlib==0.3.7 # via virtualenv -docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme +distlib==0.3.7 # via requirementslib, virtualenv +distro==1.8.0 # via scikit-build +docopt==0.6.2 # via pipreqs +docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme eradicate==2.3.0 # via flake8-eradicate -exceptiongroup==1.1.3 # via hypothesis, pytest +exceptiongroup==1.2.0 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist -executing==1.2.0 # via devtools +executing==2.0.1 # via devtools factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==19.6.1 # via factory-boy -fastjsonschema==2.18.0 # via nbformat -filelock==3.12.4 # via tox, virtualenv +faker==20.1.0 # via factory-boy +fastjsonschema==2.19.0 # via nbformat +filelock==3.13.1 # via tox, virtualenv flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings -flake8-bugbear==23.9.16 # via -r requirements-dev.in -flake8-builtins==2.1.0 # via -r requirements-dev.in +flake8-bugbear==23.12.2 # via -r requirements-dev.in +flake8-builtins==2.2.0 # via -r requirements-dev.in flake8-debugger==4.1.2 # via -r requirements-dev.in flake8-docstrings==1.7.0 # via -r requirements-dev.in flake8-eradicate==1.5.0 # via -r requirements-dev.in flake8-mutable==1.2.0 # via -r requirements-dev.in flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in -flask==2.3.3 # via dace -frozendict==2.3.8 # via gt4py (pyproject.toml) +flask==3.0.0 # via dace +fparser==0.1.3 # via dace +frozendict==2.3.10 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.86.1 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.5.29 # via pre-commit -idna==3.4 # via requests +hypothesis==6.92.0 # via -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.5.33 # via pre-commit +idna==3.6 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.8.0 # via build, flask, sphinx -importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications +importlib-metadata==7.0.0 # via build, flask, fparser, sphinx +importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest -isort==5.12.0 # via -r requirements-dev.in +isort==5.13.0 # via -r requirements-dev.in itsdangerous==2.1.2 # via flask jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx -jsonschema==4.19.0 # via nbformat -jsonschema-specifications==2023.7.1 # via jsonschema -jupyter-core==5.3.1 # via nbformat -jupytext==1.15.2 # via -r requirements-dev.in -lark==1.1.7 # via gt4py (pyproject.toml) -mako==1.2.4 # via gt4py (pyproject.toml) +jsonschema==4.20.0 # via nbformat +jsonschema-specifications==2023.11.2 # via jsonschema +jupyter-core==5.5.0 # via nbformat +jupytext==1.16.0 # via -r requirements-dev.in +lark==1.1.8 # via gt4py (pyproject.toml) +mako==1.3.0 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins markupsafe==2.1.3 # via jinja2, mako, werkzeug mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.5.1 # via -r requirements-dev.in +mypy==1.7.1 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==1.5.2 # via gt4py (pyproject.toml) +nanobind==1.8.0 # via gt4py (pyproject.toml) nbformat==5.9.2 # via jupytext networkx==3.1 # via dace -ninja==1.11.1 # via gt4py (pyproject.toml) +ninja==1.11.1.1 # via gt4py (pyproject.toml) nodeenv==1.8.0 # via pre-commit numpy==1.24.4 # via dace, gt4py (pyproject.toml), types-jack-client ordered-set==4.1.0 # via deepdiff -packaging==23.1 # via black, build, gt4py (pyproject.toml), pyproject-api, pytest, sphinx, tox -pathspec==0.11.2 # via black +packaging==23.2 # via black, build, gt4py (pyproject.toml), jupytext, pyproject-api, pytest, scikit-build, setuptools-scm, sphinx, tox +pathspec==0.12.1 # via black +pep517==0.13.1 # via requirementslib +pip-api==0.0.30 # via isort pip-tools==7.3.0 # via -r requirements-dev.in -pipdeptree==2.13.0 # via -r requirements-dev.in +pipdeptree==2.13.1 # via -r requirements-dev.in +pipreqs==0.4.13 # via isort pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv +platformdirs==4.1.0 # via black, jupyter-core, requirementslib, tox, virtualenv +plette==0.4.4 # via requirementslib pluggy==1.3.0 # via pytest, tox ply==3.11 # via dace -pre-commit==3.4.0 # via -r requirements-dev.in -psutil==5.9.5 # via -r requirements-dev.in, pytest-xdist +pre-commit==3.5.0 # via -r requirements-dev.in +psutil==5.9.6 # via -r requirements-dev.in, pytest-xdist pybind11==2.11.1 # via gt4py (pyproject.toml) -pycodestyle==2.11.0 # via flake8, flake8-debugger +pycodestyle==2.11.1 # via flake8, flake8-debugger pycparser==2.21 # via cffi +pydantic==1.10.13 # via requirementslib pydocstyle==6.3.0 # via flake8-docstrings pyflakes==3.1.0 # via flake8 -pygments==2.16.1 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx +pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.2 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==7.4.3 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in -pytest-factoryboy==2.5.1 # via -r requirements-dev.in -pytest-xdist==3.3.1 # via -r requirements-dev.in +pytest-factoryboy==2.6.0 # via -r requirements-dev.in +pytest-xdist==3.5.0 # via -r requirements-dev.in python-dateutil==2.8.2 # via faker pytz==2023.3.post1 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit -referencing==0.30.2 # via jsonschema, jsonschema-specifications -requests==2.31.0 # via dace, sphinx +referencing==0.32.0 # via jsonschema, jsonschema-specifications +requests==2.31.0 # via dace, requirementslib, sphinx, yarg +requirementslib==3.0.0 # via isort restructuredtext-lint==1.4.0 # via flake8-rst-docstrings -rpds-py==0.10.3 # via jsonschema, referencing -ruff==0.0.290 # via -r requirements-dev.in +rpds-py==0.13.2 # via jsonschema, referencing +ruff==0.1.7 # via -r requirements-dev.in +scikit-build==0.17.6 # via dace +setuptools-scm==8.0.4 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx sortedcontainers==2.4.0 # via hypothesis sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==1.3.0 # via -r requirements-dev.in +sphinx-rtd-theme==2.0.0 # via -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx sphinxcontrib-htmlhelp==2.0.1 # via sphinx @@ -131,31 +143,32 @@ sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme sphinxcontrib-jsmath==1.0.1 # via sphinx sphinxcontrib-qthelp==1.0.3 # via sphinx sphinxcontrib-serializinghtml==1.1.5 # via sphinx -sympy==1.12 # via dace, gt4py (pyproject.toml) +sympy==1.9 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext -tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox +tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pep517, pip-tools, pyproject-api, pyproject-hooks, pytest, scikit-build, setuptools-scm, tox +tomlkit==0.12.3 # via plette, requirementslib toolz==0.12.0 # via cytoolz -tox==4.11.3 # via -r requirements-dev.in -traitlets==5.10.0 # via jupyter-core, nbformat +tox==4.11.4 # via -r requirements-dev.in +traitlets==5.14.0 # via jupyter-core, nbformat types-aiofiles==23.2.0.0 # via types-all types-all==1.0.0 # via -r requirements-dev.in types-annoy==1.17.8.4 # via types-all types-atomicwrites==1.4.5.1 # via types-all types-backports==0.1.3 # via types-all types-backports-abc==0.5.2 # via types-all -types-bleach==6.0.0.4 # via types-all +types-bleach==6.1.0.1 # via types-all types-boto==2.49.18.9 # via types-all -types-cachetools==5.3.0.6 # via types-all +types-cachetools==5.3.0.7 # via types-all types-certifi==2021.10.8.3 # via types-all -types-cffi==1.15.1.15 # via types-jack-client +types-cffi==1.16.0.0 # via types-jack-client types-characteristic==14.3.7 # via types-all types-chardet==5.0.4.6 # via types-all types-click==7.1.8 # via types-all, types-flask -types-click-spinner==0.1.13.5 # via types-all +types-click-spinner==0.1.13.6 # via types-all types-colorama==0.4.15.12 # via types-all types-contextvars==2.4.7.3 # via types-all -types-croniter==1.4.0.1 # via types-all +types-croniter==2.0.0.0 # via types-all types-cryptography==3.3.23.2 # via types-all, types-openssl-python, types-pyjwt types-dataclasses==0.6.6 # via types-all types-dateparser==1.1.4.10 # via types-all @@ -176,44 +189,44 @@ types-futures==3.3.8 # via types-all types-geoip2==3.0.0 # via types-all types-ipaddress==1.0.8 # via types-all, types-maxminddb types-itsdangerous==1.1.6 # via types-all -types-jack-client==0.5.10.9 # via types-all +types-jack-client==0.5.10.10 # via types-all types-jinja2==2.11.9 # via types-all, types-flask types-kazoo==0.1.3 # via types-all -types-markdown==3.4.2.10 # via types-all +types-markdown==3.5.0.3 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 -types-mock==5.1.0.2 # via types-all +types-mock==5.1.0.3 # via types-all types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all -types-paramiko==3.3.0.0 # via types-all, types-pysftp +types-paramiko==3.3.0.2 # via types-all, types-pysftp types-pathlib2==2.3.0 # via types-all -types-pillow==10.0.0.3 # via types-all +types-pillow==10.1.0.2 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.1 # via types-all -types-protobuf==4.24.0.1 # via types-all +types-protobuf==4.24.0.4 # via types-all types-pyaudio==0.2.16.7 # via types-all types-pycurl==7.45.2.5 # via types-all types-pyfarmhash==0.3.1.2 # via types-all types-pyjwt==1.7.1 # via types-all types-pymssql==2.1.0 # via types-all types-pymysql==1.1.0.1 # via types-all -types-pyopenssl==23.2.0.2 # via types-redis +types-pyopenssl==23.3.0.0 # via types-redis types-pyrfc3339==1.1.1.5 # via types-all types-pysftp==0.2.17.6 # via types-all types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.1.0 # via types-all, types-tzlocal +types-pytz==2023.3.1.1 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all -types-pyyaml==6.0.12.11 # via types-all -types-redis==4.6.0.6 # via types-all -types-requests==2.31.0.2 # via types-all +types-pyyaml==6.0.12.12 # via types-all +types-redis==4.6.0.11 # via types-all +types-requests==2.31.0.10 # via types-all types-retry==0.9.9.4 # via types-all types-routes==2.5.0 # via types-all types-scribe==2.0.0 # via types-all -types-setuptools==68.2.0.0 # via types-cffi +types-setuptools==69.0.0.0 # via types-cffi types-simplejson==3.19.0.2 # via types-all types-singledispatch==4.1.0.0 # via types-all types-six==1.16.21.9 # via types-all @@ -222,21 +235,21 @@ types-termcolor==1.1.6.2 # via types-all types-toml==0.10.8.7 # via types-all types-tornado==5.1.1 # via types-all types-typed-ast==1.5.8.7 # via types-all -types-tzlocal==5.0.1.1 # via types-all +types-tzlocal==5.1.0.1 # via types-all types-ujson==5.8.0.1 # via types-all -types-urllib3==1.26.25.14 # via types-requests types-waitress==2.1.4.9 # via types-all types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all -typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy -urllib3==2.0.4 # via requests -virtualenv==20.24.5 # via pre-commit, tox -websockets==11.0.3 # via dace -werkzeug==2.3.7 # via flask -wheel==0.41.2 # via astunparse, pip-tools +typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pydantic, pytest-factoryboy, setuptools-scm +urllib3==2.1.0 # via requests, types-requests +virtualenv==20.25.0 # via pre-commit, tox +websockets==12.0 # via dace +werkzeug==3.0.1 # via flask +wheel==0.42.0 # via astunparse, pip-tools, scikit-build xxhash==3.0.0 # via gt4py (pyproject.toml) -zipp==3.16.2 # via importlib-metadata, importlib-resources +yarg==0.1.9 # via pipreqs +zipp==3.17.0 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: -pip==23.2.1 # via pip-tools -setuptools==68.2.2 # via gt4py (pyproject.toml), nodeenv, pip-tools +pip==23.3.1 # via pip-api, pip-tools, requirementslib +setuptools==69.0.2 # via gt4py (pyproject.toml), nodeenv, pip-tools, requirementslib, scikit-build, setuptools-scm diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 17709206a0..fd7724bac9 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -25,7 +25,7 @@ cmake==3.22 cogapp==3.3 coverage[toml]==5.0 cytoolz==0.12.0 -dace==0.14.2 +dace==0.15.1 darglint==1.6 deepdiff==5.6.0 devtools==0.6 @@ -70,7 +70,7 @@ scipy==1.7.2 setuptools==65.5.0 sphinx==4.4 sphinx_rtd_theme==1.0 -sympy==1.7 +sympy==1.9 tabulate==0.8.10 tomli==2.0.1 tox==3.2.0 diff --git a/pyproject.toml b/pyproject.toml index 5d7a2f2cb6..675bdae9d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,15 +69,15 @@ requires-python = '>=3.8' cuda = ['cupy>=12.0'] cuda11x = ['cupy-cuda11x>=12.0'] cuda12x = ['cupy-cuda12x>=12.0'] -dace = ['dace>=0.14.2,<0.15', 'sympy>=1.7'] +dace = ['dace>=0.15.1,<0.16', 'sympy>=1.9'] formatting = ['clang-format>=9.0'] # Always add all extra packages to 'full' for a simple full gt4py installation full = [ 'clang-format>=9.0', - 'dace>=0.14.2,<0.15', + 'dace>=0.15.1,<0.16', 'hypothesis>=6.0.0', 'pytest>=7.0', - 'sympy>=1.7', + 'sympy>=1.9', 'scipy>=1.7.2', 'jax[cpu]>=0.4.13' ] diff --git a/requirements-dev.txt b/requirements-dev.txt index d6dcc12d21..0fa523866f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,124 +6,136 @@ # aenum==3.1.15 # via dace alabaster==0.7.13 # via sphinx -asttokens==2.4.0 # via devtools +asttokens==2.4.1 # via devtools astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing -babel==2.12.1 # via sphinx -black==23.9.1 # via gt4py (pyproject.toml) -blinker==1.6.2 # via flask -boltons==23.0.0 # via gt4py (pyproject.toml) +babel==2.13.1 # via sphinx +black==23.11.0 # via gt4py (pyproject.toml) +blinker==1.7.0 # via flask +boltons==23.1.1 # via gt4py (pyproject.toml) build==1.0.3 # via pip-tools cached-property==1.5.2 # via gt4py (pyproject.toml) -cachetools==5.3.1 # via tox -certifi==2023.7.22 # via requests -cffi==1.15.1 # via cryptography +cachetools==5.3.2 # via tox +cerberus==1.3.5 # via plette +certifi==2023.11.17 # via requests +cffi==1.16.0 # via cryptography cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox -charset-normalizer==3.2.0 # via requests -clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) +charset-normalizer==3.3.2 # via requests +clang-format==17.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.27.5 # via gt4py (pyproject.toml) +cmake==3.27.9 # via dace, gt4py (pyproject.toml) cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox -coverage[toml]==7.3.1 # via -r requirements-dev.in, pytest-cov -cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis +coverage[toml]==7.3.2 # via -r requirements-dev.in, pytest-cov +cryptography==41.0.7 # via types-paramiko, types-pyopenssl, types-redis cytoolz==0.12.2 # via gt4py (pyproject.toml) -dace==0.14.4 # via gt4py (pyproject.toml) +dace==0.15.1 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -deepdiff==6.5.0 # via gt4py (pyproject.toml) +deepdiff==6.7.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.7 # via dace -distlib==0.3.7 # via virtualenv -docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme +distlib==0.3.7 # via requirementslib, virtualenv +distro==1.8.0 # via scikit-build +docopt==0.6.2 # via pipreqs +docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme eradicate==2.3.0 # via flake8-eradicate -exceptiongroup==1.1.3 # via hypothesis, pytest +exceptiongroup==1.2.0 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist -executing==1.2.0 # via devtools +executing==2.0.1 # via devtools factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==19.6.1 # via factory-boy -fastjsonschema==2.18.0 # via nbformat -filelock==3.12.4 # via tox, virtualenv +faker==20.1.0 # via factory-boy +fastjsonschema==2.19.0 # via nbformat +filelock==3.13.1 # via tox, virtualenv flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings -flake8-bugbear==23.9.16 # via -r requirements-dev.in -flake8-builtins==2.1.0 # via -r requirements-dev.in +flake8-bugbear==23.12.2 # via -r requirements-dev.in +flake8-builtins==2.2.0 # via -r requirements-dev.in flake8-debugger==4.1.2 # via -r requirements-dev.in flake8-docstrings==1.7.0 # via -r requirements-dev.in flake8-eradicate==1.5.0 # via -r requirements-dev.in flake8-mutable==1.2.0 # via -r requirements-dev.in flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in -flask==2.3.3 # via dace -frozendict==2.3.8 # via gt4py (pyproject.toml) +flask==3.0.0 # via dace +fparser==0.1.3 # via dace +frozendict==2.3.10 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.86.1 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.5.29 # via pre-commit -idna==3.4 # via requests +hypothesis==6.92.0 # via -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.5.33 # via pre-commit +idna==3.6 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.8.0 # via build, flask, sphinx -importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications +importlib-metadata==7.0.0 # via build, flask, fparser, sphinx +importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest -isort==5.12.0 # via -r requirements-dev.in +isort==5.13.0 # via -r requirements-dev.in itsdangerous==2.1.2 # via flask jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx -jsonschema==4.19.0 # via nbformat -jsonschema-specifications==2023.7.1 # via jsonschema -jupyter-core==5.3.1 # via nbformat -jupytext==1.15.2 # via -r requirements-dev.in -lark==1.1.7 # via gt4py (pyproject.toml) -mako==1.2.4 # via gt4py (pyproject.toml) +jsonschema==4.20.0 # via nbformat +jsonschema-specifications==2023.11.2 # via jsonschema +jupyter-core==5.5.0 # via nbformat +jupytext==1.16.0 # via -r requirements-dev.in +lark==1.1.8 # via gt4py (pyproject.toml) +mako==1.3.0 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins markupsafe==2.1.3 # via jinja2, mako, werkzeug mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.5.1 # via -r requirements-dev.in +mypy==1.7.1 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==1.5.2 # via gt4py (pyproject.toml) +nanobind==1.8.0 # via gt4py (pyproject.toml) nbformat==5.9.2 # via jupytext networkx==3.1 # via dace -ninja==1.11.1 # via gt4py (pyproject.toml) +ninja==1.11.1.1 # via gt4py (pyproject.toml) nodeenv==1.8.0 # via pre-commit numpy==1.24.4 # via dace, gt4py (pyproject.toml), types-jack-client ordered-set==4.1.0 # via deepdiff -packaging==23.1 # via black, build, gt4py (pyproject.toml), pyproject-api, pytest, sphinx, tox -pathspec==0.11.2 # via black +packaging==23.2 # via black, build, gt4py (pyproject.toml), jupytext, pyproject-api, pytest, scikit-build, setuptools-scm, sphinx, tox +pathspec==0.12.1 # via black +pep517==0.13.1 # via requirementslib +pip-api==0.0.30 # via isort pip-tools==7.3.0 # via -r requirements-dev.in -pipdeptree==2.13.0 # via -r requirements-dev.in +pipdeptree==2.13.1 # via -r requirements-dev.in +pipreqs==0.4.13 # via isort pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv +platformdirs==4.1.0 # via black, jupyter-core, requirementslib, tox, virtualenv +plette[validation]==0.4.4 # via requirementslib pluggy==1.3.0 # via pytest, tox ply==3.11 # via dace -pre-commit==3.4.0 # via -r requirements-dev.in -psutil==5.9.5 # via -r requirements-dev.in, pytest-xdist +pre-commit==3.5.0 # via -r requirements-dev.in +psutil==5.9.6 # via -r requirements-dev.in, pytest-xdist pybind11==2.11.1 # via gt4py (pyproject.toml) -pycodestyle==2.11.0 # via flake8, flake8-debugger +pycodestyle==2.11.1 # via flake8, flake8-debugger pycparser==2.21 # via cffi +pydantic==1.10.13 # via requirementslib pydocstyle==6.3.0 # via flake8-docstrings pyflakes==3.1.0 # via flake8 -pygments==2.16.1 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx +pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.2 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==7.4.3 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in -pytest-factoryboy==2.5.1 # via -r requirements-dev.in -pytest-xdist[psutil]==3.3.1 # via -r requirements-dev.in +pytest-factoryboy==2.6.0 # via -r requirements-dev.in +pytest-xdist[psutil]==3.5.0 # via -r requirements-dev.in python-dateutil==2.8.2 # via faker pytz==2023.3.post1 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit -referencing==0.30.2 # via jsonschema, jsonschema-specifications -requests==2.31.0 # via dace, sphinx +referencing==0.32.0 # via jsonschema, jsonschema-specifications +requests==2.31.0 # via dace, requirementslib, sphinx, yarg +requirementslib==3.0.0 # via isort restructuredtext-lint==1.4.0 # via flake8-rst-docstrings -rpds-py==0.10.3 # via jsonschema, referencing -ruff==0.0.290 # via -r requirements-dev.in +rpds-py==0.13.2 # via jsonschema, referencing +ruff==0.1.7 # via -r requirements-dev.in +scikit-build==0.17.6 # via dace +setuptools-scm==8.0.4 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx sortedcontainers==2.4.0 # via hypothesis sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==1.3.0 # via -r requirements-dev.in +sphinx-rtd-theme==2.0.0 # via -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx sphinxcontrib-htmlhelp==2.0.1 # via sphinx @@ -131,31 +143,32 @@ sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme sphinxcontrib-jsmath==1.0.1 # via sphinx sphinxcontrib-qthelp==1.0.3 # via sphinx sphinxcontrib-serializinghtml==1.1.5 # via sphinx -sympy==1.12 # via dace, gt4py (pyproject.toml) +sympy==1.9 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext -tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox +tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pep517, pip-tools, pyproject-api, pyproject-hooks, pytest, scikit-build, setuptools-scm, tox +tomlkit==0.12.3 # via plette, requirementslib toolz==0.12.0 # via cytoolz -tox==4.11.3 # via -r requirements-dev.in -traitlets==5.10.0 # via jupyter-core, nbformat +tox==4.11.4 # via -r requirements-dev.in +traitlets==5.14.0 # via jupyter-core, nbformat types-aiofiles==23.2.0.0 # via types-all types-all==1.0.0 # via -r requirements-dev.in types-annoy==1.17.8.4 # via types-all types-atomicwrites==1.4.5.1 # via types-all types-backports==0.1.3 # via types-all types-backports-abc==0.5.2 # via types-all -types-bleach==6.0.0.4 # via types-all +types-bleach==6.1.0.1 # via types-all types-boto==2.49.18.9 # via types-all -types-cachetools==5.3.0.6 # via types-all +types-cachetools==5.3.0.7 # via types-all types-certifi==2021.10.8.3 # via types-all -types-cffi==1.15.1.15 # via types-jack-client +types-cffi==1.16.0.0 # via types-jack-client types-characteristic==14.3.7 # via types-all types-chardet==5.0.4.6 # via types-all types-click==7.1.8 # via types-all, types-flask -types-click-spinner==0.1.13.5 # via types-all +types-click-spinner==0.1.13.6 # via types-all types-colorama==0.4.15.12 # via types-all types-contextvars==2.4.7.3 # via types-all -types-croniter==1.4.0.1 # via types-all +types-croniter==2.0.0.0 # via types-all types-cryptography==3.3.23.2 # via types-all, types-openssl-python, types-pyjwt types-dataclasses==0.6.6 # via types-all types-dateparser==1.1.4.10 # via types-all @@ -176,44 +189,44 @@ types-futures==3.3.8 # via types-all types-geoip2==3.0.0 # via types-all types-ipaddress==1.0.8 # via types-all, types-maxminddb types-itsdangerous==1.1.6 # via types-all -types-jack-client==0.5.10.9 # via types-all +types-jack-client==0.5.10.10 # via types-all types-jinja2==2.11.9 # via types-all, types-flask types-kazoo==0.1.3 # via types-all -types-markdown==3.4.2.10 # via types-all +types-markdown==3.5.0.3 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 -types-mock==5.1.0.2 # via types-all +types-mock==5.1.0.3 # via types-all types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all -types-paramiko==3.3.0.0 # via types-all, types-pysftp +types-paramiko==3.3.0.2 # via types-all, types-pysftp types-pathlib2==2.3.0 # via types-all -types-pillow==10.0.0.3 # via types-all +types-pillow==10.1.0.2 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.1 # via types-all -types-protobuf==4.24.0.1 # via types-all +types-protobuf==4.24.0.4 # via types-all types-pyaudio==0.2.16.7 # via types-all types-pycurl==7.45.2.5 # via types-all types-pyfarmhash==0.3.1.2 # via types-all types-pyjwt==1.7.1 # via types-all types-pymssql==2.1.0 # via types-all types-pymysql==1.1.0.1 # via types-all -types-pyopenssl==23.2.0.2 # via types-redis +types-pyopenssl==23.3.0.0 # via types-redis types-pyrfc3339==1.1.1.5 # via types-all types-pysftp==0.2.17.6 # via types-all types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.1.0 # via types-all, types-tzlocal +types-pytz==2023.3.1.1 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all -types-pyyaml==6.0.12.11 # via types-all -types-redis==4.6.0.6 # via types-all -types-requests==2.31.0.2 # via types-all +types-pyyaml==6.0.12.12 # via types-all +types-redis==4.6.0.11 # via types-all +types-requests==2.31.0.10 # via types-all types-retry==0.9.9.4 # via types-all types-routes==2.5.0 # via types-all types-scribe==2.0.0 # via types-all -types-setuptools==68.2.0.0 # via types-cffi +types-setuptools==69.0.0.0 # via types-cffi types-simplejson==3.19.0.2 # via types-all types-singledispatch==4.1.0.0 # via types-all types-six==1.16.21.9 # via types-all @@ -222,21 +235,21 @@ types-termcolor==1.1.6.2 # via types-all types-toml==0.10.8.7 # via types-all types-tornado==5.1.1 # via types-all types-typed-ast==1.5.8.7 # via types-all -types-tzlocal==5.0.1.1 # via types-all +types-tzlocal==5.1.0.1 # via types-all types-ujson==5.8.0.1 # via types-all -types-urllib3==1.26.25.14 # via types-requests types-waitress==2.1.4.9 # via types-all types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all -typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy -urllib3==2.0.4 # via requests -virtualenv==20.24.5 # via pre-commit, tox -websockets==11.0.3 # via dace -werkzeug==2.3.7 # via flask -wheel==0.41.2 # via astunparse, pip-tools +typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pydantic, pytest-factoryboy, setuptools-scm +urllib3==2.1.0 # via requests, types-requests +virtualenv==20.25.0 # via pre-commit, tox +websockets==12.0 # via dace +werkzeug==3.0.1 # via flask +wheel==0.42.0 # via astunparse, pip-tools, scikit-build xxhash==3.0.0 # via gt4py (pyproject.toml) -zipp==3.16.2 # via importlib-metadata, importlib-resources +yarg==0.1.9 # via pipreqs +zipp==3.17.0 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: -pip==23.2.1 # via pip-tools -setuptools==68.2.2 # via gt4py (pyproject.toml), nodeenv, pip-tools +pip==23.3.1 # via pip-api, pip-tools, requirementslib +setuptools==69.0.2 # via gt4py (pyproject.toml), nodeenv, pip-tools, requirementslib, scikit-build, setuptools-scm diff --git a/src/gt4py/__init__.py b/src/gt4py/__init__.py index 7d255de142..c28c5cf2d6 100644 --- a/src/gt4py/__init__.py +++ b/src/gt4py/__init__.py @@ -33,6 +33,6 @@ if _sys.version_info >= (3, 10): - from . import next + from . import next # noqa: A004 __all__ += ["next"] diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index b1e559a41e..5dae025acb 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -562,12 +562,6 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.S omp_threads = "" omp_header = "" - # Backward compatible state struct name change in DaCe >=0.15.x - try: - dace_state_suffix = dace.Config.get("compiler.codegen_state_struct_suffix") - except (KeyError, TypeError): - dace_state_suffix = "_t" # old structure name - interface = cls.template.definition.render( name=sdfg.name, backend_specifics=omp_threads, @@ -575,7 +569,7 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.S functor_args=self.generate_functor_args(sdfg), tmp_allocs=self.generate_tmp_allocs(sdfg), allocator="gt::cuda_util::cuda_malloc" if is_gpu else "std::make_unique", - state_suffix=dace_state_suffix, + state_suffix=dace.Config.get("compiler.codegen_state_struct_suffix"), ) generated_code = textwrap.dedent( f"""#include diff --git a/src/gt4py/cartesian/gtc/dace/nodes.py b/src/gt4py/cartesian/gtc/dace/nodes.py index ddcb719b5f..bd8c08034c 100644 --- a/src/gt4py/cartesian/gtc/dace/nodes.py +++ b/src/gt4py/cartesian/gtc/dace/nodes.py @@ -121,7 +121,7 @@ def __init__( *args, **kwargs, ): - super().__init__(name=name, *args, **kwargs) + super().__init__(*args, name=name, **kwargs) from gt4py.cartesian.gtc.dace.utils import compute_dcir_access_infos diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index fcd53d1312..5660fdbf76 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -814,7 +814,7 @@ def concretize( """ # noqa: RST301 # doctest conventions confuse RST validator concrete_cls: Type[DataModelT] = _make_concrete_with_cache( - datamodel_cls, *type_args, class_name=class_name, module=module + datamodel_cls, *type_args, class_name=class_name, module=module # type: ignore[arg-type] ) assert isinstance(concrete_cls, type) and is_datamodel(concrete_cls) diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 7104f7658f..624407f319 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -1225,7 +1225,7 @@ def unzip(self) -> XIterable[Tuple[Any, ...]]: [('a', 'b', 'c'), (1, 2, 3)] """ - return XIterable(zip(*self.iterator)) # type: ignore # mypy gets confused with *args + return XIterable(zip(*self.iterator)) @typing.overload def islice(self, __stop: int) -> XIterable[T]: @@ -1536,7 +1536,7 @@ def reduceby( ) -> Dict[K, S]: ... - def reduceby( # type: ignore[misc] # signatures 2 and 4 are not satified due to inconsistencies with type variables + def reduceby( self, bin_op_func: Callable[[S, T], S], key: Union[str, List[K], Callable[[T], K]], diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 6bf6858369..949f4b461a 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -189,11 +189,12 @@ def __and__(self, other: UnitRange) -> UnitRange: return UnitRange(max(self.start, other.start), min(self.stop, other.stop)) def __contains__(self, value: Any) -> bool: - return ( - isinstance(value, core_defs.INTEGRAL_TYPES) - and value >= self.start - and value < self.stop - ) + # TODO(egparedes): use core_defs.IntegralScalar for `isinstance()` checks (see PEP 604) + # and remove int cast, once the related mypy bug (#16358) gets fixed + if isinstance(value, core_defs.INTEGRAL_TYPES): + return self.start <= cast(int, value) < self.stop + else: + return False def __le__(self, other: UnitRange) -> bool: return self.start >= other.start and self.stop <= other.stop diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 278dde9180..cd75538da7 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -15,7 +15,7 @@ import dataclasses import functools import inspect -from builtins import bool, float, int, tuple +from builtins import bool, float, int, tuple # noqa: A004 from typing import Any, Callable, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast import numpy as np diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index ed8b768972..3a82f9c738 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -82,7 +82,7 @@ def replace(self, **kwargs: Any) -> Self: if not dataclasses.is_dataclass(self): raise TypeError(f"'{self.__class__}' is not a dataclass.") assert not isinstance(self, type) - return dataclasses.replace(self, **kwargs) # type: ignore[misc] # `self` is guaranteed to be a dataclass (is_dataclass) should be a `TypeGuard`? + return dataclasses.replace(self, **kwargs) class ChainableWorkflowMixin(Workflow[StartT, EndT]): diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 7fd4794e57..fdd8a61054 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -260,10 +260,12 @@ def build_sdfg_from_itir( # visit ITIR and generate SDFG program = preprocess_program(program, offset_provider, lift_mode) - # TODO: According to Lex one should build the SDFG first in a general mannor. - # Generalisation to a particular device should happen only at the end. - sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, on_gpu) + sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) sdfg = sdfg_genenerator.visit(program) + if sdfg is None: + raise RuntimeError(f"Visit failed for program {program.id}.") + + # run DaCe transformations to simplify the SDFG sdfg.simplify() # run DaCe auto-optimization heuristics @@ -274,6 +276,9 @@ def build_sdfg_from_itir( device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu) + if on_gpu: + sdfg.apply_gpu_transformations() + return sdfg @@ -283,7 +288,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): compiler_args = kwargs.get("compiler_args", None) # `None` will take default. build_type = kwargs.get("build_type", "RelWithDebInfo") on_gpu = kwargs.get("on_gpu", False) - auto_optimize = kwargs.get("auto_optimize", False) + auto_optimize = kwargs.get("auto_optimize", True) lift_mode = kwargs.get("lift_mode", itir_transforms.LiftMode.FORCE_INLINE) # ITIR parameters column_axis = kwargs.get("column_axis", None) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index e3b5ddf2ac..fb2f82fed0 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -99,20 +99,17 @@ class ItirToSDFG(eve.NodeVisitor): offset_provider: dict[str, Any] node_types: dict[int, next_typing.Type] unique_id: int - use_gpu_storage: bool def __init__( self, param_types: list[ts.TypeSpec], offset_provider: dict[str, NeighborTableOffsetProvider], column_axis: Optional[Dimension] = None, - use_gpu_storage: bool = False, ): self.param_types = param_types self.column_axis = column_axis self.offset_provider = offset_provider self.storage_types = {} - self.use_gpu_storage = use_gpu_storage def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True): if isinstance(type_, ts.FieldType): @@ -123,14 +120,7 @@ def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset else None ) dtype = as_dace_type(type_.dtype) - storage = ( - dace.dtypes.StorageType.GPU_Global - if self.use_gpu_storage - else dace.dtypes.StorageType.Default - ) - sdfg.add_array( - name, shape=shape, strides=strides, offset=offset, dtype=dtype, storage=storage - ) + sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype) elif isinstance(type_, ts.ScalarType): sdfg.add_symbol(name, as_dace_type(type_)) @@ -246,7 +236,6 @@ def visit_StencilClosure( shape=array_table[name].shape, strides=array_table[name].strides, dtype=array_table[name].dtype, - storage=array_table[name].storage, transient=True, ) closure_init_state.add_nedge( @@ -261,7 +250,6 @@ def visit_StencilClosure( shape=array_table[name].shape, strides=array_table[name].strides, dtype=array_table[name].dtype, - storage=array_table[name].storage, ) else: assert isinstance(self.storage_types[name], ts.ScalarType) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py index e580333bc8..8cfff12df4 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py @@ -312,7 +312,7 @@ def test_symbolref_validation_for_valid_tree(): SymbolTableRootNode( nodes=[SymbolChildNode(name="foo"), SymbolRefChildNode(name="foo")], ) - SymbolTableRootNode( + SymbolTableRootNode( # noqa: B018 nodes=[ SymbolChildNode(name="foo"), SymbolRefChildNode(name="foo"), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 698dce2b5c..d100cd380c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -30,16 +30,6 @@ def test_external_local_field(unstructured_case): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - @gtx.field_operator def testee( inp: gtx.Field[[Vertex, V2EDim], int32], ones: gtx.Field[[Edge], int32] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index e8d0c8b163..e2434d860a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -46,16 +46,6 @@ ids=["positive_values", "negative_values"], ) def test_maxover_execution_(unstructured_case, strategy): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - if unstructured_case.backend in [ gtfn.run_gtfn, gtfn.run_gtfn_gpu, @@ -79,16 +69,6 @@ def testee(edge_f: cases.EField) -> cases.VField: @pytest.mark.uses_unstructured_shift def test_minover_execution(unstructured_case): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - @gtx.field_operator def minover(edge_f: cases.EField) -> cases.VField: out = min_over(edge_f(V2E), axis=V2EDim) @@ -102,16 +82,6 @@ def minover(edge_f: cases.EField) -> cases.VField: @pytest.mark.uses_unstructured_shift def test_reduction_execution(unstructured_case): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - @gtx.field_operator def reduction(edge_f: cases.EField) -> cases.VField: return neighbor_sum(edge_f(V2E), axis=V2EDim) @@ -150,16 +120,6 @@ def fencil(edge_f: cases.EField, out: cases.VField): @pytest.mark.uses_unstructured_shift def test_reduction_with_common_expression(unstructured_case): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - @gtx.field_operator def testee(flux: cases.EField) -> cases.VField: return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index da0945fe96..788081b81e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -20,14 +20,20 @@ from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms from gt4py.next.program_processors import otf_compile_executor from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries -from tests.next_tests.integration_tests.cases import Case -from tests.next_tests.toy_connectivity import Cell, Edge from next_tests.integration_tests import cases -from next_tests.integration_tests.cases import E2V, KDim, Vertex, cartesian_case, unstructured_case +from next_tests.integration_tests.cases import ( + E2V, + Case, + KDim, + Vertex, + cartesian_case, + unstructured_case, +) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( reduction_setup, ) +from next_tests.toy_connectivity import Cell, Edge @pytest.fixture From 90e5d5a281fb02c1c4558e9097fc5fb980584321 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Fri, 19 Jan 2024 13:44:39 +0100 Subject: [PATCH 08/11] feature[next]: add support for Python3.11 by fixing typing-related bugs (#1418) Fixes hidden bugs in `eve.datamodels` and `eve.extended_typing` to support Python 3.11. Actual bug fixes: - Previous fix to support `typing.Any` implementation as a class (https://github.com/python/cpython/commit/5a4973e29f2f5c4ee8c086f40325786c62381540) didn't work in 3.11. - Partially concretization of generic datamodels replacing typevars was broken. - Partially concretization of generic datamodels leaving some parameters as typevars was broken. Other changes: - Add python 3.11 as supported version. - Remove dead code in comments. - Fix some imports style to comply with our coding guidelines. --- .github/workflows/daily-ci.yml | 2 +- .github/workflows/test-cartesian-fallback.yml | 2 +- .github/workflows/test-cartesian.yml | 2 +- .github/workflows/test-eve-fallback.yml | 2 +- .github/workflows/test-eve.yml | 3 +- .github/workflows/test-next-fallback.yml | 2 +- .github/workflows/test-next.yml | 2 +- .github/workflows/test-storage-fallback.yml | 2 +- .github/workflows/test-storage.yml | 3 +- src/gt4py/eve/datamodels/core.py | 40 ++---- src/gt4py/eve/extended_typing.py | 15 ++- tests/eve_tests/unit_tests/test_datamodels.py | 117 ++++++++++-------- .../unit_tests/test_type_validation.py | 5 +- .../ffront_tests/test_icon_like_scan.py | 26 ++-- .../test_single_static_assign.py | 5 +- tox.ini | 35 +++--- 16 files changed, 140 insertions(+), 123 deletions(-) diff --git a/.github/workflows/daily-ci.yml b/.github/workflows/daily-ci.yml index 77ba39a361..8631390dbb 100644 --- a/.github/workflows/daily-ci.yml +++ b/.github/workflows/daily-ci.yml @@ -14,7 +14,7 @@ jobs: daily-ci: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] tox-module-factor: ["cartesian", "eve", "next", "storage"] os: ["ubuntu-latest"] requirements-file: ["requirements-dev.txt", "min-requirements-test.txt", "min-extra-requirements-test.txt"] diff --git a/.github/workflows/test-cartesian-fallback.yml b/.github/workflows/test-cartesian-fallback.yml index b2eaead47a..7e9a948e9c 100644 --- a/.github/workflows/test-cartesian-fallback.yml +++ b/.github/workflows/test-cartesian-fallback.yml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] backends: [internal-cpu, dace-cpu] steps: diff --git a/.github/workflows/test-cartesian.yml b/.github/workflows/test-cartesian.yml index 2c2b97aaa6..ebdc4ce749 100644 --- a/.github/workflows/test-cartesian.yml +++ b/.github/workflows/test-cartesian.yml @@ -23,7 +23,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] backends: [internal-cpu, dace-cpu] steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/test-eve-fallback.yml b/.github/workflows/test-eve-fallback.yml index 93dc308a53..fd7ab5452c 100644 --- a/.github/workflows/test-eve-fallback.yml +++ b/.github/workflows/test-eve-fallback.yml @@ -17,7 +17,7 @@ jobs: test-eve: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] os: ["ubuntu-latest"] runs-on: ${{ matrix.os }} diff --git a/.github/workflows/test-eve.yml b/.github/workflows/test-eve.yml index 1322c573db..222b825f38 100644 --- a/.github/workflows/test-eve.yml +++ b/.github/workflows/test-eve.yml @@ -20,7 +20,7 @@ jobs: test-eve: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] os: ["ubuntu-latest"] fail-fast: false @@ -68,4 +68,3 @@ jobs: # with: # name: info-py${{ matrix.python-version }}-${{ matrix.os }} # path: info.txt - diff --git a/.github/workflows/test-next-fallback.yml b/.github/workflows/test-next-fallback.yml index 8490a3e393..bdcc061db0 100644 --- a/.github/workflows/test-next-fallback.yml +++ b/.github/workflows/test-next-fallback.yml @@ -15,7 +15,7 @@ jobs: test-next: strategy: matrix: - python-version: ["3.10"] + python-version: ["3.10", "3.11"] tox-env-factor: ["nomesh", "atlas"] os: ["ubuntu-latest"] diff --git a/.github/workflows/test-next.yml b/.github/workflows/test-next.yml index 52f8c25386..4282a22da6 100644 --- a/.github/workflows/test-next.yml +++ b/.github/workflows/test-next.yml @@ -18,7 +18,7 @@ jobs: test-next: strategy: matrix: - python-version: ["3.10"] + python-version: ["3.10", "3.11"] tox-env-factor: ["nomesh", "atlas"] os: ["ubuntu-latest"] fail-fast: false diff --git a/.github/workflows/test-storage-fallback.yml b/.github/workflows/test-storage-fallback.yml index 0cbc735564..99e4923de8 100644 --- a/.github/workflows/test-storage-fallback.yml +++ b/.github/workflows/test-storage-fallback.yml @@ -18,7 +18,7 @@ jobs: test-storage: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] backends: [internal-cpu, dace-cpu] os: ["ubuntu-latest"] diff --git a/.github/workflows/test-storage.yml b/.github/workflows/test-storage.yml index 1133353f30..34841ed71c 100644 --- a/.github/workflows/test-storage.yml +++ b/.github/workflows/test-storage.yml @@ -21,7 +21,7 @@ jobs: test-storage: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] backends: [internal-cpu, dace-cpu] os: ["ubuntu-latest"] fail-fast: false @@ -70,4 +70,3 @@ jobs: # with: # name: info-py${{ matrix.python-version }}-${{ matrix.os }} # path: info.txt - diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index 5660fdbf76..bc744b3ccc 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -883,17 +883,6 @@ def _substitute_typevars( return type_params_map[type_hint], True elif getattr(type_hint, "__parameters__", []): return type_hint[tuple(type_params_map[tp] for tp in type_hint.__parameters__)], True - # TODO(egparedes): WIP fix for partial specialization - # # Type hint is a generic model: replace all the concretized type vars - # noqa: e800 replaced = False - # noqa: e800 new_args = [] - # noqa: e800 for tp in type_hint.__parameters__: - # noqa: e800 if tp in type_params_map: - # noqa: e800 new_args.append(type_params_map[tp]) - # noqa: e800 replaced = True - # noqa: e800 else: - # noqa: e800 new_args.append(type_params_map[tp]) - # noqa: e800 return type_hint[tuple(new_args)], replaced else: return type_hint, False @@ -981,21 +970,14 @@ def __class_getitem__( """ type_args: Tuple[Type] = args if isinstance(args, tuple) else (args,) concrete_cls: Type[DataModelT] = concretize(cls, *type_args) - res = xtyping.StdGenericAliasType(concrete_cls, type_args) - if sys.version_info < (3, 9): - # in Python 3.8, xtyping.StdGenericAliasType (aka typing._GenericAlias) - # does not copy all required `__dict__` entries, so do it manually - for k, v in concrete_cls.__dict__.items(): - if k not in res.__dict__: - res.__dict__[k] = v - return res + return concrete_cls return classmethod(__class_getitem__) def _make_type_converter(type_annotation: TypeAnnotation, name: str) -> TypeConverter[_T]: - # TODO(egparedes): if a "typing tree" structure is implemented, refactor this code as a tree traversal. - # + # TODO(egparedes): if a "typing tree" structure is implemented, refactor this code + # as a tree traversal. if xtyping.is_actual_type(type_annotation) and not isinstance(None, type_annotation): assert not xtyping.get_args(type_annotation) assert isinstance(type_annotation, type) @@ -1316,11 +1298,7 @@ def _make_concrete_with_cache( # Replace field definitions with the new actual types for generic fields type_params_map = dict(zip(datamodel_cls.__parameters__, type_args)) model_fields = getattr(datamodel_cls, MODEL_FIELD_DEFINITIONS_ATTR) - new_annotations = { - # TODO(egparedes): ? - # noqa: e800 "__args__": "ClassVar[Tuple[Union[Type, TypeVar], ...]]", - # noqa: e800 "__parameters__": "ClassVar[Tuple[TypeVar, ...]]", - } + new_annotations = {} new_field_c_attrs = {} for field_name, field_type in xtyping.get_type_hints(datamodel_cls).items(): @@ -1353,8 +1331,16 @@ def _make_concrete_with_cache( "__module__": module if module else datamodel_cls.__module__, **new_field_c_attrs, } - concrete_cls = type(class_name, (datamodel_cls,), namespace) + + # Update the tuple of generic parameters in the new class, in case + # this is a partial concretization + assert hasattr(concrete_cls, "__parameters__") + concrete_cls.__parameters__ = tuple( + type_params_map[tp_var] + for tp_var in datamodel_cls.__parameters__ + if isinstance(type_params_map[tp_var], typing.TypeVar) + ) assert concrete_cls.__module__ == module or not module if MODEL_FIELD_DEFINITIONS_ATTR not in concrete_cls.__dict__: diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index 17462a37ff..3ee447ca6c 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -493,7 +493,7 @@ def _patched_proto_hook(other): # type: ignore[no-untyped-def] if isinstance(_typing.Any, type): # Python >= 3.11 _ArtefactTypes = (*_ArtefactTypes, _typing.Any) -# `Any` is a class since typing_extensions >= 4.4 +# `Any` is a class since typing_extensions >= 4.4 and Python 3.11 if (typing_exts_any := getattr(_typing_extensions, "Any", None)) is not _typing.Any and isinstance( typing_exts_any, type ): @@ -504,11 +504,13 @@ def is_actual_type(obj: Any) -> TypeGuard[Type]: """Check if an object has an actual type and instead of a typing artefact like ``GenericAlias`` or ``Any``. This is needed because since Python 3.9: - ``isinstance(types.GenericAlias(), type) is True`` + ``isinstance(types.GenericAlias(), type) is True`` and since Python 3.11: - ``isinstance(typing.Any, type) is True`` + ``isinstance(typing.Any, type) is True`` """ - return isinstance(obj, type) and type(obj) not in _ArtefactTypes + return ( + isinstance(obj, type) and (obj not in _ArtefactTypes) and (type(obj) not in _ArtefactTypes) + ) if hasattr(_typing_extensions, "Any") and _typing.Any is not _typing_extensions.Any: # type: ignore[attr-defined] # _typing_extensions.Any only from >= 4.4 @@ -641,9 +643,12 @@ def get_partial_type_hints( resolved_hints = get_type_hints( # type: ignore[call-arg] # Python 3.8 does not define `include-extras` obj, globalns=globalns, localns=localns, include_extras=include_extras ) - hints.update(resolved_hints) + hints[name] = resolved_hints[name] except NameError as error: if isinstance(hint, str): + # This conversion could be probably skipped in Python versions containing + # the fix applied in bpo-41370. Check: + # https://github.com/python/cpython/commit/b465b606049f6f7dd0711cb031fdaa251818741a#diff-ddb987fca5f5df0c9a2f5521ed687919d70bb3d64eaeb8021f98833a2a716887R344 hints[name] = ForwardRef(hint) elif isinstance(hint, (ForwardRef, _typing.ForwardRef)): hints[name] = hint diff --git a/tests/eve_tests/unit_tests/test_datamodels.py b/tests/eve_tests/unit_tests/test_datamodels.py index 8fa9e02cb6..0abb893dd4 100644 --- a/tests/eve_tests/unit_tests/test_datamodels.py +++ b/tests/eve_tests/unit_tests/test_datamodels.py @@ -15,6 +15,7 @@ from __future__ import annotations import enum +import numbers import types import typing from typing import Set # noqa: F401 # imported but unused (used in exec() context) @@ -1150,66 +1151,80 @@ class PartialGenericModel(datamodels.GenericDataModel, Generic[T]): with pytest.raises(TypeError, match="'PartialGenericModel__int.value'"): PartialGenericModel__int(value=["1"]) - def test_partial_specialization(self): - class PartialGenericModel(datamodels.GenericDataModel, Generic[T, U]): + def test_partial_concretization(self): + class BaseGenericModel(datamodels.GenericDataModel, Generic[T, U]): value: List[Tuple[T, U]] - PartialGenericModel(value=[]) - PartialGenericModel(value=[("value", 3)]) - PartialGenericModel(value=[(1, "value")]) - PartialGenericModel(value=[(-1.0, "value")]) - with pytest.raises(TypeError, match="'PartialGenericModel.value'"): - PartialGenericModel(value=1) - with pytest.raises(TypeError, match="'PartialGenericModel.value'"): - PartialGenericModel(value=(1, 2)) - with pytest.raises(TypeError, match="'PartialGenericModel.value'"): - PartialGenericModel(value=[()]) - with pytest.raises(TypeError, match="'PartialGenericModel.value'"): - PartialGenericModel(value=[(1,)]) + assert len(BaseGenericModel.__parameters__) == 2 + + BaseGenericModel(value=[]) + BaseGenericModel(value=[("value", 3)]) + BaseGenericModel(value=[(1, "value")]) + BaseGenericModel(value=[(-1.0, "value")]) + with pytest.raises(TypeError, match="'BaseGenericModel.value'"): + BaseGenericModel(value=1) + with pytest.raises(TypeError, match="'BaseGenericModel.value'"): + BaseGenericModel(value=(1, 2)) + with pytest.raises(TypeError, match="'BaseGenericModel.value'"): + BaseGenericModel(value=[()]) + with pytest.raises(TypeError, match="'BaseGenericModel.value'"): + BaseGenericModel(value=[(1,)]) + + PartiallyConcretizedGenericModel = BaseGenericModel[int, U] + + assert len(PartiallyConcretizedGenericModel.__parameters__) == 1 + + PartiallyConcretizedGenericModel(value=[]) + PartiallyConcretizedGenericModel(value=[(1, 2)]) + PartiallyConcretizedGenericModel(value=[(1, "value")]) + PartiallyConcretizedGenericModel(value=[(1, (11, 12))]) + with pytest.raises(TypeError, match=".value'"): + PartiallyConcretizedGenericModel(value=1) + with pytest.raises(TypeError, match=".value'"): + PartiallyConcretizedGenericModel(value=(1, 2)) + with pytest.raises(TypeError, match=".value'"): + PartiallyConcretizedGenericModel(value=[1.0]) + with pytest.raises(TypeError, match=".value'"): + PartiallyConcretizedGenericModel(value=["1"]) - print(f"{PartialGenericModel.__parameters__=}") - print(f"{hasattr(PartialGenericModel ,'__args__')=}") + FullyConcretizedGenericModel = PartiallyConcretizedGenericModel[str] - PartiallySpecializedGenericModel = PartialGenericModel[int, U] - print(f"{PartiallySpecializedGenericModel.__datamodel_fields__=}") - print(f"{PartiallySpecializedGenericModel.__parameters__=}") - print(f"{PartiallySpecializedGenericModel.__args__=}") + assert len(FullyConcretizedGenericModel.__parameters__) == 0 - PartiallySpecializedGenericModel(value=[]) - PartiallySpecializedGenericModel(value=[(1, 2)]) - PartiallySpecializedGenericModel(value=[(1, "value")]) - PartiallySpecializedGenericModel(value=[(1, (11, 12))]) + FullyConcretizedGenericModel(value=[]) + FullyConcretizedGenericModel(value=[(1, "value")]) + with pytest.raises(TypeError, match=".value'"): + FullyConcretizedGenericModel(value=1) + with pytest.raises(TypeError, match=".value'"): + FullyConcretizedGenericModel(value=(1, 2)) with pytest.raises(TypeError, match=".value'"): - PartiallySpecializedGenericModel(value=1) + FullyConcretizedGenericModel(value=[1.0]) with pytest.raises(TypeError, match=".value'"): - PartiallySpecializedGenericModel(value=(1, 2)) + FullyConcretizedGenericModel(value=["1"]) with pytest.raises(TypeError, match=".value'"): - PartiallySpecializedGenericModel(value=[1.0]) + FullyConcretizedGenericModel(value=1) with pytest.raises(TypeError, match=".value'"): - PartiallySpecializedGenericModel(value=["1"]) - - # TODO(egparedes): after fixing partial nested datamodel specialization - # noqa: e800 FullySpecializedGenericModel = PartiallySpecializedGenericModel[str] - # noqa: e800 print(f"{FullySpecializedGenericModel.__datamodel_fields__=}") - # noqa: e800 print(f"{FullySpecializedGenericModel.__parameters__=}") - # noqa: e800 print(f"{FullySpecializedGenericModel.__args__=}") - - # noqa: e800 FullySpecializedGenericModel(value=[]) - # noqa: e800 FullySpecializedGenericModel(value=[(1, "value")]) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=1) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=(1, 2)) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=[1.0]) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=["1"]) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=1) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=[(1, 2)]) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=[(1, (11, 12))]) + FullyConcretizedGenericModel(value=[(1, 2)]) + with pytest.raises(TypeError, match=".value'"): + FullyConcretizedGenericModel(value=[(1, (11, 12))]) + + def test_partial_concretization_with_typevar(self): + class PartialGenericModel(datamodels.GenericDataModel, Generic[T]): + a: T + values: List[T] + + B = TypeVar("B", bound=numbers.Number) + PartiallyConcretizedGenericModel = PartialGenericModel[B] + + PartiallyConcretizedGenericModel(a=1, values=[2, 3]) + PartiallyConcretizedGenericModel(a=-1.32, values=[2.2, 3j]) + + with pytest.raises(TypeError, match=".a'"): + PartiallyConcretizedGenericModel(a="1", values=[2, 3]) + with pytest.raises(TypeError, match=".values'"): + PartiallyConcretizedGenericModel(a=1, values=[1, "2"]) + with pytest.raises(TypeError, match=".values'"): + PartiallyConcretizedGenericModel(a=1, values=(1, 2)) # Reuse sample_type_data from test_field_type_hint @pytest.mark.parametrize(["type_hint", "valid_values", "wrong_values"], SAMPLE_TYPE_DATA) diff --git a/tests/eve_tests/unit_tests/test_type_validation.py b/tests/eve_tests/unit_tests/test_type_validation.py index 70ef033ff0..d9977f0d3a 100644 --- a/tests/eve_tests/unit_tests/test_type_validation.py +++ b/tests/eve_tests/unit_tests/test_type_validation.py @@ -28,6 +28,7 @@ ) from gt4py.eve.extended_typing import ( Any, + Callable, Dict, Final, ForwardRef, @@ -41,8 +42,8 @@ ) -VALIDATORS: Final = [type_val.simple_type_validator] -FACTORIES: Final = [type_val.simple_type_validator_factory] +VALIDATORS: Final[list[Callable]] = [type_val.simple_type_validator] +FACTORIES: Final[list[Callable]] = [type_val.simple_type_validator_factory] class SampleEnum(enum.Enum): diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 130f6bd29c..5bd255f80f 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from dataclasses import dataclass +import dataclasses import numpy as np import pytest @@ -201,22 +201,26 @@ def test_setup(fieldview_backend): grid_type=common.GridType.UNSTRUCTURED, ) - @dataclass(frozen=True) + @dataclasses.dataclass(frozen=True) class setup: - case: cases.Case = test_case - cell_size = case.default_sizes[Cell] - k_size = case.default_sizes[KDim] - z_alpha = case.as_field( + case: cases.Case = dataclasses.field(default_factory=lambda: test_case) + cell_size = test_case.default_sizes[Cell] + k_size = test_case.default_sizes[KDim] + z_alpha = test_case.as_field( [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size + 1)) ) - z_beta = case.as_field( + z_beta = test_case.as_field( + [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) + ) + z_q = test_case.as_field( + [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) + ) + w = test_case.as_field( [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) ) - z_q = case.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) - w = case.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) z_q_ref, w_ref = reference(z_alpha.ndarray, z_beta.ndarray, z_q.ndarray, w.ndarray) - dummy = case.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool)) - z_q_out = case.as_field([Cell, KDim], np.zeros((cell_size, k_size))) + dummy = test_case.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool)) + z_q_out = test_case.as_field([Cell, KDim], np.zeros((cell_size, k_size))) return setup() diff --git a/tests/next_tests/unit_tests/ffront_tests/ast_passes_tests/test_single_static_assign.py b/tests/next_tests/unit_tests/ffront_tests/ast_passes_tests/test_single_static_assign.py index 052f272d22..ea1cdb82a6 100644 --- a/tests/next_tests/unit_tests/ffront_tests/ast_passes_tests/test_single_static_assign.py +++ b/tests/next_tests/unit_tests/ffront_tests/ast_passes_tests/test_single_static_assign.py @@ -108,7 +108,10 @@ def test_unpacking_swap(): lines = ast.unparse(ssa_ast).split("\n") assert lines[0] == f"a{SEP}0 = 5" assert lines[1] == f"b{SEP}0 = 1" - assert lines[2] == f"(b{SEP}1, a{SEP}1) = (a{SEP}0, b{SEP}0)" + assert lines[2] in [ + f"(b{SEP}1, a{SEP}1) = (a{SEP}0, b{SEP}0)", + f"b{SEP}1, a{SEP}1 = (a{SEP}0, b{SEP}0)", + ] # unparse produces different parentheses in different Python versions def test_annotated_assign(): diff --git a/tox.ini b/tox.ini index 44dc912c8a..817f721f71 100644 --- a/tox.ini +++ b/tox.ini @@ -11,21 +11,24 @@ envlist = # docs labels = test-cartesian-cpu = cartesian-py38-internal-cpu, cartesian-py39-internal-cpu, cartesian-py310-internal-cpu, \ - cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu + cartesian-py311-internal-cpu, cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu, \ + cartesian-py311-dace-cpu - test-eve-cpu = eve-py38, eve-py39, eve-py310 + test-eve-cpu = eve-py38, eve-py39, eve-py310, eve-py311 - test-next-cpu = next-py310-nomesh, next-py310-atlas + test-next-cpu = next-py310-nomesh, next-py311-nomesh, next-py310-atlas, next-py311-atlas test-storage-cpu = storage-py38-internal-cpu, storage-py39-internal-cpu, storage-py310-internal-cpu, \ - storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu + storage-py311-internal-cpu, storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu, \ + storage-py311-dace-cpu test-cpu = cartesian-py38-internal-cpu, cartesian-py39-internal-cpu, cartesian-py310-internal-cpu, \ - cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu, \ - eve-py38, eve-py39, eve-py310, \ - next-py310-nomesh, next-py310-atlas, \ - storage-py38-internal-cpu, storage-py39-internal-cpu, storage-py310-internal-cpu, \ - storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu + cartesian-py311-internal-cpu, cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu, \ + cartesian-py311-dace-cpu, \ + eve-py38, eve-py39, eve-py310, eve-py311, \ + next-py310-nomesh, next-py311-nomesh, next-py310-atlas, next-py311-atlas, \ + storage-py38-internal-cpu, storage-py39-internal-cpu, storage-py310-internal-cpu, storage-py311-internal-cpu, \ + storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu, storage-py311-dace-cpu [testenv] deps = -r {tox_root}{/}{env:ENV_REQUIREMENTS_FILE:requirements-dev.txt} @@ -44,7 +47,7 @@ pass_env = NUM_PROCESSES set_env = PYTHONWARNINGS = {env:PYTHONWARNINGS:ignore:Support for `[tool.setuptools]` in `pyproject.toml` is still *beta*:UserWarning} -[testenv:cartesian-py{38,39,310}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] +[testenv:cartesian-py{38,39,310,311}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] description = Run 'gt4py.cartesian' tests pass_env = {[testenv]pass_env}, BOOST_ROOT, BOOST_HOME, CUDA_HOME, CUDA_PATH, CXX, CC, OPENMP_CPPFLAGS, OPENMP_LDFLAGS, PIP_USER, PYTHONUSERBASE allowlist_externals = @@ -65,13 +68,13 @@ commands = ; coverage json --rcfile=setup.cfg ; coverage html --rcfile=setup.cfg --show-contexts -[testenv:eve-py{38,39,310}] +[testenv:eve-py{38,39,310,311}] description = Run 'gt4py.eve' tests commands = python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} {posargs} tests{/}eve_tests python -m pytest --doctest-modules src{/}gt4py{/}eve -[testenv:next-py{310}-{nomesh,atlas}-{cpu,cuda,cuda11x,cuda12x}] +[testenv:next-py{310,311}-{nomesh,atlas}-{cpu,cuda,cuda11x,cuda12x}] description = Run 'gt4py.next' tests pass_env = {[testenv]pass_env}, BOOST_ROOT, BOOST_HOME, CUDA_HOME, CUDA_PATH deps = @@ -87,14 +90,14 @@ commands = # atlas-{cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and requires_gpu" {posargs} tests{/}next_tests # TODO(ricoh): activate when such tests exist pytest --doctest-modules src{/}gt4py{/}next -[testenv:storage-py{38,39,310}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] +[testenv:storage-py{38,39,310,311}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] description = Run 'gt4py.storage' tests commands = cpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_gpu" {posargs} tests{/}storage_tests {cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_gpu" {posargs} tests{/}storage_tests #pytest doctest-modules {posargs} src{/}gt4py{/}storage -[testenv:linters-py{38,39,310}] +[testenv:linters-py{38,39,310,311}] description = Run linters commands = flake8 .{/}src @@ -134,11 +137,13 @@ description = py38: Update requirements for testing a specific python version py39: Update requirements for testing a specific python version py310: Update requirements for testing a specific python version + py311: Update requirements for testing a specific python version base_python = common: py38 py38: py38 py39: py39 py310: py310 + py311: py311 deps = cogapp>=3.3 pip-tools>=6.10 @@ -178,7 +183,7 @@ commands = # Run cog to update .pre-commit-config.yaml with new versions common: cog -r -P .pre-commit-config.yaml -[testenv:dev-py{38,39,310}{-atlas,}] +[testenv:dev-py{38,39,310,311}{-atlas,}] description = Initialize development environment for gt4py deps = -r {tox_root}{/}requirements-dev.txt From e20294efe31f911567713b62ac70e161cc49255f Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 22 Jan 2024 10:21:46 +0100 Subject: [PATCH 09/11] bug[next]: Improve error message on invalid call to field operator and program (#1323) After #1275 most of the error message given to the user when calling a field operator or program with invalid arguments was only available in verbose mode. This PR shows this information again. ```python @field_operator def foo(x: IField): return x @field_operator def testee(a: IField, b: IField, c: IField) -> IField: return foo(1) ``` ``` gt4py.next.errors.exceptions.DSLError: Invalid argument types in call to `foo`. E Invalid call to function of type `FieldOperatorType(definition=FunctionType(pos_only_args=[], pos_or_kw_args={'x': FieldType(dims=[Dimension(value='IDim', kind=)], dtype=ScalarType(kind=, shape=None))}, kw_only_args={}, returns=FieldType(dims=[Dimension(value='IDim', kind=)], dtype=ScalarType(kind=, shape=None))))`: E - Expected argument `x` to be of type `Field[[IDim], int32]`, but got `int32`. E File ".../gt4py_functional/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py", line 113 E return foo(1) ``` --- src/gt4py/next/ffront/decorator.py | 4 +++- src/gt4py/next/ffront/foast_passes/type_deduction.py | 2 +- src/gt4py/next/ffront/past_passes/type_deduction.py | 2 +- .../feature_tests/ffront_tests/test_program.py | 3 ++- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 05cbe1c882..9f8537f59b 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -344,7 +344,9 @@ def _validate_args(self, *args, **kwargs) -> None: raise_exception=True, ) except ValueError as err: - raise TypeError(f"Invalid argument types in call to '{self.past_node.id}'.") from err + raise errors.DSLError( + None, f"Invalid argument types in call to '{self.past_node.id}'.\n{err}" + ) from err def _process_args(self, args: tuple, kwargs: dict) -> tuple[tuple, tuple, dict[str, Any]]: self._validate_args(*args, **kwargs) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 639e5ff009..5e289af664 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -694,7 +694,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: ) except ValueError as err: raise errors.DSLError( - node.location, f"Invalid argument types in call to '{new_func}'." + node.location, f"Invalid argument types in call to '{new_func}'.\n{err}" ) from err return_type = type_info.return_type(func_type, with_args=arg_types, with_kwargs=kwarg_types) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index fc353d64e4..af8f5e8368 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -229,7 +229,7 @@ def visit_Call(self, node: past.Call, **kwargs): ) except ValueError as ex: - raise errors.DSLError(node.location, f"Invalid call to '{node.func.id}'.") from ex + raise errors.DSLError(node.location, f"Invalid call to '{node.func.id}'.\n{ex}") from ex return past.Call( func=new_func, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index c86881ab7c..938c69fb52 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -20,6 +20,7 @@ import pytest import gt4py.next as gtx +from gt4py.next import errors from next_tests.integration_tests import cases from next_tests.integration_tests.cases import IDim, Ioff, JDim, cartesian_case, fieldview_backend @@ -222,7 +223,7 @@ def test_wrong_argument_type(cartesian_case, copy_program_def): inp = cartesian_case.as_field([JDim], np.ones((cartesian_case.default_sizes[JDim],))) out = cases.allocate(cartesian_case, copy_program, "out").strategy(cases.ConstInitializer(1))() - with pytest.raises(TypeError) as exc_info: + with pytest.raises(errors.DSLError) as exc_info: # program is defined on Field[[IDim], ...], but we call with # Field[[JDim], ...] copy_program(inp, out, offset_provider={}) From e6f41605ee7ae5a791d8a5d3557b4f45e0f511a8 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 22 Jan 2024 18:35:57 +0100 Subject: [PATCH 10/11] Update AUTHORS.md --- AUTHORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/AUTHORS.md b/AUTHORS.md index 89aafb9971..0fd0098fc4 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -9,6 +9,7 @@ - Deconinck, Florian. SSAI/NASA-GSFC - Ehrengruber, Till. ETH Zurich - CSCS - Elbert, Oliver D. NOAA-GFDL +- Faghih-Naini, Sara - ECMWF - Farabullini, Nicoletta. ETH Zurich - C2SM - George, Rhea. Allen Institute for AI - González Paredes, Enrique. ETH Zurich - CSCS From 8bd5a41e9d27409472442c6ad8a8a7908953b265 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 22 Jan 2024 18:37:04 +0100 Subject: [PATCH 11/11] Update AUTHORS.md --- AUTHORS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AUTHORS.md b/AUTHORS.md index 0fd0098fc4..6c76e5759e 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -9,7 +9,7 @@ - Deconinck, Florian. SSAI/NASA-GSFC - Ehrengruber, Till. ETH Zurich - CSCS - Elbert, Oliver D. NOAA-GFDL -- Faghih-Naini, Sara - ECMWF +- Faghih-Naini, Sara. ECMWF - Farabullini, Nicoletta. ETH Zurich - C2SM - George, Rhea. Allen Institute for AI - González Paredes, Enrique. ETH Zurich - CSCS