Skip to content

Commit

Permalink
add test case for sdfg transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Dec 13, 2024
1 parent 8f0e515 commit f01d291
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1008,9 +1008,10 @@ def map_strides(edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode) -
inner_desc.set_shape(inner_desc.shape, new_strides)

for stride in new_strides:
for sym in stride.free_symbols:
nsdfg_node.sdfg.add_symbol(str(sym), sym.dtype)
nsdfg_node.symbol_mapping |= {str(sym): sym}
if isinstance(stride, dace.symbolic.symbol):
for sym in stride.free_symbols:
nsdfg_node.sdfg.add_symbol(str(sym), sym.dtype)
nsdfg_node.symbol_mapping |= {str(sym): sym}

for inner_state in nsdfg_node.sdfg.states():
for inner_node in inner_state.data_nodes():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
import dace


def _make_test_data(names: list[str]) -> dict[str, np.ndarray]:
return {name: np.array(np.random.rand(10), dtype=np.float64, copy=True) for name in names}


def _make_test_sdfg(
output_name: str = "G",
input_name: str = "G",
Expand Down Expand Up @@ -262,3 +258,92 @@ def test_map_buffer_elimination_not_apply():
validate_all=True,
)
assert count == 0


def test_map_buffer_elimination_with_nested_sdfgs():
"""
After removing a transient connected to a nested SDFG node, ensure that the strides
are propagated to the arrays in nested SDFG.
"""

stride1, stride2, stride3 = [dace.symbol(f"stride{i}", dace.int32) for i in range(3)]

# top-level sdfg
sdfg = dace.SDFG(util.unique_name("map_buffer"))
inp, inp_desc = sdfg.add_array("__inp", (10,), dace.float64)
out, out_desc = sdfg.add_array(
"__out", (10, 10, 10), dace.float64, strides=(stride1, stride2, stride3)
)
tmp, _ = sdfg.add_temp_transient_like(out_desc)
state = sdfg.add_state()
tmp_node = state.add_access(tmp)

nsdfg1 = dace.SDFG(util.unique_name("map_buffer"))
inp1, inp1_desc = nsdfg1.add_array("__inp", (10,), dace.float64)
out1, out1_desc = nsdfg1.add_array("__out", (10, 10), dace.float64)
tmp1, _ = nsdfg1.add_temp_transient_like(out1_desc)
state1 = nsdfg1.add_state()
tmp1_node = state1.add_access(tmp1)

nsdfg2 = dace.SDFG(util.unique_name("map_buffer"))
inp2, _ = nsdfg2.add_array("__inp", (10,), dace.float64)
out2, out2_desc = nsdfg2.add_array("__out", (10,), dace.float64)
tmp2, _ = nsdfg2.add_temp_transient_like(out2_desc)
state2 = nsdfg2.add_state()
tmp2_node = state2.add_access(tmp2)

state2.add_mapped_tasklet(
"broadcast2",
map_ranges={"__i": "0:10"},
code="__oval = __ival + 1.0",
inputs={
"__ival": dace.Memlet(f"{inp2}[__i]"),
},
outputs={
"__oval": dace.Memlet(f"{tmp2}[__i]"),
},
output_nodes={tmp2_node},
external_edges=True,
)
state2.add_nedge(tmp2_node, state2.add_access(out2), dace.Memlet.from_array(out2, out2_desc))

nsdfg2_node = state1.add_nested_sdfg(nsdfg2, nsdfg1, inputs={"__inp"}, outputs={"__out"})
me1, mx1 = state1.add_map("broadcast1", ndrange={"__i": "0:10"})
state1.add_memlet_path(
state1.add_access(inp1),
me1,
nsdfg2_node,
dst_conn="__inp",
memlet=dace.Memlet.from_array(inp1, inp1_desc),
)
state1.add_memlet_path(
nsdfg2_node, mx1, tmp1_node, src_conn="__out", memlet=dace.Memlet(f"{tmp1}[__i, 0:10]")
)
state1.add_nedge(tmp1_node, state1.add_access(out1), dace.Memlet.from_array(out1, out1_desc))

nsdfg1_node = state.add_nested_sdfg(nsdfg1, sdfg, inputs={"__inp"}, outputs={"__out"})
me, mx = state.add_map("broadcast", ndrange={"__i": "0:10"})
state.add_memlet_path(
state.add_access(inp),
me,
nsdfg1_node,
dst_conn="__inp",
memlet=dace.Memlet.from_array(inp, inp_desc),
)
state.add_memlet_path(
nsdfg1_node, mx, tmp_node, src_conn="__out", memlet=dace.Memlet(f"{tmp}[__i, 0:10, 0:10]")
)
state.add_nedge(tmp_node, state.add_access(out), dace.Memlet.from_array(out, out_desc))

sdfg.validate()

count = sdfg.apply_transformations_repeated(
gtx_transformations.GT4PyMapBufferElimination(
assume_pointwise=False,
),
validate=True,
validate_all=True,
)
assert count == 3
assert out1_desc.strides == out_desc.strides[1:]
assert out2_desc.strides == out_desc.strides[2:]

0 comments on commit f01d291

Please sign in to comment.