From 6873a0e7f87e43ed717cc22611793a81afce0c93 Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 6 Nov 2024 09:54:22 +0100 Subject: [PATCH] fix[next][dace]: Fix for nested SDFG outer data descriptor (#1726) Fixes a bug in lowering of let-statements. The lambda expression of a let-statement is lowered to a nested SDFG. The result data produced in the nested SDFG is written to temporary data allocated in the parent SDFG. The previous lowering was directly using a copy of the inner data descriptor for the outer data. The bug is that some symbols for array shape and strides might not be available in the parent SDFG, so we have to apply the symbol mapping on the outer data descriptor. The test case `test_gtir_let_lambda_with_cond` was modified to trigger this bug and verify the fix. --- .../runners/dace_fieldview/gtir_sdfg.py | 66 +++++++++++-------- .../dace_tests/test_gtir_to_sdfg.py | 16 ++--- 2 files changed, 45 insertions(+), 37 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index f19f78d9d2..da940e883c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -709,49 +709,61 @@ def _flatten_tuples( head_state.add_edge(src_node, None, nsdfg_node, connector, memlet) - def make_temps( - output_data: gtir_builtin_translators.FieldopData, + def construct_output_for_nested_sdfg( + inner_data: gtir_builtin_translators.FieldopData, ) -> gtir_builtin_translators.FieldopData: """ - This function will be called while traversing the result of the lambda - dataflow to setup the intermediate data nodes in the parent SDFG and - the data edges from the nested-SDFG output connectors. + This function makes a data container that lives inside a nested SDFG, denoted by `inner_data`, + available in the parent SDFG. + In order to achieve this, the data container inside the nested SDFG is marked as non-transient + (in other words, externally allocated - a requirement of the SDFG IR) and a new data container + is created within the parent SDFG, with the same properties (shape, stride, etc.) of `inner_data` + but appropriatly remapped using the symbol mapping table. + For lambda arguments that are simply returned by the lambda, the `inner_data` was already mapped + to a parent SDFG data container, therefore it can be directly accessed in the parent SDFG. + The same happens to symbols available in the lambda context but not explicitly passed as lambda + arguments, that are simply returned by the lambda: it can be directly accessed in the parent SDFG. """ - desc = output_data.dc_node.desc(nsdfg) - if desc.transient: - # Transient nodes actually contain some result produced by the dataflow - # itself, therefore these nodes are changed to non-transient and an output - # edge will write the result from the nested-SDFG to a new intermediate - # data node in the parent context. - desc.transient = False - temp, _ = sdfg.add_temp_transient_like(desc) - connector = output_data.dc_node.data - dst_node = head_state.add_access(temp) + inner_desc = inner_data.dc_node.desc(nsdfg) + if inner_desc.transient: + # Transient data nodes only exist within the nested SDFG. In order to return some result data, + # the corresponding data container inside the nested SDFG has to be changed to non-transient, + # that is externally allocated, as required by the SDFG IR. An output edge will write the result + # from the nested-SDFG to a new intermediate data container allocated in the parent SDFG. + inner_desc.transient = False + outer, outer_desc = sdfg.add_temp_transient_like(inner_desc) + # We cannot use a copy of the inner data descriptor directly, we have to apply the symbol mapping. + dace.symbolic.safe_replace( + nsdfg_symbols_mapping, + lambda m: dace.sdfg.replace_properties_dict(outer_desc, m), + ) + connector = inner_data.dc_node.data + outer_node = head_state.add_access(outer) head_state.add_edge( - nsdfg_node, connector, dst_node, None, sdfg.make_array_memlet(temp) + nsdfg_node, connector, outer_node, None, sdfg.make_array_memlet(outer) ) - temp_field = gtir_builtin_translators.FieldopData( - dst_node, output_data.gt_dtype, output_data.local_offset + outer_data = gtir_builtin_translators.FieldopData( + outer_node, inner_data.gt_dtype, inner_data.local_offset ) - elif output_data.dc_node.data in lambda_arg_nodes: + elif inner_data.dc_node.data in lambda_arg_nodes: # This if branch and the next one handle the non-transient result nodes. # Non-transient nodes are just input nodes that are immediately returned # by the lambda expression. Therefore, these nodes are already available # in the parent context and can be directly accessed there. - temp_field = lambda_arg_nodes[output_data.dc_node.data] + outer_data = lambda_arg_nodes[inner_data.dc_node.data] else: - dc_node = head_state.add_access(output_data.dc_node.data) - temp_field = gtir_builtin_translators.FieldopData( - dc_node, output_data.gt_dtype, output_data.local_offset + outer_node = head_state.add_access(inner_data.dc_node.data) + outer_data = gtir_builtin_translators.FieldopData( + outer_node, inner_data.gt_dtype, inner_data.local_offset ) # Isolated access node will make validation fail. # Isolated access nodes can be found in the join-state of an if-expression # or in lambda expressions that just construct tuples from input arguments. - if nstate.degree(output_data.dc_node) == 0: - nstate.remove_node(output_data.dc_node) - return temp_field + if nstate.degree(inner_data.dc_node) == 0: + nstate.remove_node(inner_data.dc_node) + return outer_data - return gtx_utils.tree_map(make_temps)(lambda_result) + return gtx_utils.tree_map(construct_output_for_nested_sdfg)(lambda_result) def visit_Literal( self, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index dea9f2879b..cc72adae4f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -621,7 +621,7 @@ def test_gtir_cond(): expr=im.op_as_fieldop("plus", domain)( "x", im.if_( - im.greater(gtir.SymRef(id="s1"), gtir.SymRef(id="s2")), + im.greater("s1", "s2"), im.op_as_fieldop("plus", domain)("y", "scalar"), im.op_as_fieldop("plus", domain)("w", "scalar"), ), @@ -663,7 +663,7 @@ def test_gtir_cond_with_tuple_return(): expr=im.tuple_get( 0, im.if_( - gtir.SymRef(id="pred"), + "pred", im.make_tuple(im.make_tuple("x", "y"), "w"), im.make_tuple(im.make_tuple("y", "x"), "w"), ), @@ -703,10 +703,10 @@ def test_gtir_cond_nested(): body=[ gtir.SetAt( expr=im.if_( - gtir.SymRef(id="pred_1"), + "pred_1", im.op_as_fieldop("plus", domain)("x", 1.0), im.if_( - gtir.SymRef(id="pred_2"), + "pred_2", im.op_as_fieldop("plus", domain)("x", 2.0), im.op_as_fieldop("plus", domain)("x", 3.0), ), @@ -1534,7 +1534,7 @@ def test_gtir_reduce_with_cond_neighbors(): vertex_domain, )( im.if_( - gtir.SymRef(id="pred"), + "pred", im.as_fieldop_neighbors("V2E_FULL", "edges", vertex_domain), im.as_fieldop_neighbors("V2E", "edges", vertex_domain), ) @@ -1756,11 +1756,7 @@ def test_gtir_let_lambda_with_cond(): gtir.SetAt( expr=im.let("x1", "x")( im.let("x2", im.op_as_fieldop("multiplies", domain)(2.0, "x"))( - im.if_( - gtir.SymRef(id="pred"), - im.as_fieldop(im.lambda_("a")(im.deref("a")), domain)("x1"), - im.as_fieldop(im.lambda_("a")(im.deref("a")), domain)("x2"), - ) + im.if_("pred", "x1", "x2") ) ), domain=domain,