From 5ef99e05997c4295ec9991707968d727a6562e09 Mon Sep 17 00:00:00 2001 From: Lukas Truemper Date: Thu, 16 Nov 2023 22:36:43 +0100 Subject: [PATCH] TaskletFusion: Minor bugfix --- .../transformation/dataflow/tasklet_fusion.py | 2 +- tests/transformations/tasklet_fusion_test.py | 29 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/dace/transformation/dataflow/tasklet_fusion.py b/dace/transformation/dataflow/tasklet_fusion.py index d6b4a3039b..29bb014263 100644 --- a/dace/transformation/dataflow/tasklet_fusion.py +++ b/dace/transformation/dataflow/tasklet_fusion.py @@ -267,5 +267,5 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG): graph.remove_node(t1) if data is not None: graph.remove_node(data) - sdfg.remove_data(data.data, True) + graph.remove_node(t2) diff --git a/tests/transformations/tasklet_fusion_test.py b/tests/transformations/tasklet_fusion_test.py index 743010e8c9..59a7e8b36b 100644 --- a/tests/transformations/tasklet_fusion_test.py +++ b/tests/transformations/tasklet_fusion_test.py @@ -3,6 +3,7 @@ import dace from dace import dtypes from dace.transformation.dataflow import TaskletFusion, MapFusion +from dace.transformation.optimizer import Optimizer import pytest datatype = dace.float32 @@ -257,6 +258,33 @@ def sdfg_none_connector(A: dace.float32[32], B: dace.float32[32]): assert sdfg.start_state.out_degree(map_entry) == 1 assert len([edge.src_conn for edge in sdfg.start_state.out_edges(map_entry) if edge.src_conn is None]) == 0 + +def test_intermediate_transients(): + @dace.program + def sdfg_intermediate_transients(A: dace.float32[10], B: dace.float32[10]): + tmp = dace.define_local_scalar(dace.float32) + + # Use tmp twice to test removal of data + tmp = A[0] + 1 + tmp = tmp * 2 + B[0] = tmp + + + sdfg = sdfg_intermediate_transients.to_sdfg(simplify=True) + assert len([node for node in sdfg.start_state.data_nodes() if node.data == "tmp"]) == 2 + + xforms = Optimizer(sdfg=sdfg).get_pattern_matches(patterns=(TaskletFusion,)) + applied = False + for xform in xforms: + if xform.data.data == "tmp": + xform.apply(sdfg.start_state, sdfg) + applied = True + break + + assert applied + assert len([node for node in sdfg.start_state.data_nodes() if node.data == "tmp"]) == 1 + assert "tmp" in sdfg.arrays + if __name__ == '__main__': test_basic() test_same_name() @@ -268,3 +296,4 @@ def sdfg_none_connector(A: dace.float32[32], B: dace.float32[32]): test_map_with_tasklets(language='CPP', with_data=False) test_map_with_tasklets(language='CPP', with_data=True) test_none_connector() + test_intermediate_transients()