diff --git a/dace/transformation/dataflow/otf_map_fusion.py b/dace/transformation/dataflow/otf_map_fusion.py index b2e5710942..f41e3b4e0b 100644 --- a/dace/transformation/dataflow/otf_map_fusion.py +++ b/dace/transformation/dataflow/otf_map_fusion.py @@ -289,14 +289,17 @@ def apply(self, graph: SDFGState, sdfg: SDFG): for edge in graph.edges_between(first_map_entry, node): memlet = copy.deepcopy(edge.data) - in_connector = edge.src_conn.replace("OUT", "IN") - if in_connector in connector_mapping: - out_connector = connector_mapping[in_connector].replace("IN", "OUT") + if edge.src_conn is not None: + in_connector = edge.src_conn.replace("OUT", "IN") + if in_connector in connector_mapping: + out_connector = connector_mapping[in_connector].replace("IN", "OUT") + else: + out_connector = edge.src_conn + + if out_connector not in self.second_map_entry.out_connectors: + self.second_map_entry.add_out_connector(out_connector) else: - out_connector = edge.src_conn - - if out_connector not in self.second_map_entry.out_connectors: - self.second_map_entry.add_out_connector(out_connector) + out_connector = None graph.add_edge(self.second_map_entry, out_connector, node, edge.dst_conn, memlet) graph.remove_edge(edge) diff --git a/tests/transformations/otf_map_fusion_test.py b/tests/transformations/otf_map_fusion_test.py index eb871566d1..4786901887 100644 --- a/tests/transformations/otf_map_fusion_test.py +++ b/tests/transformations/otf_map_fusion_test.py @@ -330,6 +330,36 @@ def test_trivial_fusion_nested_sdfg(): assert (res == res_fused).all() +@dace.program +def trivial_fusion_none_connectors(B: dace.float64[10, 20]): + tmp = dace.define_local([10, 20], dtype=B.dtype) + for i, j in dace.map[0:10, 0:20]: + with dace.tasklet: + b >> tmp[i, j] + b = 0 + + for i, j in dace.map[0:10, 0:20]: + with dace.tasklet: + a << tmp[i, j] + b >> B[i, j] + b = a + 2 + + +def test_trivial_fusion_none_connectors(): + sdfg = trivial_fusion_none_connectors.to_sdfg() + sdfg.simplify() + assert count_maps(sdfg) == 2 + + sdfg.apply_transformations(OTFMapFusion) + assert count_maps(sdfg) == 1 + + B = np.zeros((10, 20)) + ref = np.zeros((10, 20)) + 2 + + sdfg(B=B) + assert np.allclose(B, ref) + + @dace.program def undefined_subset(A: dace.float64[10], B: dace.float64[10]): tmp = dace.define_local([10], dtype=A.dtype) @@ -703,6 +733,7 @@ def test_hdiff(): test_trivial_fusion_permute() test_trivial_fusion_not_remove_map() test_trivial_fusion_nested_sdfg() + test_trivial_fusion_none_connectors() # Defined subsets test_undefined_subset()