From b01f1605f9966709aa0431233bbe3ecb5ed4cf04 Mon Sep 17 00:00:00 2001
From: Lukas Truemper <lukas.truemper@outlook.de>
Date: Tue, 31 Oct 2023 09:32:04 +0100
Subject: [PATCH] OTFMapFusion: Bugfix for tasklets with None connectors

---
 .../transformation/dataflow/otf_map_fusion.py | 17 +++++-----
 tests/transformations/otf_map_fusion_test.py  | 31 +++++++++++++++++++
 2 files changed, 41 insertions(+), 7 deletions(-)

diff --git a/dace/transformation/dataflow/otf_map_fusion.py b/dace/transformation/dataflow/otf_map_fusion.py
index b2e5710942..f41e3b4e0b 100644
--- a/dace/transformation/dataflow/otf_map_fusion.py
+++ b/dace/transformation/dataflow/otf_map_fusion.py
@@ -289,14 +289,17 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
                     for edge in graph.edges_between(first_map_entry, node):
                         memlet = copy.deepcopy(edge.data)
 
-                        in_connector = edge.src_conn.replace("OUT", "IN")
-                        if in_connector in connector_mapping:
-                            out_connector = connector_mapping[in_connector].replace("IN", "OUT")
+                        if edge.src_conn is not None:
+                            in_connector = edge.src_conn.replace("OUT", "IN")
+                            if in_connector in connector_mapping:
+                                out_connector = connector_mapping[in_connector].replace("IN", "OUT")
+                            else:
+                                out_connector = edge.src_conn
+
+                            if out_connector not in self.second_map_entry.out_connectors:
+                                self.second_map_entry.add_out_connector(out_connector)
                         else:
-                            out_connector = edge.src_conn
-
-                        if out_connector not in self.second_map_entry.out_connectors:
-                            self.second_map_entry.add_out_connector(out_connector)
+                            out_connector = None
 
                         graph.add_edge(self.second_map_entry, out_connector, node, edge.dst_conn, memlet)
                         graph.remove_edge(edge)
diff --git a/tests/transformations/otf_map_fusion_test.py b/tests/transformations/otf_map_fusion_test.py
index eb871566d1..4786901887 100644
--- a/tests/transformations/otf_map_fusion_test.py
+++ b/tests/transformations/otf_map_fusion_test.py
@@ -330,6 +330,36 @@ def test_trivial_fusion_nested_sdfg():
     assert (res == res_fused).all()
 
 
+@dace.program
+def trivial_fusion_none_connectors(B: dace.float64[10, 20]):
+    tmp = dace.define_local([10, 20], dtype=B.dtype)
+    for i, j in dace.map[0:10, 0:20]:
+        with dace.tasklet:
+            b >> tmp[i, j]
+            b = 0
+
+    for i, j in dace.map[0:10, 0:20]:
+        with dace.tasklet:
+            a << tmp[i, j]
+            b >> B[i, j]
+            b = a + 2
+
+
+def test_trivial_fusion_none_connectors():
+    sdfg = trivial_fusion_none_connectors.to_sdfg()
+    sdfg.simplify()
+    assert count_maps(sdfg) == 2
+
+    sdfg.apply_transformations(OTFMapFusion)
+    assert count_maps(sdfg) == 1
+
+    B = np.zeros((10, 20))
+    ref = np.zeros((10, 20)) + 2
+
+    sdfg(B=B)
+    assert np.allclose(B, ref)
+
+
 @dace.program
 def undefined_subset(A: dace.float64[10], B: dace.float64[10]):
     tmp = dace.define_local([10], dtype=A.dtype)
@@ -703,6 +733,7 @@ def test_hdiff():
     test_trivial_fusion_permute()
     test_trivial_fusion_not_remove_map()
     test_trivial_fusion_nested_sdfg()
+    test_trivial_fusion_none_connectors()
 
     # Defined subsets
     test_undefined_subset()