diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 64f052428d..79f1bd0201 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -59,6 +59,21 @@ def itir_type_as_dace_type(type_: next_typing.Type): raise NotImplementedError() +def reduction_init_value(op_name_: str, type_: Any): + if op_name_ == "plus": + init_value = type_(0) + elif op_name_ == "multiplies": + init_value = type_(1) + elif op_name_ == "minimum": + init_value = type_("inf") + elif op_name_ == "maximum": + init_value = type_("-inf") + else: + raise NotImplementedError() + + return init_value + + _MATH_BUILTINS_MAPPING = { "abs": "abs({})", "sin": "math.sin({})", @@ -139,6 +154,7 @@ class Context: state: dace.SDFGState symbol_map: dict[str, IteratorExpr | ValueExpr | SymbolExpr] reduce_limit: int + reduce_wcr: Optional[str] def __init__( self, @@ -150,6 +166,7 @@ def __init__( self.state = state self.symbol_map = symbol_map self.reduce_limit = 0 + self.reduce_wcr = None def builtin_neighbors( @@ -365,6 +382,8 @@ def visit_Lambda( value = IteratorExpr(field, indices, arg.dtype, arg.dimensions) symbol_map[param] = value context = Context(context_sdfg, context_state, symbol_map) + context.reduce_limit = prev_context.reduce_limit + context.reduce_wcr = prev_context.reduce_wcr self.context = context # Add input parameters as arrays @@ -411,7 +430,11 @@ def visit_Lambda( self.context.body.add_scalar(result_name, result.dtype, transient=True) result_access = self.context.state.add_access(result_name) self.context.state.add_edge( - result.value, None, result_access, None, dace.Memlet(f"{result.value.data}[0]") + result.value, + None, + result_access, + None, + dace.Memlet(f"{result_access.data}[0]", wcr=context.reduce_wcr), ) result = ValueExpr(value=result_access, dtype=result.dtype) else: @@ -770,21 +793,43 @@ def _visit_reduce(self, node: itir.FunCall): nreduce = self.context.body.arrays[neighbor_args[0].value.data].shape[0] nreduce_domain = {"__idx": f"0:{nreduce}"} - # set variable in context to enable dereference of neighbors in input fields + result_dtype = neighbor_args[0].dtype + self.context.body.add_scalar(result_name, result_dtype, transient=True) + + assert isinstance(fun_node.expr, itir.FunCall) + op_name = fun_node.expr.fun + assert isinstance(op_name, itir.SymRef) + + init_value = reduction_init_value(op_name.id, result_dtype) + init_state = self.context.body.add_state_before(self.context.state, "init") + init_tasklet = init_state.add_tasklet( + "init_reduce", {}, {"__out"}, f"__out = {init_value}" + ) + init_state.add_edge( + init_tasklet, + "__out", + init_state.add_access(result_name), + None, + dace.Memlet.simple(result_name, "0"), + ) + + # set variable in context to enable dereference of neighbors in input fields and WCR on reduce tasklet self.context.reduce_limit = nreduce + self.context.reduce_wcr = "lambda x, y: " + _MATH_BUILTINS_MAPPING[str(op_name)].format( + "x", "y" + ) + for i, node_arg in enumerate(node.args): if not args[i]: args[i] = self.visit(node_arg)[0] - # clear context - self.context.reduce_limit = 0 - - result_dtype = neighbor_args[0].dtype - self.context.body.add_scalar(result_name, result_dtype, transient=True) - assert isinstance(fun_node.expr, itir.FunCall) lambda_node = itir.Lambda(expr=fun_node.expr.args[1], params=fun_node.params[1:]) lambda_context, inner_inputs, inner_outputs = self.visit(lambda_node, args=args) + # clear context + self.context.reduce_limit = 0 + self.context.reduce_wcr = None + # the connectivity arrays (neighbor tables) are not needed inside the lambda SDFG neighbor_tables = filter_neighbor_tables(self.offset_provider) for conn, _ in neighbor_tables: @@ -797,11 +842,7 @@ def _visit_reduce(self, node: itir.FunCall): input_memlets = [ create_memlet_at(expr.value.data, ("__idx",)) for arg, expr in zip(node.args, args) ] - - op_name = fun_node.expr.fun - assert isinstance(op_name, itir.SymRef) - wcr_str = "lambda x, y: " + _MATH_BUILTINS_MAPPING[str(op_name)].format("x", "y") - output_memlet = dace.Memlet(data=result_name, subset="0", wcr=wcr_str) + output_memlet = dace.Memlet.simple(result_name, "0") input_mapping = {param: arg for (param, _), arg in zip(inner_inputs, input_memlets)} output_mapping = {inner_outputs[0].value.data: output_memlet} diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index f3006fc18f..0b2e529bc8 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -92,8 +92,7 @@ def fencil(edge_f: cases.EField, out: cases.VField): def test_reduction_expression_in_call(unstructured_case): if unstructured_case.backend == dace_iterator.run_dace_iterator: - # -edge_f(V2E) * tmp_nbh * 2 gets inlined with the neighbor_sum operation in the reduction in itir, - # so in addition to the skipped reason, currently itir is a lambda instead of the 'plus' operation + # -edge_f(V2E) * tmp_nbh * 2 gets inlined with the neighbor_sum operation in the reduction in itir pytest.skip("Not supported in DaCe backend: Reductions not directly on a field.") @gtx.field_operator