Skip to content

Commit

Permalink
Merge pull request #1360 from spcl/tasklet-fusion-bugfix
Browse files Browse the repository at this point in the history
TaskletFusion: Fix additional edges in case of none-connectors
  • Loading branch information
lukastruemper authored Sep 4, 2023
2 parents c5ca99a + c34de8e commit 9a8279d
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
3 changes: 3 additions & 0 deletions dace/transformation/dataflow/tasklet_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
t1.language)

for in_edge in graph.in_edges(t1):
if in_edge.src_conn is None and isinstance(in_edge.src, dace.nodes.EntryNode):
if len(new_tasklet.in_connectors) > 0:
continue
graph.add_edge(in_edge.src, in_edge.src_conn, new_tasklet, in_edge.dst_conn, in_edge.data)

for in_edge in graph.in_edges(t2):
Expand Down
44 changes: 44 additions & 0 deletions tests/transformations/tasklet_fusion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,49 @@ def test_map_with_tasklets(language: str, with_data: bool):
ref = map_with_tasklets.f(A, B)
assert (np.allclose(C, ref))

def test_none_connector():
@dace.program
def sdfg_none_connector(A: dace.float32[32], B: dace.float32[32]):
tmp = dace.define_local([32], dace.float32)
for i in dace.map[0:32]:
with dace.tasklet:
a >> tmp[i]
a = 0

tmp2 = dace.define_local([32], dace.float32)
for i in dace.map[0:32]:
with dace.tasklet:
a << A[i]
b >> tmp2[i]
b = a + 1


for i in dace.map[0:32]:
with dace.tasklet:
a << tmp[i]
b << tmp2[i]
c >> B[i]
c = a + b

sdfg = sdfg_none_connector.to_sdfg()
sdfg.simplify()
applied = sdfg.apply_transformations_repeated(MapFusion)
assert applied == 2

map_entry = None
for node in sdfg.start_state.nodes():
if isinstance(node, dace.nodes.MapEntry):
map_entry = node
break

assert map_entry is not None
assert len([edge.src_conn for edge in sdfg.start_state.out_edges(map_entry) if edge.src_conn is None]) == 1

applied = sdfg.apply_transformations_repeated(TaskletFusion)
assert applied == 2

assert sdfg.start_state.out_degree(map_entry) == 1
assert len([edge.src_conn for edge in sdfg.start_state.out_edges(map_entry) if edge.src_conn is None]) == 0

if __name__ == '__main__':
test_basic()
Expand All @@ -224,3 +267,4 @@ def test_map_with_tasklets(language: str, with_data: bool):
test_map_with_tasklets(language='Python', with_data=True)
test_map_with_tasklets(language='CPP', with_data=False)
test_map_with_tasklets(language='CPP', with_data=True)
test_none_connector()

0 comments on commit 9a8279d

Please sign in to comment.