Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed a bug in the map fusion transformation. #1535

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions dace/transformation/dataflow/map_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn, other_edges=None)
for other_edge in out_edges:
if other_edge is not edge:
graph.remove_edge(other_edge)
mem = Memlet(data=local_name, other_subset=other_edge.data.dst_subset)
mem = Memlet(data=local_name, subset="0", other_subset=other_edge.data.dst_subset)
graph.add_edge(local_node, src_connector, other_edge.dst, other_edge.dst_conn, mem)
else:
local_node = edge.src
Expand All @@ -490,11 +490,17 @@ def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn, other_edges=None)
else:
local_node_out = local_node
connector_out = src_connector
assert local_node_out != local_node # Ensures that no cycles are introduced by the call below.
graph.add_edge(local_node, src_connector, local_node_out, connector_out,
Memlet.from_array(local_name, sdfg.arrays[local_name]))
Memlet(data=local_name, subset='0'))
graph.add_edge(local_node_out, connector_out, new_dst, new_dst_conn, dcpy(edge.data))
for e in other_edges:
graph.add_edge(local_node_out, connector_out, e.dst, e.dst_conn, dcpy(edge.data))

for other_edge in other_edges:
if other_edge is not edge:
mem = dcpy(edge.data)
mem.subset = "0"
mem.data = local_name
graph.add_edge(local_node_out, connector_out, other_edge.dst, other_edge.dst_conn, mem)
else:
# Add edge that leads to the second node
graph.add_edge(local_node, src_connector, new_dst, new_dst_conn, dcpy(edge.data))
Expand Down
31 changes: 31 additions & 0 deletions tests/transformations/mapfusion_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
import numpy
import numpy as np
import os
import dace
Expand Down Expand Up @@ -163,6 +164,35 @@ def test_fusion_with_transient():
assert np.allclose(A, expected)


def fusion_with_transient_scalar(
A: dace.float32[1, 1],
B: dace.float32[1],
):
tmp1 = A * 2

for i, j in dace.map[0:1, 0:1]:
tmp2 = tmp1[i, j] + 4
tmp3 = tmp1[i, j] * 5
tmp4 = tmp1[i, j] + 1
B[0] = tmp2 + tmp3 + tmp4


def test_fusion_with_transient_scalar():
A = np.ones((1, 1)).astype(np.float32)
B = np.zeros(1).astype(np.float32)
fusion_with_transient_scalar_prog = dace.program(fusion_with_transient_scalar)
fusion_with_transient_scalar(A, B)
sdfg = fusion_with_transient_scalar_prog.to_sdfg()
sdfg.simplify()
assert sdfg.is_valid()
expected_B = B.copy()
B[:] = np.nan
sdfg.apply_transformations_repeated(MapFusion)
csdfg = sdfg.compile()
csdfg(A=A, B=B)
assert np.allclose(B, expected_B)


def test_fusion_with_inverted_indices():

@dace.program
Expand Down Expand Up @@ -278,6 +308,7 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3
test_multiple_fusions()
test_fusion_chain()
test_fusion_with_transient()
test_fusion_with_transient_scalar()
test_fusion_with_inverted_indices()
test_fusion_with_empty_memlet()
test_fusion_with_nested_sdfg_0()
Expand Down
Loading