Skip to content

Commit

Permalink
OTFMapFusion: Bugfix for tasklets with None connectors
Browse files Browse the repository at this point in the history
  • Loading branch information
lukastruemper committed Oct 31, 2023
1 parent 3ddd2cc commit b01f160
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 7 deletions.
17 changes: 10 additions & 7 deletions dace/transformation/dataflow/otf_map_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,14 +289,17 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
for edge in graph.edges_between(first_map_entry, node):
memlet = copy.deepcopy(edge.data)

in_connector = edge.src_conn.replace("OUT", "IN")
if in_connector in connector_mapping:
out_connector = connector_mapping[in_connector].replace("IN", "OUT")
if edge.src_conn is not None:
in_connector = edge.src_conn.replace("OUT", "IN")
if in_connector in connector_mapping:
out_connector = connector_mapping[in_connector].replace("IN", "OUT")
else:
out_connector = edge.src_conn

if out_connector not in self.second_map_entry.out_connectors:
self.second_map_entry.add_out_connector(out_connector)
else:
out_connector = edge.src_conn

if out_connector not in self.second_map_entry.out_connectors:
self.second_map_entry.add_out_connector(out_connector)
out_connector = None

graph.add_edge(self.second_map_entry, out_connector, node, edge.dst_conn, memlet)
graph.remove_edge(edge)
Expand Down
31 changes: 31 additions & 0 deletions tests/transformations/otf_map_fusion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,36 @@ def test_trivial_fusion_nested_sdfg():
assert (res == res_fused).all()


@dace.program
def trivial_fusion_none_connectors(B: dace.float64[10, 20]):
tmp = dace.define_local([10, 20], dtype=B.dtype)
for i, j in dace.map[0:10, 0:20]:
with dace.tasklet:
b >> tmp[i, j]
b = 0

for i, j in dace.map[0:10, 0:20]:
with dace.tasklet:
a << tmp[i, j]
b >> B[i, j]
b = a + 2


def test_trivial_fusion_none_connectors():
sdfg = trivial_fusion_none_connectors.to_sdfg()
sdfg.simplify()
assert count_maps(sdfg) == 2

sdfg.apply_transformations(OTFMapFusion)
assert count_maps(sdfg) == 1

B = np.zeros((10, 20))
ref = np.zeros((10, 20)) + 2

sdfg(B=B)
assert np.allclose(B, ref)


@dace.program
def undefined_subset(A: dace.float64[10], B: dace.float64[10]):
tmp = dace.define_local([10], dtype=A.dtype)
Expand Down Expand Up @@ -703,6 +733,7 @@ def test_hdiff():
test_trivial_fusion_permute()
test_trivial_fusion_not_remove_map()
test_trivial_fusion_nested_sdfg()
test_trivial_fusion_none_connectors()

# Defined subsets
test_undefined_subset()
Expand Down

0 comments on commit b01f160

Please sign in to comment.