Skip to content

Commit

Permalink
fix[next][dace]: Fix for nested SDFG outer data descriptor (#1726)
Browse files Browse the repository at this point in the history
Fixes a bug in lowering of let-statements. The lambda expression of a
let-statement is lowered to a nested SDFG. The result data produced in
the nested SDFG is written to temporary data allocated in the parent
SDFG. The previous lowering was directly using a copy of the inner data
descriptor for the outer data. The bug is that some symbols for array
shape and strides might not be available in the parent SDFG, so we have
to apply the symbol mapping on the outer data descriptor.

The test case `test_gtir_let_lambda_with_cond` was modified to trigger
this bug and verify the fix.
  • Loading branch information
edopao authored Nov 6, 2024
1 parent 60bb7b1 commit 6873a0e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -709,49 +709,61 @@ def _flatten_tuples(

head_state.add_edge(src_node, None, nsdfg_node, connector, memlet)

def make_temps(
output_data: gtir_builtin_translators.FieldopData,
def construct_output_for_nested_sdfg(
inner_data: gtir_builtin_translators.FieldopData,
) -> gtir_builtin_translators.FieldopData:
"""
This function will be called while traversing the result of the lambda
dataflow to setup the intermediate data nodes in the parent SDFG and
the data edges from the nested-SDFG output connectors.
This function makes a data container that lives inside a nested SDFG, denoted by `inner_data`,
available in the parent SDFG.
In order to achieve this, the data container inside the nested SDFG is marked as non-transient
(in other words, externally allocated - a requirement of the SDFG IR) and a new data container
is created within the parent SDFG, with the same properties (shape, stride, etc.) of `inner_data`
but appropriatly remapped using the symbol mapping table.
For lambda arguments that are simply returned by the lambda, the `inner_data` was already mapped
to a parent SDFG data container, therefore it can be directly accessed in the parent SDFG.
The same happens to symbols available in the lambda context but not explicitly passed as lambda
arguments, that are simply returned by the lambda: it can be directly accessed in the parent SDFG.
"""
desc = output_data.dc_node.desc(nsdfg)
if desc.transient:
# Transient nodes actually contain some result produced by the dataflow
# itself, therefore these nodes are changed to non-transient and an output
# edge will write the result from the nested-SDFG to a new intermediate
# data node in the parent context.
desc.transient = False
temp, _ = sdfg.add_temp_transient_like(desc)
connector = output_data.dc_node.data
dst_node = head_state.add_access(temp)
inner_desc = inner_data.dc_node.desc(nsdfg)
if inner_desc.transient:
# Transient data nodes only exist within the nested SDFG. In order to return some result data,
# the corresponding data container inside the nested SDFG has to be changed to non-transient,
# that is externally allocated, as required by the SDFG IR. An output edge will write the result
# from the nested-SDFG to a new intermediate data container allocated in the parent SDFG.
inner_desc.transient = False
outer, outer_desc = sdfg.add_temp_transient_like(inner_desc)
# We cannot use a copy of the inner data descriptor directly, we have to apply the symbol mapping.
dace.symbolic.safe_replace(
nsdfg_symbols_mapping,
lambda m: dace.sdfg.replace_properties_dict(outer_desc, m),
)
connector = inner_data.dc_node.data
outer_node = head_state.add_access(outer)
head_state.add_edge(
nsdfg_node, connector, dst_node, None, sdfg.make_array_memlet(temp)
nsdfg_node, connector, outer_node, None, sdfg.make_array_memlet(outer)
)
temp_field = gtir_builtin_translators.FieldopData(
dst_node, output_data.gt_dtype, output_data.local_offset
outer_data = gtir_builtin_translators.FieldopData(
outer_node, inner_data.gt_dtype, inner_data.local_offset
)
elif output_data.dc_node.data in lambda_arg_nodes:
elif inner_data.dc_node.data in lambda_arg_nodes:
# This if branch and the next one handle the non-transient result nodes.
# Non-transient nodes are just input nodes that are immediately returned
# by the lambda expression. Therefore, these nodes are already available
# in the parent context and can be directly accessed there.
temp_field = lambda_arg_nodes[output_data.dc_node.data]
outer_data = lambda_arg_nodes[inner_data.dc_node.data]
else:
dc_node = head_state.add_access(output_data.dc_node.data)
temp_field = gtir_builtin_translators.FieldopData(
dc_node, output_data.gt_dtype, output_data.local_offset
outer_node = head_state.add_access(inner_data.dc_node.data)
outer_data = gtir_builtin_translators.FieldopData(
outer_node, inner_data.gt_dtype, inner_data.local_offset
)
# Isolated access node will make validation fail.
# Isolated access nodes can be found in the join-state of an if-expression
# or in lambda expressions that just construct tuples from input arguments.
if nstate.degree(output_data.dc_node) == 0:
nstate.remove_node(output_data.dc_node)
return temp_field
if nstate.degree(inner_data.dc_node) == 0:
nstate.remove_node(inner_data.dc_node)
return outer_data

return gtx_utils.tree_map(make_temps)(lambda_result)
return gtx_utils.tree_map(construct_output_for_nested_sdfg)(lambda_result)

def visit_Literal(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def test_gtir_cond():
expr=im.op_as_fieldop("plus", domain)(
"x",
im.if_(
im.greater(gtir.SymRef(id="s1"), gtir.SymRef(id="s2")),
im.greater("s1", "s2"),
im.op_as_fieldop("plus", domain)("y", "scalar"),
im.op_as_fieldop("plus", domain)("w", "scalar"),
),
Expand Down Expand Up @@ -663,7 +663,7 @@ def test_gtir_cond_with_tuple_return():
expr=im.tuple_get(
0,
im.if_(
gtir.SymRef(id="pred"),
"pred",
im.make_tuple(im.make_tuple("x", "y"), "w"),
im.make_tuple(im.make_tuple("y", "x"), "w"),
),
Expand Down Expand Up @@ -703,10 +703,10 @@ def test_gtir_cond_nested():
body=[
gtir.SetAt(
expr=im.if_(
gtir.SymRef(id="pred_1"),
"pred_1",
im.op_as_fieldop("plus", domain)("x", 1.0),
im.if_(
gtir.SymRef(id="pred_2"),
"pred_2",
im.op_as_fieldop("plus", domain)("x", 2.0),
im.op_as_fieldop("plus", domain)("x", 3.0),
),
Expand Down Expand Up @@ -1534,7 +1534,7 @@ def test_gtir_reduce_with_cond_neighbors():
vertex_domain,
)(
im.if_(
gtir.SymRef(id="pred"),
"pred",
im.as_fieldop_neighbors("V2E_FULL", "edges", vertex_domain),
im.as_fieldop_neighbors("V2E", "edges", vertex_domain),
)
Expand Down Expand Up @@ -1756,11 +1756,7 @@ def test_gtir_let_lambda_with_cond():
gtir.SetAt(
expr=im.let("x1", "x")(
im.let("x2", im.op_as_fieldop("multiplies", domain)(2.0, "x"))(
im.if_(
gtir.SymRef(id="pred"),
im.as_fieldop(im.lambda_("a")(im.deref("a")), domain)("x1"),
im.as_fieldop(im.lambda_("a")(im.deref("a")), domain)("x2"),
)
im.if_("pred", "x1", "x2")
)
),
domain=domain,
Expand Down

0 comments on commit 6873a0e

Please sign in to comment.