Skip to content

Commit

Permalink
TaskletFusion: Minor bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
lukastruemper committed Nov 16, 2023
1 parent d345c59 commit 276348f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
4 changes: 3 additions & 1 deletion dace/transformation/dataflow/tasklet_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,5 +267,7 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
graph.remove_node(t1)
if data is not None:
graph.remove_node(data)
sdfg.remove_data(data.data, True)

graph.remove_node(t2)

sdfg.validate()
29 changes: 29 additions & 0 deletions tests/transformations/tasklet_fusion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dace
from dace import dtypes
from dace.transformation.dataflow import TaskletFusion, MapFusion
from dace.transformation.optimizer import Optimizer
import pytest

datatype = dace.float32
Expand Down Expand Up @@ -257,6 +258,33 @@ def sdfg_none_connector(A: dace.float32[32], B: dace.float32[32]):
assert sdfg.start_state.out_degree(map_entry) == 1
assert len([edge.src_conn for edge in sdfg.start_state.out_edges(map_entry) if edge.src_conn is None]) == 0


def test_intermediate_transients():
@dace.program
def sdfg_intermediate_transients(A: dace.float32[10], B: dace.float32[10]):
tmp = dace.define_local_scalar(dace.float32)

# Use tmp twice to test removal of data
tmp = A[0] + 1
tmp = tmp * 2
B[0] = tmp


sdfg = sdfg_intermediate_transients.to_sdfg(simplify=True)
assert len([node for node in sdfg.start_state.data_nodes() if node.data == "tmp"]) == 2

xforms = Optimizer(sdfg=sdfg).get_pattern_matches(patterns=(TaskletFusion,))
applied = False
for xform in xforms:
if xform.data.data == "tmp":
xform.apply(sdfg.start_state, sdfg)
applied = True
break

assert applied
assert len([node for node in sdfg.start_state.data_nodes() if node.data == "tmp"]) == 1
assert "tmp" in sdfg.arrays

if __name__ == '__main__':
test_basic()
test_same_name()
Expand All @@ -268,3 +296,4 @@ def sdfg_none_connector(A: dace.float32[32], B: dace.float32[32]):
test_map_with_tasklets(language='CPP', with_data=False)
test_map_with_tasklets(language='CPP', with_data=True)
test_none_connector()
test_intermediate_transients()

0 comments on commit 276348f

Please sign in to comment.