From f6c219bd989e3c5325da1173bade4bff2ac9e650 Mon Sep 17 00:00:00 2001 From: SF-N Date: Tue, 26 Nov 2024 15:59:58 +0100 Subject: [PATCH 01/13] bug[next]: Fix SetAt type inference for ts.DeferredType (#1747) Fix to correctly handle tuples of ts.DeferredType. --------- Co-authored-by: Till Ehrengruber --- src/gt4py/next/iterator/type_system/inference.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 987eb0f308..249019769b 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -509,7 +509,10 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: # the target can have fewer elements than the expr in which case the output from the # expression is simply discarded. expr_type = functools.reduce( - lambda tuple_type, i: tuple_type.types[i], # type: ignore[attr-defined] # format ensured by primitive_constituents + lambda tuple_type, i: tuple_type.types[i] # type: ignore[attr-defined] # format ensured by primitive_constituents + # `ts.DeferredType` only occurs for scans returning a tuple + if not isinstance(tuple_type, ts.DeferredType) + else ts.DeferredType(constraint=None), path, node.expr.type, ) From f6c0498dbffd85a80a32281e5a53bfb35e00e745 Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 27 Nov 2024 09:55:46 +0100 Subject: [PATCH 02/13] feat[next][dace]: Lowering to SDFG of index builtin (#1751) Implements the lowering to SDFG of the GTIR index builtin. --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 14 ++++ .../gtir_builtin_translators.py | 83 ++++++++++++++++--- .../runners/dace_fieldview/gtir_sdfg.py | 2 + tests/next_tests/definitions.py | 1 - .../dace_tests/test_gtir_to_sdfg.py | 50 ++++++++++- 5 files changed, 134 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 2864c7f727..a4e111e785 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -519,6 +519,20 @@ def _impl(it: itir.Expr) -> itir.FunCall: return _impl +def index(dim: common.Dimension) -> itir.FunCall: + """ + Create a call to the `index` builtin, shorthand for `call("index")(axis)`, + after converting the given dimension to `itir.AxisLiteral`. + + Args: + dim: the dimension corresponding to the index axis. + + Returns: + A function that constructs a Field of indices in the given dimension. + """ + return call("index")(itir.AxisLiteral(value=dim.value, kind=dim.kind)) + + def map_(op): """Create a `map_` call.""" return call(call("map_")(op)) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 60dcd8ddc9..94ab3a6f76 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -18,7 +18,7 @@ from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -277,20 +277,31 @@ def extract_domain(node: gtir.Node) -> FieldopDomain: the corresponding lower and upper bounds. The returned lower bound is inclusive, the upper bound is exclusive: [lower_bound, upper_bound[ """ - assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) domain = [] - for named_range in node.args: - assert cpm.is_call_to(named_range, "named_range") - assert len(named_range.args) == 3 - axis = named_range.args[0] - assert isinstance(axis, gtir.AxisLiteral) - lower_bound, upper_bound = ( - dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(arg)) - for arg in named_range.args[1:3] - ) - dim = gtx_common.Dimension(axis.value, axis.kind) - domain.append((dim, lower_bound, upper_bound)) + + def parse_range_boundary(expr: gtir.Expr) -> str: + return dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(expr)) + + if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")): + for named_range in node.args: + assert cpm.is_call_to(named_range, "named_range") + assert len(named_range.args) == 3 + axis = named_range.args[0] + assert isinstance(axis, gtir.AxisLiteral) + lower_bound, upper_bound = (parse_range_boundary(arg) for arg in named_range.args[1:3]) + dim = gtx_common.Dimension(axis.value, axis.kind) + domain.append((dim, lower_bound, upper_bound)) + + elif isinstance(node, domain_utils.SymbolicDomain): + assert str(node.grid_type) in {"cartesian_domain", "unstructured_domain"} + for dim, drange in node.ranges.items(): + domain.append( + (dim, parse_range_boundary(drange.start), parse_range_boundary(drange.stop)) + ) + + else: + raise ValueError(f"Invalid domain {node}.") return domain @@ -545,6 +556,51 @@ def construct_output(inner_data: FieldopData) -> FieldopData: return result_temps +def translate_index( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, +) -> FieldopResult: + """ + Lowers the `index` builtin function to a mapped tasklet that writes the dimension + index values to a transient array. The extent of the index range is taken from + the domain information that should be present in the node annex. + """ + assert "domain" in node.annex + domain = extract_domain(node.annex.domain) + assert len(domain) == 1 + dim, lower_bound, upper_bound = domain[0] + dim_index = dace_gtir_utils.get_map_variable(dim) + + field_dims, field_offset, field_shape = _get_field_layout(domain) + field_type = ts.FieldType(field_dims, dace_utils.as_itir_type(INDEX_DTYPE)) + + output, _ = sdfg.add_temp_transient(field_shape, INDEX_DTYPE) + output_node = state.add_access(output) + + sdfg_builder.add_mapped_tasklet( + "index", + state, + map_ranges={ + dim_index: f"{lower_bound}:{upper_bound}", + }, + inputs={}, + code=f"__val = {dim_index}", + outputs={ + "__val": dace.Memlet( + data=output_node.data, + subset=sbs.Range.from_indices(_get_domain_indices(field_dims, field_offset)), + ) + }, + input_nodes={}, + output_nodes={output_node.data: output_node}, + external_edges=True, + ) + + return FieldopData(output_node, field_type, field_offset) + + def _get_data_nodes( sdfg: dace.SDFG, state: dace.SDFGState, @@ -777,6 +833,7 @@ def translate_symbol_ref( translate_as_fieldop, translate_broadcast_scalar, translate_if, + translate_index, translate_literal, translate_make_tuple, translate_tuple_get, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index f15287e64c..6b5e164458 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -568,6 +568,8 @@ def visit_FunCall( # use specialized dataflow builder classes for each builtin function if cpm.is_call_to(node, "if_"): return gtir_builtin_translators.translate_if(node, sdfg, head_state, self) + elif cpm.is_call_to(node, "index"): + return gtir_builtin_translators.translate_index(node, sdfg, head_state, self) elif cpm.is_call_to(node, "make_tuple"): return gtir_builtin_translators.translate_make_tuple(node, sdfg, head_state, self) elif cpm.is_call_to(node, "tuple_get"): diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 01fd18897d..349d3e9f70 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -154,7 +154,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), ] GTIR_DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ - (USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index f5191fbaaa..c7466b853f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -12,15 +12,15 @@ Note: this test module covers the fieldview flavour of ITIR. """ -import copy import functools import numpy as np import pytest -from gt4py.next import common as gtx_common, constructors +from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import infer_domain from gt4py.next.type_system import type_specifications as ts from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -1973,3 +1973,49 @@ def test_gtir_if_values(): sdfg(a, b, c, **FSYMBOLS) assert np.allclose(c, np.where(a < b, a, b)) + + +def test_gtir_index(): + MARGIN = 2 + assert MARGIN < N + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + subdomain = im.domain( + gtx_common.GridType.CARTESIAN, ranges={IDim: (MARGIN, im.minus("size", MARGIN))} + ) + + testee = gtir.Program( + id="gtir_cast", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.let("i", im.index(IDim))( + im.op_as_fieldop("plus", domain)( + "i", + im.as_fieldop( + im.lambda_("a")(im.deref(im.shift(IDim.value, 1)("a"))), subdomain + )("i"), + ) + ), + domain=subdomain, + target=gtir.SymRef(id="x"), + ) + ], + ) + + v = np.empty(N, dtype=np.int32) + + # we need to run domain inference in order to add the domain annex information to the index node. + testee = infer_domain.infer_program(testee, offset_provider=CARTESIAN_OFFSETS) + sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + + ref = np.concatenate( + (v[:MARGIN], np.arange(MARGIN, N - MARGIN, dtype=np.int32), v[N - MARGIN :]) + ) + + sdfg(v, **FSYMBOLS) + np.allclose(v, ref) From 3ece412f0d78f32893d8f01ed0e74c8b38388854 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 28 Nov 2024 13:13:55 -0500 Subject: [PATCH 03/13] fix[cartesian]: Deactivate K offset write in `gt:gpu` (#1755) Following the issue logged as https://github.com/GridTools/gt4py/issues/1754 we are deactivating the K-offset write feature until we can figure out why it's failing. I will monitor any activity on the ticket if users are hit by this. --------- Co-authored-by: Hannes Vogt --- src/gt4py/cartesian/frontend/gtscript_frontend.py | 7 +++++++ .../multi_feature_tests/test_code_generation.py | 4 ++++ 2 files changed, 11 insertions(+) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index ade05921ef..f155ea6209 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -1460,6 +1460,13 @@ def visit_Assign(self, node: ast.Assign) -> list: loc=nodes.Location.from_ast_node(t), ) + if self.backend_name in ["gt:gpu"]: + raise GTScriptSyntaxError( + message=f"Assignment to non-zero offsets in K is not available in {self.backend_name} as an unsolved bug remains." + "Please refer to https://github.com/GridTools/gt4py/issues/1754.", + loc=nodes.Location.from_ast_node(t), + ) + if not self._is_known(name): if name in self.temp_decls: field_decl = self.temp_decls[name] diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index c4d07d7337..7c4956b3ef 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -667,6 +667,10 @@ def test_K_offset_write_conditional(backend): pytest.skip( f"{backend} backend with CUDA 11 and/or GCC 10.3 is not capable of K offset write, update CUDA/GCC if possible" ) + if backend in ["gt:gpu"]: + pytest.skip( + f"{backend} backend is not capable of K offset write, bug remains unsolved: https://github.com/GridTools/gt4py/issues/1754" + ) arraylib = get_array_library(backend) array_shape = (1, 1, 4) From 886058496c1ebcb90ba530a796213d1fec7c7095 Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 29 Nov 2024 08:46:06 +0100 Subject: [PATCH 04/13] refact[next][dace]: Helper function for field operator constructor (#1743) Includes refactoring of the code for construction of field operators, in order to make it usable by the three lowering functions that construct fields: `translate_as_fieldop()`, `translate_broadcast_scalar()`, and `translate_index()`. --- .../gtir_builtin_translators.py | 242 +++++++----------- 1 file changed, 94 insertions(+), 148 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 94ab3a6f76..ff011c4193 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -18,7 +18,11 @@ from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -229,40 +233,75 @@ def _get_field_layout( return list(domain_dims), list(domain_lbs), domain_sizes -def _create_temporary_field( +def _create_field_operator( sdfg: dace.SDFG, state: dace.SDFGState, domain: FieldopDomain, node_type: ts.FieldType, - dataflow_output: gtir_dataflow.DataflowOutputEdge, + sdfg_builder: gtir_sdfg.SDFGBuilder, + input_edges: Sequence[gtir_dataflow.DataflowInputEdge], + output_edge: gtir_dataflow.DataflowOutputEdge, ) -> FieldopData: - """Helper method to allocate a temporary field where to write the output of a field operator.""" + """ + Helper method to allocate a temporary field to store the output of a field operator. + + Args: + sdfg: The SDFG that represents the scope of the field data. + state: The SDFG state where to create an access node to the field data. + domain: The domain of the field operator that computes the field. + node_type: The GT4Py type of the IR node that produces this field. + sdfg_builder: The object used to build the map scope in the provided SDFG. + input_edges: List of edges to pass input data into the dataflow. + output_edge: Edge representing the dataflow output data. + + Returns: + The field data descriptor, which includes the field access node in the given `state` + and the field domain offset. + """ field_dims, field_offset, field_shape = _get_field_layout(domain) + field_indices = _get_domain_indices(field_dims, field_offset) + + dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) - output_desc = dataflow_output.result.dc_node.desc(sdfg) - if isinstance(output_desc, dace.data.Array): + field_subset = sbs.Range.from_indices(field_indices) + if isinstance(output_edge.result.gt_dtype, ts.ScalarType): + assert output_edge.result.gt_dtype == node_type.dtype + assert isinstance(dataflow_output_desc, dace.data.Scalar) + assert dataflow_output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) + field_dtype = output_edge.result.gt_dtype + else: assert isinstance(node_type.dtype, itir_ts.ListType) - assert isinstance(node_type.dtype.element_type, ts.ScalarType) - assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype.element_type) + assert output_edge.result.gt_dtype.element_type == node_type.dtype.element_type + assert isinstance(dataflow_output_desc, dace.data.Array) + assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) + field_dtype = output_edge.result.gt_dtype.element_type # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) - field_offset.extend(output_desc.offset) - field_shape.extend(output_desc.shape) - elif isinstance(output_desc, dace.data.Scalar): - assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) - else: - raise ValueError(f"Cannot create field for dace type {output_desc}.") + assert output_edge.result.gt_dtype.offset_type is not None + field_dims.append(output_edge.result.gt_dtype.offset_type) + field_shape.extend(dataflow_output_desc.shape) + field_offset.extend(dataflow_output_desc.offset) + field_subset = field_subset + sbs.Range.from_array(dataflow_output_desc) # allocate local temporary storage - temp_name, _ = sdfg.add_temp_transient(field_shape, output_desc.dtype) - field_node = state.add_access(temp_name) + field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) + field_node = state.add_access(field_name) - if isinstance(dataflow_output.result.gt_dtype, ts.ScalarType): - field_dtype = dataflow_output.result.gt_dtype - else: - assert isinstance(dataflow_output.result.gt_dtype.element_type, ts.ScalarType) - field_dtype = dataflow_output.result.gt_dtype.element_type - assert dataflow_output.result.gt_dtype.offset_type is not None - field_dims.append(dataflow_output.result.gt_dtype.offset_type) + # create map range corresponding to the field operator domain + me, mx = sdfg_builder.add_map( + "fieldop", + state, + ndrange={ + dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" + for dim, lower_bound, upper_bound in domain + }, + ) + + # here we setup the edges passing through the map entry node + for edge in input_edges: + edge.connect(me) + + # and here the edge writing the dataflow result data through the map exit node + output_edge.connect(mx, field_node, field_subset) return FieldopData( field_node, @@ -341,7 +380,8 @@ def translate_as_fieldop( # Special usage of 'deref' as argument to fieldop expression, to pass a scalar # value to 'as_fieldop' function. It results in broadcasting the scalar value # over the field domain. - return translate_broadcast_scalar(node, sdfg, state, sdfg_builder) + stencil_expr = im.lambda_("a")(im.deref("a")) + stencil_expr.expr.type = node.type.dtype # type: ignore[attr-defined] else: raise NotImplementedError( f"Expression type '{type(stencil_expr)}' not supported as argument to 'as_fieldop' node." @@ -349,117 +389,18 @@ def translate_as_fieldop( # parse the domain of the field operator domain = extract_domain(domain_expr) - domain_dims, domain_offsets, _ = zip(*domain) - domain_indices = _get_domain_indices(domain_dims, domain_offsets) # visit the list of arguments to be passed to the lambda expression stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder) - input_edges, output = taskgen.visit(stencil_expr, args=stencil_args) - output_desc = output.result.dc_node.desc(sdfg) - - if isinstance(node.type.dtype, itir_ts.ListType): - assert isinstance(output_desc, dace.data.Array) - # additional local dimension for neighbors - # TODO(phimuell): Investigate if we should swap the two. - output_subset = sbs.Range.from_indices(domain_indices) + sbs.Range.from_array(output_desc) - else: - assert isinstance(output_desc, dace.data.Scalar) - output_subset = sbs.Range.from_indices(domain_indices) - - # create map range corresponding to the field operator domain - me, mx = sdfg_builder.add_map( - "fieldop", - state, - ndrange={ - dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" - for dim, lower_bound, upper_bound in domain - }, - ) - - # allocate local temporary storage for the result field - result_field = _create_temporary_field(sdfg, state, domain, node.type, output) - - # here we setup the edges from the map entry node - for edge in input_edges: - edge.connect(me) - - # and here the edge writing the result data through the map exit node - output.connect(mx, result_field.dc_node, output_subset) - - return result_field - + input_edges, output_edge = taskgen.visit(stencil_expr, args=stencil_args) -def translate_broadcast_scalar( - node: gtir.Node, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: gtir_sdfg.SDFGBuilder, -) -> FieldopResult: - """ - Generates the dataflow subgraph for the 'as_fieldop' builtin function for the - special case where the argument to 'as_fieldop' is a 'deref' scalar expression, - rather than a lambda function. This case corresponds to broadcasting the scalar - value over the field domain. Therefore, it is lowered to a mapped tasklet that - just writes the scalar value out to all elements of the result field. - """ - assert isinstance(node, gtir.FunCall) - assert cpm.is_call_to(node.fun, "as_fieldop") - assert isinstance(node.type, ts.FieldType) - - fun_node = node.fun - assert len(fun_node.args) == 2 - stencil_expr, domain_expr = fun_node.args - assert cpm.is_ref_to(stencil_expr, "deref") - - domain = extract_domain(domain_expr) - output_dims, output_offset, output_shape = _get_field_layout(domain) - output_subset = sbs.Range.from_indices(_get_domain_indices(output_dims, output_offset)) - - assert len(node.args) == 1 - scalar_expr = _parse_fieldop_arg(node.args[0], sdfg, state, sdfg_builder, domain) - - if isinstance(node.args[0].type, ts.ScalarType): - assert isinstance(scalar_expr, (gtir_dataflow.MemletExpr, gtir_dataflow.ValueExpr)) - input_subset = ( - str(scalar_expr.subset) if isinstance(scalar_expr, gtir_dataflow.MemletExpr) else "0" - ) - input_node = scalar_expr.dc_node - gt_dtype = node.args[0].type - elif isinstance(node.args[0].type, ts.FieldType): - assert isinstance(scalar_expr, gtir_dataflow.IteratorExpr) - if len(node.args[0].type.dims) == 0: # zero-dimensional field - input_subset = "0" - else: - input_subset = scalar_expr.get_memlet_subset(sdfg) - - input_node = scalar_expr.field - gt_dtype = node.args[0].type.dtype - else: - raise ValueError(f"Unexpected argument {node.args[0]} in broadcast expression.") - - output, _ = sdfg.add_temp_transient(output_shape, input_node.desc(sdfg).dtype) - output_node = state.add_access(output) - - sdfg_builder.add_mapped_tasklet( - "broadcast", - state, - map_ranges={ - dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" - for dim, lower_bound, upper_bound in domain - }, - inputs={"__inp": dace.Memlet(data=input_node.data, subset=input_subset)}, - code="__val = __inp", - outputs={"__val": dace.Memlet(data=output_node.data, subset=output_subset)}, - input_nodes={input_node.data: input_node}, - output_nodes={output_node.data: output_node}, - external_edges=True, + return _create_field_operator( + sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge ) - return FieldopData(output_node, ts.FieldType(output_dims, gt_dtype), output_offset) - def translate_if( node: gtir.Node, @@ -567,38 +508,44 @@ def translate_index( index values to a transient array. The extent of the index range is taken from the domain information that should be present in the node annex. """ + assert cpm.is_call_to(node, "index") + assert isinstance(node.type, ts.FieldType) + assert "domain" in node.annex domain = extract_domain(node.annex.domain) assert len(domain) == 1 - dim, lower_bound, upper_bound = domain[0] + dim, _, _ = domain[0] dim_index = dace_gtir_utils.get_map_variable(dim) - field_dims, field_offset, field_shape = _get_field_layout(domain) - field_type = ts.FieldType(field_dims, dace_utils.as_itir_type(INDEX_DTYPE)) - - output, _ = sdfg.add_temp_transient(field_shape, INDEX_DTYPE) - output_node = state.add_access(output) - - sdfg_builder.add_mapped_tasklet( + index_data = sdfg.temp_data_name() + sdfg.add_scalar(index_data, INDEX_DTYPE, transient=True) + index_node = state.add_access(index_data) + index_value = gtir_dataflow.ValueExpr( + dc_node=index_node, + gt_dtype=dace_utils.as_itir_type(INDEX_DTYPE), + ) + index_write_tasklet = sdfg_builder.add_tasklet( "index", state, - map_ranges={ - dim_index: f"{lower_bound}:{upper_bound}", - }, inputs={}, + outputs={"__val"}, code=f"__val = {dim_index}", - outputs={ - "__val": dace.Memlet( - data=output_node.data, - subset=sbs.Range.from_indices(_get_domain_indices(field_dims, field_offset)), - ) - }, - input_nodes={}, - output_nodes={output_node.data: output_node}, - external_edges=True, + ) + state.add_edge( + index_write_tasklet, + "__val", + index_node, + None, + dace.Memlet(data=index_data, subset="0"), ) - return FieldopData(output_node, field_type, field_offset) + input_edges = [ + gtir_dataflow.EmptyInputEdge(state, index_write_tasklet), + ] + output_edge = gtir_dataflow.DataflowOutputEdge(state, index_value) + return _create_field_operator( + sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge + ) def _get_data_nodes( @@ -831,7 +778,6 @@ def translate_symbol_ref( # Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol __primitive_translators: list[PrimitiveTranslator] = [ translate_as_fieldop, - translate_broadcast_scalar, translate_if, translate_index, translate_literal, From d9b38f476ee5df1995d27b7497037f3f19c9b6e6 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 29 Nov 2024 02:50:43 -0500 Subject: [PATCH 05/13] hotfix[cartesian]: Fixing k offset write utest deactivate (#1757) Missed a utest in #1755 --- .../multi_feature_tests/test_code_generation.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 7c4956b3ef..e51b3ef09d 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -582,13 +582,17 @@ def test_K_offset_write(backend): # Cuda generates bad code for the K offset if backend == "cuda": pytest.skip("cuda K-offset write generates bad code") - if backend in ["gt:gpu", "dace:gpu"]: + if backend in ["dace:gpu"]: import cupy as cp if cp.cuda.runtime.runtimeGetVersion() < 12000: pytest.skip( f"{backend} backend with CUDA 11 and/or GCC 10.3 is not capable of K offset write, update CUDA/GCC if possible" ) + if backend in ["gt:gpu"]: + pytest.skip( + f"{backend} backend is not capable of K offset write, bug remains unsolved: https://github.com/GridTools/gt4py/issues/1754" + ) arraylib = get_array_library(backend) array_shape = (1, 1, 4) @@ -660,7 +664,7 @@ def backward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): def test_K_offset_write_conditional(backend): if backend == "cuda": pytest.skip("Cuda backend is not capable of K offset write") - if backend in ["gt:gpu", "dace:gpu"]: + if backend in ["dace:gpu"]: import cupy as cp if cp.cuda.runtime.runtimeGetVersion() < 12000: From 791f67d031127872fc6375819267f59faeaf85ba Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 29 Nov 2024 10:02:34 +0100 Subject: [PATCH 06/13] test[next]: Fix flaky failure in GTIR to SDFG tests (#1759) The SDFG name has to be unique to avoid issues with parallel build in CI tests. --- .../runners_tests/dace_tests/test_gtir_to_sdfg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index c7466b853f..b1ba4ccf22 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -1984,7 +1984,7 @@ def test_gtir_index(): ) testee = gtir.Program( - id="gtir_cast", + id="gtir_index", function_definitions=[], params=[ gtir.Sym(id="x", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)), From 04513ba859d5ed55ea99999f6fd826a2a542a627 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Fri, 29 Nov 2024 13:57:10 +0100 Subject: [PATCH 07/13] fix[next]: use current working directory as default cache folder root (#1744) Change the root folder of the gt4py cache directory from the system temp folder to the current working directory, which is more visible and also avoids polluting shared filesystems in hpc clusters. --------- Co-authored-by: Hannes Vogt --- src/gt4py/next/config.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/config.py b/src/gt4py/next/config.py index ed244c2932..7a19f3eb9d 100644 --- a/src/gt4py/next/config.py +++ b/src/gt4py/next/config.py @@ -11,7 +11,6 @@ import enum import os import pathlib -import tempfile from typing import Final @@ -51,25 +50,22 @@ def env_flag_to_bool(name: str, default: bool) -> bool: ) -_PREFIX: Final[str] = "GT4PY" - #: Master debug flag #: Changes defaults for all the other options to be as helpful for debugging as possible. #: Does not override values set in environment variables. -DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=False) +DEBUG: Final[bool] = env_flag_to_bool("GT4PY_DEBUG", default=False) #: Verbose flag for DSL compilation errors VERBOSE_EXCEPTIONS: bool = env_flag_to_bool( - f"{_PREFIX}_VERBOSE_EXCEPTIONS", default=True if DEBUG else False + "GT4PY_VERBOSE_EXCEPTIONS", default=True if DEBUG else False ) #: Where generated code projects should be persisted. #: Only active if BUILD_CACHE_LIFETIME is set to PERSISTENT BUILD_CACHE_DIR: pathlib.Path = ( - pathlib.Path(os.environ.get(f"{_PREFIX}_BUILD_CACHE_DIR", tempfile.gettempdir())) - / "gt4py_cache" + pathlib.Path(os.environ.get("GT4PY_BUILD_CACHE_DIR", pathlib.Path.cwd())) / ".gt4py_cache" ) @@ -77,11 +73,11 @@ def env_flag_to_bool(name: str, default: bool) -> bool: #: - SESSION: generated code projects get destroyed when the interpreter shuts down #: - PERSISTENT: generated code projects are written to BUILD_CACHE_DIR and persist between runs BUILD_CACHE_LIFETIME: BuildCacheLifetime = BuildCacheLifetime[ - os.environ.get(f"{_PREFIX}_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper() + os.environ.get("GT4PY_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper() ] #: Build type to be used when CMake is used to compile generated code. #: Might have no effect when CMake is not used as part of the toolchain. CMAKE_BUILD_TYPE: CMakeBuildType = CMakeBuildType[ - os.environ.get(f"{_PREFIX}_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper() + os.environ.get("GT4PY_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper() ] From d581060e5c6e8b6f64b72cce041d539956ca4727 Mon Sep 17 00:00:00 2001 From: SF-N Date: Sat, 30 Nov 2024 09:39:26 +0100 Subject: [PATCH 08/13] bug[next]: ConstantFolding after create_global_tmps (#1756) Do `ConstantFolding` within `domain_union` to avoid nested minima and maxima by `create_global_tmps` --------- Co-authored-by: Till Ehrengruber --- src/gt4py/next/iterator/ir_utils/domain_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index f5625b509c..4a023f7535 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -16,6 +16,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms import trace_shifts +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> dict[str, int]: @@ -168,6 +169,8 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr), [domain.ranges[dim].stop for domain in domains], ) + # constant fold expression to keep the tree small + start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr new_domain_ranges[dim] = SymbolicRange(start, stop) return SymbolicDomain(domains[0].grid_type, new_domain_ranges) From a26d91f409ea5d67f168bbbc4a2157df2ed1080b Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 21:31:13 +0100 Subject: [PATCH 09/13] fix[next]: Fix annex & type preservation in inline_lambdas (#1760) Co-authored-by: SF-N --- src/gt4py/next/iterator/transforms/inline_lambdas.py | 11 +++++------ src/gt4py/next/iterator/transforms/remap_symbols.py | 5 ++++- src/gt4py/next/iterator/type_system/inference.py | 7 +++++-- .../transforms_tests/test_inline_lambdas.py | 7 +++++++ 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index 5ec9ec5d0b..9053214b39 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -97,7 +97,6 @@ def new_name(name): if all(eligible_params): new_expr.location = node.location - return new_expr else: new_expr = ir.FunCall( fun=ir.Lambda( @@ -111,11 +110,11 @@ def new_name(name): args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], location=node.location, ) - for attr in ("type", "recorded_shifts", "domain"): - if hasattr(node.annex, attr): - setattr(new_expr.annex, attr, getattr(node.annex, attr)) - itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True) - return new_expr + for attr in ("type", "recorded_shifts", "domain"): + if hasattr(node.annex, attr): + setattr(new_expr.annex, attr, getattr(node.annex, attr)) + itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True) + return new_expr @dataclasses.dataclass diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index 08d896121d..fb909dc5d0 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -10,6 +10,7 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor, SymbolTableTrait from gt4py.next.iterator import ir +from gt4py.next.iterator.type_system import inference as type_inference class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator): @@ -46,7 +47,9 @@ def visit_SymRef( self, node: ir.SymRef, *, name_map: Dict[str, str], active: Optional[Set[str]] = None ): if active and node.id in active: - return ir.SymRef(id=name_map.get(node.id, node.id)) + new_ref = ir.SymRef(id=name_map.get(node.id, node.id)) + type_inference.copy_type(from_=node, to=new_ref, allow_untyped=True) + return new_ref return node def generic_visit( # type: ignore[override] diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 249019769b..ffca6cc7a7 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -95,14 +95,17 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: node.type = type_ -def copy_type(from_: itir.Node, to: itir.Node, allow_untyped=False) -> None: +def copy_type(from_: itir.Node, to: itir.Node, allow_untyped: bool = False) -> None: """ Copy type from one node to another. This function mainly exists for readability reasons. """ assert allow_untyped is not None or isinstance(from_.type, ts.TypeSpec) - _set_node_type(to, from_.type) # type: ignore[arg-type] + if from_.type is None: + assert allow_untyped + return + _set_node_type(to, from_.type) def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None: diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index 2e0a83d33b..c10d48ad06 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -84,3 +84,10 @@ def test_inline_lambda_args(): ) inlined = InlineLambdas.apply(testee, opcount_preserving=True, force_inline_lambda_args=True) assert inlined == expected + + +def test_type_preservation(): + testee = im.let("a", "b")("a") + testee.type = testee.annex.type = ts.ScalarType(kind=ts.ScalarKind.FLOAT32) + inlined = InlineLambdas.apply(testee) + assert inlined.type == inlined.annex.type == ts.ScalarType(kind=ts.ScalarKind.FLOAT32) From 99c53004663b0b58c7ce8335bcc30e347d3686b5 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 22:08:39 +0100 Subject: [PATCH 10/13] refactor[next]: Use `set_at` & `as_fieldop` instead of `closure` in iterator tests (#1691) --- .../test_cartesian_offset_provider.py | 12 +++--- .../iterator_tests/test_conditional.py | 2 +- .../test_strided_offset_provider.py | 7 ++-- .../iterator_tests/test_trivial.py | 10 ++--- .../iterator_tests/test_tuple.py | 28 +++++-------- .../iterator_tests/test_anton_toy.py | 21 +++++----- .../iterator_tests/test_fvm_nabla.py | 40 ++++++++----------- .../iterator_tests/test_hdiff.py | 10 ++--- 8 files changed, 55 insertions(+), 75 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py index 2ebcd0c033..fedfd83fd2 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py @@ -10,7 +10,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import double_roundtrip, roundtrip @@ -27,16 +27,14 @@ def foo(inp): @fendef(offset_provider={"I": I_loc, "J": J_loc}) def fencil(output, input): - closure( - cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), foo, output, [input] - ) + domain = cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)) + set_at(as_fieldop(foo, domain)(input), domain, output) @fendef(offset_provider={"I": J_loc, "J": I_loc}) def fencil_swapped(output, input): - closure( - cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), foo, output, [input] - ) + domain = cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)) + set_at(as_fieldop(foo, domain)(input), domain, output) def test_cartesian_offset_provider(): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py index 551c567e61..eae66d425b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef +from gt4py.next.iterator.runtime import set_at, fendef, fundef from next_tests.unit_tests.conftest import program_processor, run_processor diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py index 7bde55bfd2..68e5f9d532 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py @@ -10,8 +10,8 @@ import pytest import gt4py.next as gtx -from gt4py.next.iterator.builtins import deref, named_range, shift, unstructured_domain -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.builtins import deref, named_range, shift, unstructured_domain, as_fieldop +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.unit_tests.conftest import program_processor, run_processor from gt4py.next.iterator.embedded import StridedConnectivityField @@ -36,7 +36,8 @@ def foo(inp): @fendef(offset_provider={"O": LocA2LocAB_offset_provider}) def fencil(size, out, inp): - closure(unstructured_domain(named_range(LocA, 0, size)), foo, out, [inp]) + domain = unstructured_domain(named_range(LocA, 0, size)) + set_at(as_fieldop(foo, domain)(inp), domain, out) @pytest.mark.uses_strided_neighbor_offset diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py index 5f1c70a6b3..fe89fe7c9d 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py @@ -12,7 +12,7 @@ import gt4py.next as gtx from gt4py.next.iterator import transforms from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.integration_tests.cases import IDim, JDim, KDim from next_tests.unit_tests.conftest import program_processor, run_processor @@ -94,12 +94,8 @@ def test_shifted_arg_to_lift(program_processor): @fendef def fen_direct_deref(i_size, j_size, out, inp): - closure( - cartesian_domain(named_range(IDim, 0, i_size), named_range(JDim, 0, j_size)), - deref, - out, - [inp], - ) + domain = cartesian_domain(named_range(IDim, 0, i_size), named_range(JDim, 0, j_size)) + set_at(as_fieldop(deref, domain)(inp), domain, out) def test_direct_deref(program_processor): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index 2d84439c93..39d0bd69c3 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef +from gt4py.next.iterator.runtime import set_at, fendef, fundef from next_tests.unit_tests.conftest import program_processor, run_processor @@ -114,16 +114,10 @@ def test_tuple_of_field_output_constructed_inside(program_processor, stencil): @fendef def fencil(size0, size1, size2, inp1, inp2, out1, out2): - closure( - cartesian_domain( - named_range(IDim, 0, size0), - named_range(JDim, 0, size1), - named_range(KDim, 0, size2), - ), - stencil, - make_tuple(out1, out2), - [inp1, inp2], + domain = cartesian_domain( + named_range(IDim, 0, size0), named_range(JDim, 0, size1), named_range(KDim, 0, size2) ) + set_at(as_fieldop(stencil, domain)(inp1, inp2), domain, make_tuple(out1, out2)) shape = [5, 7, 9] rng = np.random.default_rng() @@ -159,15 +153,13 @@ def stencil(inp1, inp2, inp3): @fendef def fencil(size0, size1, size2, inp1, inp2, inp3, out1, out2, out3): - closure( - cartesian_domain( - named_range(IDim, 0, size0), - named_range(JDim, 0, size1), - named_range(KDim, 0, size2), - ), - stencil, + domain = cartesian_domain( + named_range(IDim, 0, size0), named_range(JDim, 0, size1), named_range(KDim, 0, size2) + ) + set_at( + as_fieldop(stencil, domain)(inp1, inp2, inp3), + domain, make_tuple(make_tuple(out1, out2), out3), - [inp1, inp2, inp3], ) shape = [5, 7, 9] diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 3ce9d6b470..d0a1601816 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -10,8 +10,15 @@ import pytest import gt4py.next as gtx -from gt4py.next.iterator.builtins import cartesian_domain, deref, lift, named_range, shift -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.builtins import ( + cartesian_domain, + deref, + lift, + named_range, + shift, + as_fieldop, +) +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn from next_tests.unit_tests.conftest import program_processor, run_processor @@ -85,14 +92,10 @@ def test_anton_toy(stencil, program_processor): @fendef(offset_provider={"i": IDim, "j": JDim}) def fencil(x, y, z, out, inp): - closure( - cartesian_domain( - named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z) - ), - stencil, - out, - [inp], + domain = cartesian_domain( + named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z) ) + set_at(as_fieldop(stencil, domain)(inp), domain, out) shape = [5, 7, 9] rng = np.random.default_rng() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 4487681abf..22b4d8b3c5 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -28,8 +28,9 @@ reduce, tuple_get, unstructured_domain, + as_fieldop, ) -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import ( assert_close, @@ -55,7 +56,8 @@ def compute_zavgS(pp, S_M): @fendef def compute_zavgS_fencil(n_edges, out, pp, S_M): - closure(unstructured_domain(named_range(Edge, 0, n_edges)), compute_zavgS, out, [pp, S_M]) + domain = unstructured_domain(named_range(Edge, 0, n_edges)) + set_at(as_fieldop(compute_zavgS, domain)(pp, S_M), domain, out) @fundef @@ -100,12 +102,8 @@ def compute_pnabla2(pp, S_M, sign, vol): @fendef def nabla(n_nodes, out, pp, S_MXX, S_MYY, sign, vol): - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - pnabla, - out, - [pp, S_MXX, S_MYY, sign, vol], - ) + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at(as_fieldop(pnabla, domain)(pp, S_MXX, S_MYY, sign, vol), domain, out) @pytest.mark.requires_atlas @@ -145,7 +143,8 @@ def test_compute_zavgS(program_processor): @fendef def compute_zavgS2_fencil(n_edges, out, pp, S_M): - closure(unstructured_domain(named_range(Edge, 0, n_edges)), compute_zavgS2, out, [pp, S_M]) + domain = unstructured_domain(named_range(Edge, 0, n_edges)) + set_at(as_fieldop(compute_zavgS2, domain)(pp, S_M), domain, out) @pytest.mark.requires_atlas @@ -212,12 +211,8 @@ def test_nabla(program_processor): @fendef def nabla2(n_nodes, out, pp, S, sign, vol): - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla2, - out, - [pp, S, sign, vol], - ) + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at(as_fieldop(compute_pnabla2, domain)(pp, S, sign, vol), domain, out) @pytest.mark.requires_atlas @@ -276,17 +271,16 @@ def compute_pnabla_sign(pp, S_M, vol, node_index, is_pole_edge): @fendef def nabla_sign(n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, vol, node_index, is_pole_edge): # TODO replace by single stencil which returns tuple - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla_sign, + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at( + as_fieldop(compute_pnabla_sign, domain)(pp, S_MXX, vol, node_index, is_pole_edge), + domain, out_MXX, - [pp, S_MXX, vol, node_index, is_pole_edge], ) - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla_sign, + set_at( + as_fieldop(compute_pnabla_sign, domain)(pp, S_MYY, vol, node_index, is_pole_edge), + domain, out_MYY, - [pp, S_MYY, vol, node_index, is_pole_edge], ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index 45793b1d3e..e44e92013f 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn from next_tests.integration_tests.cases import IDim, JDim @@ -57,12 +57,8 @@ def hdiff_sten(inp, coeff): @fendef(offset_provider={"I": IDim, "J": JDim}) def hdiff(inp, coeff, out, x, y): - closure( - cartesian_domain(named_range(IDim, 0, x), named_range(JDim, 0, y)), - hdiff_sten, - out, - [inp, coeff], - ) + domain = cartesian_domain(named_range(IDim, 0, x), named_range(JDim, 0, y)) + set_at(as_fieldop(hdiff_sten, domain)(inp, coeff), domain, out) @pytest.mark.uses_origin From 6f49699f00ceb9e466fa4448bab779bc061df047 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 2 Dec 2024 13:09:47 +0100 Subject: [PATCH 11/13] style[eve]: remove unused imports and fix typos (#1748) Small cleanup PR in the eve framework: - Removes a stale `.gitignore` file. As far as I understood from the git history, earlier versions of this codebase had many `.gitignore` files in many places. Looks like this one is a leftover from a previous time. - Remove a couple of stale includes. The language server marked them as unused and since tests still pass, I guess we really don't need them anymore. - Fixed a couple of typos in comments - Fixed two typos in the github PR template --- .github/pull_request_template.md | 4 ++-- src/gt4py/eve/.gitignore | 1 - src/gt4py/eve/__init__.py | 14 ++------------ src/gt4py/eve/codegen.py | 6 +++--- src/gt4py/eve/datamodels/__init__.py | 4 ++-- src/gt4py/eve/datamodels/core.py | 16 ++++++++-------- src/gt4py/eve/extended_typing.py | 4 ---- src/gt4py/eve/trees.py | 8 -------- src/gt4py/eve/type_validation.py | 2 +- src/gt4py/eve/utils.py | 2 +- src/gt4py/next/ffront/decorator.py | 2 +- 11 files changed, 20 insertions(+), 43 deletions(-) delete mode 100644 src/gt4py/eve/.gitignore diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 7284a7df04..83304a9c62 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -15,7 +15,7 @@ Delete this comment and add a proper description of the changes contained in thi - test: Adding missing tests or correcting existing tests : cartesian | eve | next | storage - # ONLY if changes are limited to a specific subsytem + # ONLY if changes are limited to a specific subsystem - PR Description: @@ -27,7 +27,7 @@ Delete this comment and add a proper description of the changes contained in thi ## Requirements - [ ] All fixes and/or new features come with corresponding tests. -- [ ] Important design decisions have been documented in the approriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. +- [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. If this PR contains code authored by new contributors please make sure: diff --git a/src/gt4py/eve/.gitignore b/src/gt4py/eve/.gitignore deleted file mode 100644 index 050cda3ca5..0000000000 --- a/src/gt4py/eve/.gitignore +++ /dev/null @@ -1 +0,0 @@ -_version.py diff --git a/src/gt4py/eve/__init__.py b/src/gt4py/eve/__init__.py index 0b8cfa7d62..5adac47da3 100644 --- a/src/gt4py/eve/__init__.py +++ b/src/gt4py/eve/__init__.py @@ -24,8 +24,7 @@ """ -from __future__ import annotations # isort:skip - +from __future__ import annotations from .concepts import ( AnnexManager, @@ -89,15 +88,6 @@ "SymbolRef", "VType", "register_annex_user", - "# datamodels" "Coerced", - "DataModel", - "FrozenModel", - "GenericDataModel", - "Unchecked", - "concretize", - "datamodel", - "field", - "frozenmodel", # datamodels "Coerced", "DataModel", @@ -122,7 +112,7 @@ "pre_walk_values", "walk_items", "walk_values", - "# type_definition", + # type_definitions "NOTHING", "ConstrainedStr", "Enum", diff --git a/src/gt4py/eve/codegen.py b/src/gt4py/eve/codegen.py index 15fda4f3b4..3869ff313b 100644 --- a/src/gt4py/eve/codegen.py +++ b/src/gt4py/eve/codegen.py @@ -347,7 +347,7 @@ def __str__(self) -> str: class Template(Protocol): """Protocol (abstract base class) defining the Template interface. - Direct subclassess of this base class only need to implement the + Direct subclasses of this base class only need to implement the abstract methods to adapt different template engines to this interface. @@ -654,8 +654,8 @@ def apply( # redefinition of symbol Args: root: An IR node. - node_templates (optiona): see :class:`NodeDumper`. - dump_function (optiona): see :class:`NodeDumper`. + node_templates (optional): see :class:`NodeDumper`. + dump_function (optional): see :class:`NodeDumper`. ``**kwargs`` (optional): custom extra parameters forwarded to `visit_NODE_TYPE_NAME()`. Returns: diff --git a/src/gt4py/eve/datamodels/__init__.py b/src/gt4py/eve/datamodels/__init__.py index 68ddea2510..6fd9c7bb21 100644 --- a/src/gt4py/eve/datamodels/__init__.py +++ b/src/gt4py/eve/datamodels/__init__.py @@ -11,7 +11,7 @@ Data Models can be considered as enhanced `attrs `_ / `dataclasses `_ providing additional features like automatic run-time type validation. Values assigned to fields -at initialization can be validated with automatic type checkings using the +at initialization can be validated with automatic type checking using the field type definition. Custom field validation methods can also be added with the :func:`validator` decorator, and global instance validation methods with :func:`root_validator`. @@ -33,7 +33,7 @@ 1. ``__init__()``. a. If a custom ``__init__`` already exists in the class, it will not be overwritten. - It is your responsability to call ``__auto_init__`` from there to obtain + It is your responsibility to call ``__auto_init__`` from there to obtain the described behavior. b. If there is not custom ``__init__``, the one generated by datamodels will be called first. diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index d596f59cfb..1b0e995156 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -24,7 +24,7 @@ try: - # For perfomance reasons, try to use cytoolz when possible (using cython) + # For performance reasons, try to use cytoolz when possible (using cython) import cytoolz as toolz except ModuleNotFoundError: # Fall back to pure Python toolz @@ -270,7 +270,7 @@ def datamodel( @overload -def datamodel( # redefinion of unused symbol +def datamodel( # redefinition of unused symbol cls: Type[_T], /, *, @@ -289,7 +289,7 @@ def datamodel( # redefinion of unused symbol # TODO(egparedes): Use @dataclass_transform(eq_default=True, field_specifiers=("field",)) -def datamodel( # redefinion of unused symbol +def datamodel( # redefinition of unused symbol cls: Optional[Type[_T]] = None, /, *, @@ -867,7 +867,7 @@ def _substitute_typevars( def _make_counting_attr_from_attribute( field_attrib: Attribute, *, include_type: bool = False, **kwargs: Any -) -> Any: # attr.s lies a bit in some typing definitons +) -> Any: # attr.s lies a bit in some typing definitions args = [ "default", "validator", @@ -965,7 +965,7 @@ def _type_converter(value: Any) -> _T: return value if isinstance(value, type_annotation) else type_annotation(value) except Exception as error: raise TypeError( - f"Error during coertion of given value '{value}' for field '{name}'." + f"Error during coercion of given value '{value}' for field '{name}'." ) from error return _type_converter @@ -996,7 +996,7 @@ def _type_converter(value: Any) -> _T: return _make_type_converter(origin_type, name) raise exceptions.EveTypeError( - f"Automatic type coertion for {type_annotation} types is not supported." + f"Automatic type coercion for {type_annotation} types is not supported." ) @@ -1085,7 +1085,7 @@ def _make_datamodel( ) else: - # Create field converter if automatic coertion is enabled + # Create field converter if automatic coercion is enabled converter: TypeConverter = cast( TypeConverter, _make_type_converter(type_hint, qualified_field_name) if coerce_field else None, @@ -1099,7 +1099,7 @@ def _make_datamodel( if isinstance(attr_value_in_cls, _KNOWN_MUTABLE_TYPES): warnings.warn( f"'{attr_value_in_cls.__class__.__name__}' value used as default in '{cls.__name__}.{key}'.\n" - "Mutable types should not defbe normally used as field defaults (use 'default_factory' instead).", + "Mutable types should not be used as field defaults (use 'default_factory' instead).", stacklevel=_stacklevel_offset + 2, ) setattr( diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index e276f3bccf..bf44824b49 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -14,12 +14,8 @@ from __future__ import annotations -import abc as _abc import array as _array -import collections.abc as _collections_abc -import ctypes as _ctypes import dataclasses as _dataclasses -import enum as _enum import functools as _functools import inspect as _inspect import mmap as _mmap diff --git a/src/gt4py/eve/trees.py b/src/gt4py/eve/trees.py index c8e8658413..8a3cc30f4b 100644 --- a/src/gt4py/eve/trees.py +++ b/src/gt4py/eve/trees.py @@ -31,14 +31,6 @@ from .type_definitions import Enum -try: - # For performance reasons, try to use cytoolz when possible (using cython) - import cytoolz as toolz -except ModuleNotFoundError: - # Fall back to pure Python toolz - import toolz # noqa: F401 [unused-import] - - TreeKey = Union[int, str] diff --git a/src/gt4py/eve/type_validation.py b/src/gt4py/eve/type_validation.py index 613eca40b2..e150832295 100644 --- a/src/gt4py/eve/type_validation.py +++ b/src/gt4py/eve/type_validation.py @@ -311,7 +311,7 @@ def __call__( # ... # # Since this can be an arbitrary type (not something regular like a collection) there is - # no way to check if the type parameter is verifed in the actual instance. + # no way to check if the type parameter is verified in the actual instance. # The only check can be done at run-time is to verify that the value is an instance of # the original type, completely ignoring the annotation. Ideally, the static type checker # can do a better job to try figure out if the type parameter is ok ... diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 8cb68845d7..2c66d39290 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -69,7 +69,7 @@ try: - # For perfomance reasons, try to use cytoolz when possible (using cython) + # For performance reasons, try to use cytoolz when possible (using cython) import cytoolz as toolz except ModuleNotFoundError: # Fall back to pure Python toolz diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 9ce07d01bb..61756f30c9 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -230,7 +230,7 @@ def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: if self.backend is None: warnings.warn( UserWarning( - f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a perfomance backend." + f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a performance backend." ), stacklevel=2, ) From f57d6e916e17ee2ff574ba6096ccc21911d27533 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 2 Dec 2024 20:02:44 +0100 Subject: [PATCH 12/13] fix[next]: Guard diskcache creation by file lock (#1745) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The disk cache used to cache compilation in the gtfn backend has a race condition manifesting itself in `sqlite3.OperationalError: database is locked` errors if multiple python processes try to initialize the `diskcache.Cache` object concurrently. This PR fixes this by guarding the object creation by a file-based lock in the same directory as the database. While this issue occurred frequently and was observed to be fixed on distributed file systems, the lock does not guarantee correct behavior in particular for accesses to the cache (beyond opening) since the underlying SQLite database is unreliable when stored on an NFS based file system. It does however ensure correctness of concurrent cache accesses on a local file system. See more information here: https://grantjenks.com/docs/diskcache/tutorial.html#settings https://www.sqlite.org/faq.html#q5 https://github.com/tox-dev/filelock/issues/73 NFS safe locking: https://gitlab.com/warsaw/flufl.lock [Barry Warsaw / FLUFL Lock ยท GitLab](https://gitlab.com/warsaw/flufl.lock) --- .pre-commit-config.yaml | 1 + constraints.txt | 8 ++--- min-extra-requirements-test.txt | 1 + min-requirements-test.txt | 1 + pyproject.toml | 1 + requirements-dev.txt | 8 ++--- .../next/program_processors/runners/gtfn.py | 32 ++++++++++++++++--- 7 files changed, 40 insertions(+), 12 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1c3b6e693f..7e1870c67f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -102,6 +102,7 @@ repos: - devtools==0.12.2 - diskcache==5.6.3 - factory-boy==3.3.1 + - filelock==3.16.1 - frozendict==2.4.6 - gridtools-cpp==2.3.8 - importlib-resources==6.4.5 diff --git a/constraints.txt b/constraints.txt index b4b8bc00d4..f039fa2125 100644 --- a/constraints.txt +++ b/constraints.txt @@ -49,7 +49,7 @@ executing==2.1.0 # via devtools, stack-data factory-boy==3.3.1 # via gt4py (pyproject.toml), pytest-factoryboy faker==33.0.0 # via factory-boy fastjsonschema==2.20.0 # via nbformat -filelock==3.16.1 # via tox, virtualenv +filelock==3.16.1 # via gt4py (pyproject.toml), tox, virtualenv fonttools==4.55.0 # via matplotlib fparser==0.1.4 # via dace frozendict==2.4.6 # via gt4py (pyproject.toml) @@ -113,8 +113,8 @@ psutil==6.1.0 # via -r requirements-dev.in, ipykernel, pytest-xdist ptyprocess==0.7.0 # via pexpect pure-eval==0.2.3 # via stack-data pybind11==2.13.6 # via gt4py (pyproject.toml) -pydantic==2.9.2 # via bump-my-version, pydantic-settings -pydantic-core==2.23.4 # via pydantic +pydantic==2.10.0 # via bump-my-version, pydantic-settings +pydantic-core==2.27.0 # via pydantic pydantic-settings==2.6.1 # via bump-my-version pydot==3.0.2 # via tach pygments==2.18.0 # via -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx @@ -159,7 +159,7 @@ stack-data==0.6.3 # via ipython stdlib-list==0.10.0 # via tach sympy==1.13.3 # via dace tabulate==0.9.0 # via gt4py (pyproject.toml) -tach==0.14.3 # via -r requirements-dev.in +tach==0.14.4 # via -r requirements-dev.in tomli==2.1.0 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via tach tomlkit==0.13.2 # via bump-my-version diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 57c0d3969d..d7679a1f0f 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -67,6 +67,7 @@ deepdiff==5.6.0 devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 +filelock==3.0.0 frozendict==2.3 gridtools-cpp==2.3.8 hypothesis==6.0.0 diff --git a/min-requirements-test.txt b/min-requirements-test.txt index 81a1c2dea3..cf505e88d6 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -63,6 +63,7 @@ deepdiff==5.6.0 devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 +filelock==3.0.0 frozendict==2.3 gridtools-cpp==2.3.8 hypothesis==6.0.0 diff --git a/pyproject.toml b/pyproject.toml index 02d301957c..1e24094fa2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ 'devtools>=0.6', 'diskcache>=5.6.3', 'factory-boy>=3.3.0', + 'filelock>=3.0.0', 'frozendict>=2.3', 'gridtools-cpp>=2.3.8,==2.*', "importlib-resources>=5.0;python_version<'3.9'", diff --git a/requirements-dev.txt b/requirements-dev.txt index 9f95779fd5..6542be36f1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -49,7 +49,7 @@ executing==2.1.0 # via -c constraints.txt, devtools, stack-data factory-boy==3.3.1 # via -c constraints.txt, gt4py (pyproject.toml), pytest-factoryboy faker==33.0.0 # via -c constraints.txt, factory-boy fastjsonschema==2.20.0 # via -c constraints.txt, nbformat -filelock==3.16.1 # via -c constraints.txt, tox, virtualenv +filelock==3.16.1 # via -c constraints.txt, gt4py (pyproject.toml), tox, virtualenv fonttools==4.55.0 # via -c constraints.txt, matplotlib fparser==0.1.4 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) @@ -113,8 +113,8 @@ psutil==6.1.0 # via -c constraints.txt, -r requirements-dev.in, ipyk ptyprocess==0.7.0 # via -c constraints.txt, pexpect pure-eval==0.2.3 # via -c constraints.txt, stack-data pybind11==2.13.6 # via -c constraints.txt, gt4py (pyproject.toml) -pydantic==2.9.2 # via -c constraints.txt, bump-my-version, pydantic-settings -pydantic-core==2.23.4 # via -c constraints.txt, pydantic +pydantic==2.10.0 # via -c constraints.txt, bump-my-version, pydantic-settings +pydantic-core==2.27.0 # via -c constraints.txt, pydantic pydantic-settings==2.6.1 # via -c constraints.txt, bump-my-version pydot==3.0.2 # via -c constraints.txt, tach pygments==2.18.0 # via -c constraints.txt, -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx @@ -158,7 +158,7 @@ stack-data==0.6.3 # via -c constraints.txt, ipython stdlib-list==0.10.0 # via -c constraints.txt, tach sympy==1.13.3 # via -c constraints.txt, dace tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) -tach==0.14.3 # via -c constraints.txt, -r requirements-dev.in +tach==0.14.4 # via -c constraints.txt, -r requirements-dev.in tomli==2.1.0 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via -c constraints.txt, tach tomlkit==0.13.2 # via -c constraints.txt, bump-my-version diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 1f3778f227..55f479c665 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -7,11 +7,14 @@ # SPDX-License-Identifier: BSD-3-Clause import functools +import pathlib +import tempfile import warnings from typing import Any, Optional import diskcache import factory +import filelock import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators @@ -139,13 +142,34 @@ def fingerprint_compilable_program(inp: stages.CompilableProgram) -> str: class FileCache(diskcache.Cache): """ - This class extends `diskcache.Cache` to ensure the cache is closed upon deletion, - i.e. it ensures that any resources associated with the cache are properly - released when the instance is garbage collected. + This class extends `diskcache.Cache` to ensure the cache is properly + - opened when accessed by multiple processes using a file lock. This guards the creating of the + cache object, which has been reported to cause `sqlite3.OperationalError: database is locked` + errors and slow startup times when multiple processes access the cache concurrently. While this + issue occurred frequently and was observed to be fixed on distributed file systems, the lock + does not guarantee correct behavior in particular for accesses to the cache (beyond opening) + since the underlying SQLite database is unreliable when stored on an NFS based file system. + It does however ensure correctness of concurrent cache accesses on a local file system. See + #1745 for more details. + - closed upon deletion, i.e. it ensures that any resources associated with the cache are + properly released when the instance is garbage collected. """ + def __init__(self, directory: Optional[str | pathlib.Path] = None, **settings: Any) -> None: + if directory: + lock_dir = pathlib.Path(directory).parent + else: + lock_dir = pathlib.Path(tempfile.gettempdir()) + + lock = filelock.FileLock(lock_dir / "file_cache.lock") + with lock: + super().__init__(directory=directory, **settings) + + self._init_complete = True + def __del__(self) -> None: - self.close() + if getattr(self, "_init_complete", False): # skip if `__init__` didn't finished + self.close() class GTFNCompileWorkflowFactory(factory.Factory): From e5abcd20839e35c5480b512e1c2ef9b6f01c60e4 Mon Sep 17 00:00:00 2001 From: Nicoletta Farabullini <41536517+nfarabullini@users.noreply.github.com> Date: Tue, 3 Dec 2024 09:55:53 +0100 Subject: [PATCH 13/13] bug[next]: Fix codegen in gtfn for unused vertical offset provider (#1746) Providing an offest provider for a vertical dimension without using that dimension in a program, e.g. no arguments are fields defined on K, resulted in erroneous C++ code. --- .../codegens/gtfn/itir_to_gtfn_ir.py | 3 +++ tests/next_tests/integration_tests/cases.py | 10 +++++++++- .../ffront_tests/test_execution.py | 15 +++++++++++++++ .../ffront_tests/test_gt4py_builtins.py | 17 ++++++++++------- 4 files changed, 37 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 129d81d6f9..dc0012b041 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -198,6 +198,9 @@ def _collect_offset_definitions( "Mapping an offset to a horizontal dimension in unstructured is not allowed." ) # create alias from vertical offset to vertical dimension + offset_definitions[dim.value] = TagDefinition( + name=Sym(id=dim.value), alias=_vertical_dimension + ) offset_definitions[offset_name] = TagDefinition( name=Sym(id=offset_name), alias=SymRef(id=dim.value) ) diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 9fb7850666..759cd1cf1f 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -499,13 +499,21 @@ def unstructured_case( Vertex: mesh_descriptor.num_vertices, Edge: mesh_descriptor.num_edges, Cell: mesh_descriptor.num_cells, - KDim: 10, }, grid_type=common.GridType.UNSTRUCTURED, allocator=exec_alloc_descriptor.allocator, ) +@pytest.fixture +def unstructured_case_3d(unstructured_case): + return dataclasses.replace( + unstructured_case, + default_sizes={**unstructured_case.default_sizes, KDim: 10}, + offset_provider={**unstructured_case.offset_provider, "KOff": KDim}, + ) + + def _allocate_from_type( case: Case, arg_type: ts.TypeSpec, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 1a51e3667d..0d994d1b22 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -41,6 +41,7 @@ Edge, cartesian_case, unstructured_case, + unstructured_case_3d, ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, @@ -93,6 +94,20 @@ def testee(a: cases.VField) -> cases.EField: ) +def test_horizontal_only_with_3d_mesh(unstructured_case_3d): + # test field operator operating only on horizontal fields while using an offset provider + # including a vertical dimension. + @gtx.field_operator + def testee(a: cases.VField) -> cases.VField: + return a + + cases.verify_with_default_data( + unstructured_case_3d, + testee, + ref=lambda a: a, + ) + + @pytest.mark.uses_unstructured_shift def test_composed_unstructured_shift(unstructured_case): @gtx.field_operator diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 7648d34db7..ab1c625fef 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -29,6 +29,7 @@ Vertex, cartesian_case, unstructured_case, + unstructured_case_3d, ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, @@ -105,10 +106,10 @@ def reduction_ke_field( @pytest.mark.parametrize( "fop", [reduction_e_field, reduction_ek_field, reduction_ke_field], ids=lambda fop: fop.__name__ ) -def test_neighbor_sum(unstructured_case, fop): - v2e_table = unstructured_case.offset_provider["V2E"].ndarray +def test_neighbor_sum(unstructured_case_3d, fop): + v2e_table = unstructured_case_3d.offset_provider["V2E"].ndarray - edge_f = cases.allocate(unstructured_case, fop, "edge_f")() + edge_f = cases.allocate(unstructured_case_3d, fop, "edge_f")() local_dim_idx = edge_f.domain.dims.index(Edge) + 1 adv_indexing = tuple( @@ -131,10 +132,10 @@ def test_neighbor_sum(unstructured_case, fop): where=broadcasted_table != common._DEFAULT_SKIP_VALUE, ) cases.verify( - unstructured_case, + unstructured_case_3d, fop, edge_f, - out=cases.allocate(unstructured_case, fop, cases.RETURN)(), + out=cases.allocate(unstructured_case_3d, fop, cases.RETURN)(), ref=ref, ) @@ -463,11 +464,13 @@ def conditional_program( ) -def test_promotion(unstructured_case): +def test_promotion(unstructured_case_3d): @gtx.field_operator def promotion( inp1: gtx.Field[[Edge, KDim], float64], inp2: gtx.Field[[KDim], float64] ) -> gtx.Field[[Edge, KDim], float64]: return inp1 / inp2 - cases.verify_with_default_data(unstructured_case, promotion, ref=lambda inp1, inp2: inp1 / inp2) + cases.verify_with_default_data( + unstructured_case_3d, promotion, ref=lambda inp1, inp2: inp1 / inp2 + )