From dff7343352ddeaa2ad2902da6f7a288e4e6de485 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 26 Feb 2024 11:05:50 +0100 Subject: [PATCH] Fixed a bug in the map fusion transformation. Essentially the bug happens because the transformation considers any array with one elements as scalar, including arrays such as `(1, 1)`. This commit allows the scalar branch of the transformation to handle this kind of situation. Another solution would be to "promote" it to an array. In addition this commit adds some test cases for these kind of events. --- dace/transformation/dataflow/map_fusion.py | 14 +++++++--- tests/transformations/mapfusion_test.py | 31 ++++++++++++++++++++++ 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 9a0dd0e313..3a143467af 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -475,7 +475,7 @@ def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn, other_edges=None) for other_edge in out_edges: if other_edge is not edge: graph.remove_edge(other_edge) - mem = Memlet(data=local_name, other_subset=other_edge.data.dst_subset) + mem = Memlet(data=local_name, subset="0", other_subset=other_edge.data.dst_subset) graph.add_edge(local_node, src_connector, other_edge.dst, other_edge.dst_conn, mem) else: local_node = edge.src @@ -490,11 +490,17 @@ def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn, other_edges=None) else: local_node_out = local_node connector_out = src_connector + assert local_node_out != local_node # Ensures that no cycles are introduced by the call below. graph.add_edge(local_node, src_connector, local_node_out, connector_out, - Memlet.from_array(local_name, sdfg.arrays[local_name])) + Memlet(data=local_name, subset='0')) graph.add_edge(local_node_out, connector_out, new_dst, new_dst_conn, dcpy(edge.data)) - for e in other_edges: - graph.add_edge(local_node_out, connector_out, e.dst, e.dst_conn, dcpy(edge.data)) + + for other_edge in other_edges: + if other_edge is not edge: + mem = dcpy(edge.data) + mem.subset = "0" + mem.data = local_name + graph.add_edge(local_node_out, connector_out, other_edge.dst, other_edge.dst_conn, mem) else: # Add edge that leads to the second node graph.add_edge(local_node, src_connector, new_dst, new_dst_conn, dcpy(edge.data)) diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 653fb9d120..5f3e15d46f 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -1,4 +1,5 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +import numpy import numpy as np import os import dace @@ -163,6 +164,35 @@ def test_fusion_with_transient(): assert np.allclose(A, expected) +def fusion_with_transient_scalar( + A: dace.float32[1, 1], + B: dace.float32[1], +): + tmp1 = A * 2 + + for i, j in dace.map[0:1, 0:1]: + tmp2 = tmp1[i, j] + 4 + tmp3 = tmp1[i, j] * 5 + tmp4 = tmp1[i, j] + 1 + B[0] = tmp2 + tmp3 + tmp4 + + +def test_fusion_with_transient_scalar(): + A = np.ones((1, 1)).astype(np.float32) + B = np.zeros(1).astype(np.float32) + fusion_with_transient_scalar_prog = dace.program(fusion_with_transient_scalar) + fusion_with_transient_scalar(A, B) + sdfg = fusion_with_transient_scalar_prog.to_sdfg() + sdfg.simplify() + assert sdfg.is_valid() + expected_B = B.copy() + B[:] = np.nan + sdfg.apply_transformations_repeated(MapFusion) + csdfg = sdfg.compile() + csdfg(A=A, B=B) + assert np.allclose(B, expected_B) + + def test_fusion_with_inverted_indices(): @dace.program @@ -278,6 +308,7 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 test_multiple_fusions() test_fusion_chain() test_fusion_with_transient() + test_fusion_with_transient_scalar() test_fusion_with_inverted_indices() test_fusion_with_empty_memlet() test_fusion_with_nested_sdfg_0()