From 3ece412f0d78f32893d8f01ed0e74c8b38388854 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 28 Nov 2024 13:13:55 -0500 Subject: [PATCH 1/8] fix[cartesian]: Deactivate K offset write in `gt:gpu` (#1755) Following the issue logged as https://github.com/GridTools/gt4py/issues/1754 we are deactivating the K-offset write feature until we can figure out why it's failing. I will monitor any activity on the ticket if users are hit by this. --------- Co-authored-by: Hannes Vogt --- src/gt4py/cartesian/frontend/gtscript_frontend.py | 7 +++++++ .../multi_feature_tests/test_code_generation.py | 4 ++++ 2 files changed, 11 insertions(+) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index ade05921ef..f155ea6209 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -1460,6 +1460,13 @@ def visit_Assign(self, node: ast.Assign) -> list: loc=nodes.Location.from_ast_node(t), ) + if self.backend_name in ["gt:gpu"]: + raise GTScriptSyntaxError( + message=f"Assignment to non-zero offsets in K is not available in {self.backend_name} as an unsolved bug remains." + "Please refer to https://github.com/GridTools/gt4py/issues/1754.", + loc=nodes.Location.from_ast_node(t), + ) + if not self._is_known(name): if name in self.temp_decls: field_decl = self.temp_decls[name] diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index c4d07d7337..7c4956b3ef 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -667,6 +667,10 @@ def test_K_offset_write_conditional(backend): pytest.skip( f"{backend} backend with CUDA 11 and/or GCC 10.3 is not capable of K offset write, update CUDA/GCC if possible" ) + if backend in ["gt:gpu"]: + pytest.skip( + f"{backend} backend is not capable of K offset write, bug remains unsolved: https://github.com/GridTools/gt4py/issues/1754" + ) arraylib = get_array_library(backend) array_shape = (1, 1, 4) From 886058496c1ebcb90ba530a796213d1fec7c7095 Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 29 Nov 2024 08:46:06 +0100 Subject: [PATCH 2/8] refact[next][dace]: Helper function for field operator constructor (#1743) Includes refactoring of the code for construction of field operators, in order to make it usable by the three lowering functions that construct fields: `translate_as_fieldop()`, `translate_broadcast_scalar()`, and `translate_index()`. --- .../gtir_builtin_translators.py | 242 +++++++----------- 1 file changed, 94 insertions(+), 148 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 94ab3a6f76..ff011c4193 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -18,7 +18,11 @@ from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -229,40 +233,75 @@ def _get_field_layout( return list(domain_dims), list(domain_lbs), domain_sizes -def _create_temporary_field( +def _create_field_operator( sdfg: dace.SDFG, state: dace.SDFGState, domain: FieldopDomain, node_type: ts.FieldType, - dataflow_output: gtir_dataflow.DataflowOutputEdge, + sdfg_builder: gtir_sdfg.SDFGBuilder, + input_edges: Sequence[gtir_dataflow.DataflowInputEdge], + output_edge: gtir_dataflow.DataflowOutputEdge, ) -> FieldopData: - """Helper method to allocate a temporary field where to write the output of a field operator.""" + """ + Helper method to allocate a temporary field to store the output of a field operator. + + Args: + sdfg: The SDFG that represents the scope of the field data. + state: The SDFG state where to create an access node to the field data. + domain: The domain of the field operator that computes the field. + node_type: The GT4Py type of the IR node that produces this field. + sdfg_builder: The object used to build the map scope in the provided SDFG. + input_edges: List of edges to pass input data into the dataflow. + output_edge: Edge representing the dataflow output data. + + Returns: + The field data descriptor, which includes the field access node in the given `state` + and the field domain offset. + """ field_dims, field_offset, field_shape = _get_field_layout(domain) + field_indices = _get_domain_indices(field_dims, field_offset) + + dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) - output_desc = dataflow_output.result.dc_node.desc(sdfg) - if isinstance(output_desc, dace.data.Array): + field_subset = sbs.Range.from_indices(field_indices) + if isinstance(output_edge.result.gt_dtype, ts.ScalarType): + assert output_edge.result.gt_dtype == node_type.dtype + assert isinstance(dataflow_output_desc, dace.data.Scalar) + assert dataflow_output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) + field_dtype = output_edge.result.gt_dtype + else: assert isinstance(node_type.dtype, itir_ts.ListType) - assert isinstance(node_type.dtype.element_type, ts.ScalarType) - assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype.element_type) + assert output_edge.result.gt_dtype.element_type == node_type.dtype.element_type + assert isinstance(dataflow_output_desc, dace.data.Array) + assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) + field_dtype = output_edge.result.gt_dtype.element_type # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) - field_offset.extend(output_desc.offset) - field_shape.extend(output_desc.shape) - elif isinstance(output_desc, dace.data.Scalar): - assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) - else: - raise ValueError(f"Cannot create field for dace type {output_desc}.") + assert output_edge.result.gt_dtype.offset_type is not None + field_dims.append(output_edge.result.gt_dtype.offset_type) + field_shape.extend(dataflow_output_desc.shape) + field_offset.extend(dataflow_output_desc.offset) + field_subset = field_subset + sbs.Range.from_array(dataflow_output_desc) # allocate local temporary storage - temp_name, _ = sdfg.add_temp_transient(field_shape, output_desc.dtype) - field_node = state.add_access(temp_name) + field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) + field_node = state.add_access(field_name) - if isinstance(dataflow_output.result.gt_dtype, ts.ScalarType): - field_dtype = dataflow_output.result.gt_dtype - else: - assert isinstance(dataflow_output.result.gt_dtype.element_type, ts.ScalarType) - field_dtype = dataflow_output.result.gt_dtype.element_type - assert dataflow_output.result.gt_dtype.offset_type is not None - field_dims.append(dataflow_output.result.gt_dtype.offset_type) + # create map range corresponding to the field operator domain + me, mx = sdfg_builder.add_map( + "fieldop", + state, + ndrange={ + dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" + for dim, lower_bound, upper_bound in domain + }, + ) + + # here we setup the edges passing through the map entry node + for edge in input_edges: + edge.connect(me) + + # and here the edge writing the dataflow result data through the map exit node + output_edge.connect(mx, field_node, field_subset) return FieldopData( field_node, @@ -341,7 +380,8 @@ def translate_as_fieldop( # Special usage of 'deref' as argument to fieldop expression, to pass a scalar # value to 'as_fieldop' function. It results in broadcasting the scalar value # over the field domain. - return translate_broadcast_scalar(node, sdfg, state, sdfg_builder) + stencil_expr = im.lambda_("a")(im.deref("a")) + stencil_expr.expr.type = node.type.dtype # type: ignore[attr-defined] else: raise NotImplementedError( f"Expression type '{type(stencil_expr)}' not supported as argument to 'as_fieldop' node." @@ -349,117 +389,18 @@ def translate_as_fieldop( # parse the domain of the field operator domain = extract_domain(domain_expr) - domain_dims, domain_offsets, _ = zip(*domain) - domain_indices = _get_domain_indices(domain_dims, domain_offsets) # visit the list of arguments to be passed to the lambda expression stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder) - input_edges, output = taskgen.visit(stencil_expr, args=stencil_args) - output_desc = output.result.dc_node.desc(sdfg) - - if isinstance(node.type.dtype, itir_ts.ListType): - assert isinstance(output_desc, dace.data.Array) - # additional local dimension for neighbors - # TODO(phimuell): Investigate if we should swap the two. - output_subset = sbs.Range.from_indices(domain_indices) + sbs.Range.from_array(output_desc) - else: - assert isinstance(output_desc, dace.data.Scalar) - output_subset = sbs.Range.from_indices(domain_indices) - - # create map range corresponding to the field operator domain - me, mx = sdfg_builder.add_map( - "fieldop", - state, - ndrange={ - dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" - for dim, lower_bound, upper_bound in domain - }, - ) - - # allocate local temporary storage for the result field - result_field = _create_temporary_field(sdfg, state, domain, node.type, output) - - # here we setup the edges from the map entry node - for edge in input_edges: - edge.connect(me) - - # and here the edge writing the result data through the map exit node - output.connect(mx, result_field.dc_node, output_subset) - - return result_field - + input_edges, output_edge = taskgen.visit(stencil_expr, args=stencil_args) -def translate_broadcast_scalar( - node: gtir.Node, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: gtir_sdfg.SDFGBuilder, -) -> FieldopResult: - """ - Generates the dataflow subgraph for the 'as_fieldop' builtin function for the - special case where the argument to 'as_fieldop' is a 'deref' scalar expression, - rather than a lambda function. This case corresponds to broadcasting the scalar - value over the field domain. Therefore, it is lowered to a mapped tasklet that - just writes the scalar value out to all elements of the result field. - """ - assert isinstance(node, gtir.FunCall) - assert cpm.is_call_to(node.fun, "as_fieldop") - assert isinstance(node.type, ts.FieldType) - - fun_node = node.fun - assert len(fun_node.args) == 2 - stencil_expr, domain_expr = fun_node.args - assert cpm.is_ref_to(stencil_expr, "deref") - - domain = extract_domain(domain_expr) - output_dims, output_offset, output_shape = _get_field_layout(domain) - output_subset = sbs.Range.from_indices(_get_domain_indices(output_dims, output_offset)) - - assert len(node.args) == 1 - scalar_expr = _parse_fieldop_arg(node.args[0], sdfg, state, sdfg_builder, domain) - - if isinstance(node.args[0].type, ts.ScalarType): - assert isinstance(scalar_expr, (gtir_dataflow.MemletExpr, gtir_dataflow.ValueExpr)) - input_subset = ( - str(scalar_expr.subset) if isinstance(scalar_expr, gtir_dataflow.MemletExpr) else "0" - ) - input_node = scalar_expr.dc_node - gt_dtype = node.args[0].type - elif isinstance(node.args[0].type, ts.FieldType): - assert isinstance(scalar_expr, gtir_dataflow.IteratorExpr) - if len(node.args[0].type.dims) == 0: # zero-dimensional field - input_subset = "0" - else: - input_subset = scalar_expr.get_memlet_subset(sdfg) - - input_node = scalar_expr.field - gt_dtype = node.args[0].type.dtype - else: - raise ValueError(f"Unexpected argument {node.args[0]} in broadcast expression.") - - output, _ = sdfg.add_temp_transient(output_shape, input_node.desc(sdfg).dtype) - output_node = state.add_access(output) - - sdfg_builder.add_mapped_tasklet( - "broadcast", - state, - map_ranges={ - dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" - for dim, lower_bound, upper_bound in domain - }, - inputs={"__inp": dace.Memlet(data=input_node.data, subset=input_subset)}, - code="__val = __inp", - outputs={"__val": dace.Memlet(data=output_node.data, subset=output_subset)}, - input_nodes={input_node.data: input_node}, - output_nodes={output_node.data: output_node}, - external_edges=True, + return _create_field_operator( + sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge ) - return FieldopData(output_node, ts.FieldType(output_dims, gt_dtype), output_offset) - def translate_if( node: gtir.Node, @@ -567,38 +508,44 @@ def translate_index( index values to a transient array. The extent of the index range is taken from the domain information that should be present in the node annex. """ + assert cpm.is_call_to(node, "index") + assert isinstance(node.type, ts.FieldType) + assert "domain" in node.annex domain = extract_domain(node.annex.domain) assert len(domain) == 1 - dim, lower_bound, upper_bound = domain[0] + dim, _, _ = domain[0] dim_index = dace_gtir_utils.get_map_variable(dim) - field_dims, field_offset, field_shape = _get_field_layout(domain) - field_type = ts.FieldType(field_dims, dace_utils.as_itir_type(INDEX_DTYPE)) - - output, _ = sdfg.add_temp_transient(field_shape, INDEX_DTYPE) - output_node = state.add_access(output) - - sdfg_builder.add_mapped_tasklet( + index_data = sdfg.temp_data_name() + sdfg.add_scalar(index_data, INDEX_DTYPE, transient=True) + index_node = state.add_access(index_data) + index_value = gtir_dataflow.ValueExpr( + dc_node=index_node, + gt_dtype=dace_utils.as_itir_type(INDEX_DTYPE), + ) + index_write_tasklet = sdfg_builder.add_tasklet( "index", state, - map_ranges={ - dim_index: f"{lower_bound}:{upper_bound}", - }, inputs={}, + outputs={"__val"}, code=f"__val = {dim_index}", - outputs={ - "__val": dace.Memlet( - data=output_node.data, - subset=sbs.Range.from_indices(_get_domain_indices(field_dims, field_offset)), - ) - }, - input_nodes={}, - output_nodes={output_node.data: output_node}, - external_edges=True, + ) + state.add_edge( + index_write_tasklet, + "__val", + index_node, + None, + dace.Memlet(data=index_data, subset="0"), ) - return FieldopData(output_node, field_type, field_offset) + input_edges = [ + gtir_dataflow.EmptyInputEdge(state, index_write_tasklet), + ] + output_edge = gtir_dataflow.DataflowOutputEdge(state, index_value) + return _create_field_operator( + sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge + ) def _get_data_nodes( @@ -831,7 +778,6 @@ def translate_symbol_ref( # Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol __primitive_translators: list[PrimitiveTranslator] = [ translate_as_fieldop, - translate_broadcast_scalar, translate_if, translate_index, translate_literal, From d9b38f476ee5df1995d27b7497037f3f19c9b6e6 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 29 Nov 2024 02:50:43 -0500 Subject: [PATCH 3/8] hotfix[cartesian]: Fixing k offset write utest deactivate (#1757) Missed a utest in #1755 --- .../multi_feature_tests/test_code_generation.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 7c4956b3ef..e51b3ef09d 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -582,13 +582,17 @@ def test_K_offset_write(backend): # Cuda generates bad code for the K offset if backend == "cuda": pytest.skip("cuda K-offset write generates bad code") - if backend in ["gt:gpu", "dace:gpu"]: + if backend in ["dace:gpu"]: import cupy as cp if cp.cuda.runtime.runtimeGetVersion() < 12000: pytest.skip( f"{backend} backend with CUDA 11 and/or GCC 10.3 is not capable of K offset write, update CUDA/GCC if possible" ) + if backend in ["gt:gpu"]: + pytest.skip( + f"{backend} backend is not capable of K offset write, bug remains unsolved: https://github.com/GridTools/gt4py/issues/1754" + ) arraylib = get_array_library(backend) array_shape = (1, 1, 4) @@ -660,7 +664,7 @@ def backward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): def test_K_offset_write_conditional(backend): if backend == "cuda": pytest.skip("Cuda backend is not capable of K offset write") - if backend in ["gt:gpu", "dace:gpu"]: + if backend in ["dace:gpu"]: import cupy as cp if cp.cuda.runtime.runtimeGetVersion() < 12000: From 791f67d031127872fc6375819267f59faeaf85ba Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 29 Nov 2024 10:02:34 +0100 Subject: [PATCH 4/8] test[next]: Fix flaky failure in GTIR to SDFG tests (#1759) The SDFG name has to be unique to avoid issues with parallel build in CI tests. --- .../runners_tests/dace_tests/test_gtir_to_sdfg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index c7466b853f..b1ba4ccf22 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -1984,7 +1984,7 @@ def test_gtir_index(): ) testee = gtir.Program( - id="gtir_cast", + id="gtir_index", function_definitions=[], params=[ gtir.Sym(id="x", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)), From 04513ba859d5ed55ea99999f6fd826a2a542a627 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Fri, 29 Nov 2024 13:57:10 +0100 Subject: [PATCH 5/8] fix[next]: use current working directory as default cache folder root (#1744) Change the root folder of the gt4py cache directory from the system temp folder to the current working directory, which is more visible and also avoids polluting shared filesystems in hpc clusters. --------- Co-authored-by: Hannes Vogt --- src/gt4py/next/config.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/config.py b/src/gt4py/next/config.py index ed244c2932..7a19f3eb9d 100644 --- a/src/gt4py/next/config.py +++ b/src/gt4py/next/config.py @@ -11,7 +11,6 @@ import enum import os import pathlib -import tempfile from typing import Final @@ -51,25 +50,22 @@ def env_flag_to_bool(name: str, default: bool) -> bool: ) -_PREFIX: Final[str] = "GT4PY" - #: Master debug flag #: Changes defaults for all the other options to be as helpful for debugging as possible. #: Does not override values set in environment variables. -DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=False) +DEBUG: Final[bool] = env_flag_to_bool("GT4PY_DEBUG", default=False) #: Verbose flag for DSL compilation errors VERBOSE_EXCEPTIONS: bool = env_flag_to_bool( - f"{_PREFIX}_VERBOSE_EXCEPTIONS", default=True if DEBUG else False + "GT4PY_VERBOSE_EXCEPTIONS", default=True if DEBUG else False ) #: Where generated code projects should be persisted. #: Only active if BUILD_CACHE_LIFETIME is set to PERSISTENT BUILD_CACHE_DIR: pathlib.Path = ( - pathlib.Path(os.environ.get(f"{_PREFIX}_BUILD_CACHE_DIR", tempfile.gettempdir())) - / "gt4py_cache" + pathlib.Path(os.environ.get("GT4PY_BUILD_CACHE_DIR", pathlib.Path.cwd())) / ".gt4py_cache" ) @@ -77,11 +73,11 @@ def env_flag_to_bool(name: str, default: bool) -> bool: #: - SESSION: generated code projects get destroyed when the interpreter shuts down #: - PERSISTENT: generated code projects are written to BUILD_CACHE_DIR and persist between runs BUILD_CACHE_LIFETIME: BuildCacheLifetime = BuildCacheLifetime[ - os.environ.get(f"{_PREFIX}_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper() + os.environ.get("GT4PY_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper() ] #: Build type to be used when CMake is used to compile generated code. #: Might have no effect when CMake is not used as part of the toolchain. CMAKE_BUILD_TYPE: CMakeBuildType = CMakeBuildType[ - os.environ.get(f"{_PREFIX}_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper() + os.environ.get("GT4PY_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper() ] From d581060e5c6e8b6f64b72cce041d539956ca4727 Mon Sep 17 00:00:00 2001 From: SF-N Date: Sat, 30 Nov 2024 09:39:26 +0100 Subject: [PATCH 6/8] bug[next]: ConstantFolding after create_global_tmps (#1756) Do `ConstantFolding` within `domain_union` to avoid nested minima and maxima by `create_global_tmps` --------- Co-authored-by: Till Ehrengruber --- src/gt4py/next/iterator/ir_utils/domain_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index f5625b509c..4a023f7535 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -16,6 +16,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms import trace_shifts +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> dict[str, int]: @@ -168,6 +169,8 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr), [domain.ranges[dim].stop for domain in domains], ) + # constant fold expression to keep the tree small + start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr new_domain_ranges[dim] = SymbolicRange(start, stop) return SymbolicDomain(domains[0].grid_type, new_domain_ranges) From a26d91f409ea5d67f168bbbc4a2157df2ed1080b Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 21:31:13 +0100 Subject: [PATCH 7/8] fix[next]: Fix annex & type preservation in inline_lambdas (#1760) Co-authored-by: SF-N --- src/gt4py/next/iterator/transforms/inline_lambdas.py | 11 +++++------ src/gt4py/next/iterator/transforms/remap_symbols.py | 5 ++++- src/gt4py/next/iterator/type_system/inference.py | 7 +++++-- .../transforms_tests/test_inline_lambdas.py | 7 +++++++ 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index 5ec9ec5d0b..9053214b39 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -97,7 +97,6 @@ def new_name(name): if all(eligible_params): new_expr.location = node.location - return new_expr else: new_expr = ir.FunCall( fun=ir.Lambda( @@ -111,11 +110,11 @@ def new_name(name): args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], location=node.location, ) - for attr in ("type", "recorded_shifts", "domain"): - if hasattr(node.annex, attr): - setattr(new_expr.annex, attr, getattr(node.annex, attr)) - itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True) - return new_expr + for attr in ("type", "recorded_shifts", "domain"): + if hasattr(node.annex, attr): + setattr(new_expr.annex, attr, getattr(node.annex, attr)) + itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True) + return new_expr @dataclasses.dataclass diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index 08d896121d..fb909dc5d0 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -10,6 +10,7 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor, SymbolTableTrait from gt4py.next.iterator import ir +from gt4py.next.iterator.type_system import inference as type_inference class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator): @@ -46,7 +47,9 @@ def visit_SymRef( self, node: ir.SymRef, *, name_map: Dict[str, str], active: Optional[Set[str]] = None ): if active and node.id in active: - return ir.SymRef(id=name_map.get(node.id, node.id)) + new_ref = ir.SymRef(id=name_map.get(node.id, node.id)) + type_inference.copy_type(from_=node, to=new_ref, allow_untyped=True) + return new_ref return node def generic_visit( # type: ignore[override] diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 249019769b..ffca6cc7a7 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -95,14 +95,17 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: node.type = type_ -def copy_type(from_: itir.Node, to: itir.Node, allow_untyped=False) -> None: +def copy_type(from_: itir.Node, to: itir.Node, allow_untyped: bool = False) -> None: """ Copy type from one node to another. This function mainly exists for readability reasons. """ assert allow_untyped is not None or isinstance(from_.type, ts.TypeSpec) - _set_node_type(to, from_.type) # type: ignore[arg-type] + if from_.type is None: + assert allow_untyped + return + _set_node_type(to, from_.type) def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None: diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index 2e0a83d33b..c10d48ad06 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -84,3 +84,10 @@ def test_inline_lambda_args(): ) inlined = InlineLambdas.apply(testee, opcount_preserving=True, force_inline_lambda_args=True) assert inlined == expected + + +def test_type_preservation(): + testee = im.let("a", "b")("a") + testee.type = testee.annex.type = ts.ScalarType(kind=ts.ScalarKind.FLOAT32) + inlined = InlineLambdas.apply(testee) + assert inlined.type == inlined.annex.type == ts.ScalarType(kind=ts.ScalarKind.FLOAT32) From 99c53004663b0b58c7ce8335bcc30e347d3686b5 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 22:08:39 +0100 Subject: [PATCH 8/8] refactor[next]: Use `set_at` & `as_fieldop` instead of `closure` in iterator tests (#1691) --- .../test_cartesian_offset_provider.py | 12 +++--- .../iterator_tests/test_conditional.py | 2 +- .../test_strided_offset_provider.py | 7 ++-- .../iterator_tests/test_trivial.py | 10 ++--- .../iterator_tests/test_tuple.py | 28 +++++-------- .../iterator_tests/test_anton_toy.py | 21 +++++----- .../iterator_tests/test_fvm_nabla.py | 40 ++++++++----------- .../iterator_tests/test_hdiff.py | 10 ++--- 8 files changed, 55 insertions(+), 75 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py index 2ebcd0c033..fedfd83fd2 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py @@ -10,7 +10,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import double_roundtrip, roundtrip @@ -27,16 +27,14 @@ def foo(inp): @fendef(offset_provider={"I": I_loc, "J": J_loc}) def fencil(output, input): - closure( - cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), foo, output, [input] - ) + domain = cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)) + set_at(as_fieldop(foo, domain)(input), domain, output) @fendef(offset_provider={"I": J_loc, "J": I_loc}) def fencil_swapped(output, input): - closure( - cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), foo, output, [input] - ) + domain = cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)) + set_at(as_fieldop(foo, domain)(input), domain, output) def test_cartesian_offset_provider(): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py index 551c567e61..eae66d425b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef +from gt4py.next.iterator.runtime import set_at, fendef, fundef from next_tests.unit_tests.conftest import program_processor, run_processor diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py index 7bde55bfd2..68e5f9d532 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py @@ -10,8 +10,8 @@ import pytest import gt4py.next as gtx -from gt4py.next.iterator.builtins import deref, named_range, shift, unstructured_domain -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.builtins import deref, named_range, shift, unstructured_domain, as_fieldop +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.unit_tests.conftest import program_processor, run_processor from gt4py.next.iterator.embedded import StridedConnectivityField @@ -36,7 +36,8 @@ def foo(inp): @fendef(offset_provider={"O": LocA2LocAB_offset_provider}) def fencil(size, out, inp): - closure(unstructured_domain(named_range(LocA, 0, size)), foo, out, [inp]) + domain = unstructured_domain(named_range(LocA, 0, size)) + set_at(as_fieldop(foo, domain)(inp), domain, out) @pytest.mark.uses_strided_neighbor_offset diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py index 5f1c70a6b3..fe89fe7c9d 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py @@ -12,7 +12,7 @@ import gt4py.next as gtx from gt4py.next.iterator import transforms from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.integration_tests.cases import IDim, JDim, KDim from next_tests.unit_tests.conftest import program_processor, run_processor @@ -94,12 +94,8 @@ def test_shifted_arg_to_lift(program_processor): @fendef def fen_direct_deref(i_size, j_size, out, inp): - closure( - cartesian_domain(named_range(IDim, 0, i_size), named_range(JDim, 0, j_size)), - deref, - out, - [inp], - ) + domain = cartesian_domain(named_range(IDim, 0, i_size), named_range(JDim, 0, j_size)) + set_at(as_fieldop(deref, domain)(inp), domain, out) def test_direct_deref(program_processor): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index 2d84439c93..39d0bd69c3 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef +from gt4py.next.iterator.runtime import set_at, fendef, fundef from next_tests.unit_tests.conftest import program_processor, run_processor @@ -114,16 +114,10 @@ def test_tuple_of_field_output_constructed_inside(program_processor, stencil): @fendef def fencil(size0, size1, size2, inp1, inp2, out1, out2): - closure( - cartesian_domain( - named_range(IDim, 0, size0), - named_range(JDim, 0, size1), - named_range(KDim, 0, size2), - ), - stencil, - make_tuple(out1, out2), - [inp1, inp2], + domain = cartesian_domain( + named_range(IDim, 0, size0), named_range(JDim, 0, size1), named_range(KDim, 0, size2) ) + set_at(as_fieldop(stencil, domain)(inp1, inp2), domain, make_tuple(out1, out2)) shape = [5, 7, 9] rng = np.random.default_rng() @@ -159,15 +153,13 @@ def stencil(inp1, inp2, inp3): @fendef def fencil(size0, size1, size2, inp1, inp2, inp3, out1, out2, out3): - closure( - cartesian_domain( - named_range(IDim, 0, size0), - named_range(JDim, 0, size1), - named_range(KDim, 0, size2), - ), - stencil, + domain = cartesian_domain( + named_range(IDim, 0, size0), named_range(JDim, 0, size1), named_range(KDim, 0, size2) + ) + set_at( + as_fieldop(stencil, domain)(inp1, inp2, inp3), + domain, make_tuple(make_tuple(out1, out2), out3), - [inp1, inp2, inp3], ) shape = [5, 7, 9] diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 3ce9d6b470..d0a1601816 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -10,8 +10,15 @@ import pytest import gt4py.next as gtx -from gt4py.next.iterator.builtins import cartesian_domain, deref, lift, named_range, shift -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.builtins import ( + cartesian_domain, + deref, + lift, + named_range, + shift, + as_fieldop, +) +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn from next_tests.unit_tests.conftest import program_processor, run_processor @@ -85,14 +92,10 @@ def test_anton_toy(stencil, program_processor): @fendef(offset_provider={"i": IDim, "j": JDim}) def fencil(x, y, z, out, inp): - closure( - cartesian_domain( - named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z) - ), - stencil, - out, - [inp], + domain = cartesian_domain( + named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z) ) + set_at(as_fieldop(stencil, domain)(inp), domain, out) shape = [5, 7, 9] rng = np.random.default_rng() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 4487681abf..22b4d8b3c5 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -28,8 +28,9 @@ reduce, tuple_get, unstructured_domain, + as_fieldop, ) -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import ( assert_close, @@ -55,7 +56,8 @@ def compute_zavgS(pp, S_M): @fendef def compute_zavgS_fencil(n_edges, out, pp, S_M): - closure(unstructured_domain(named_range(Edge, 0, n_edges)), compute_zavgS, out, [pp, S_M]) + domain = unstructured_domain(named_range(Edge, 0, n_edges)) + set_at(as_fieldop(compute_zavgS, domain)(pp, S_M), domain, out) @fundef @@ -100,12 +102,8 @@ def compute_pnabla2(pp, S_M, sign, vol): @fendef def nabla(n_nodes, out, pp, S_MXX, S_MYY, sign, vol): - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - pnabla, - out, - [pp, S_MXX, S_MYY, sign, vol], - ) + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at(as_fieldop(pnabla, domain)(pp, S_MXX, S_MYY, sign, vol), domain, out) @pytest.mark.requires_atlas @@ -145,7 +143,8 @@ def test_compute_zavgS(program_processor): @fendef def compute_zavgS2_fencil(n_edges, out, pp, S_M): - closure(unstructured_domain(named_range(Edge, 0, n_edges)), compute_zavgS2, out, [pp, S_M]) + domain = unstructured_domain(named_range(Edge, 0, n_edges)) + set_at(as_fieldop(compute_zavgS2, domain)(pp, S_M), domain, out) @pytest.mark.requires_atlas @@ -212,12 +211,8 @@ def test_nabla(program_processor): @fendef def nabla2(n_nodes, out, pp, S, sign, vol): - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla2, - out, - [pp, S, sign, vol], - ) + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at(as_fieldop(compute_pnabla2, domain)(pp, S, sign, vol), domain, out) @pytest.mark.requires_atlas @@ -276,17 +271,16 @@ def compute_pnabla_sign(pp, S_M, vol, node_index, is_pole_edge): @fendef def nabla_sign(n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, vol, node_index, is_pole_edge): # TODO replace by single stencil which returns tuple - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla_sign, + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at( + as_fieldop(compute_pnabla_sign, domain)(pp, S_MXX, vol, node_index, is_pole_edge), + domain, out_MXX, - [pp, S_MXX, vol, node_index, is_pole_edge], ) - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla_sign, + set_at( + as_fieldop(compute_pnabla_sign, domain)(pp, S_MYY, vol, node_index, is_pole_edge), + domain, out_MYY, - [pp, S_MYY, vol, node_index, is_pole_edge], ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index 45793b1d3e..e44e92013f 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn from next_tests.integration_tests.cases import IDim, JDim @@ -57,12 +57,8 @@ def hdiff_sten(inp, coeff): @fendef(offset_provider={"I": IDim, "J": JDim}) def hdiff(inp, coeff, out, x, y): - closure( - cartesian_domain(named_range(IDim, 0, x), named_range(JDim, 0, y)), - hdiff_sten, - out, - [inp, coeff], - ) + domain = cartesian_domain(named_range(IDim, 0, x), named_range(JDim, 0, y)) + set_at(as_fieldop(hdiff_sten, domain)(inp, coeff), domain, out) @pytest.mark.uses_origin