Skip to content

Commit

Permalink
This _should_ fix the issue that is reported in #1595.
Browse files Browse the repository at this point in the history
It essentially creates a special case for it and applies then the correct way.
This is not very good but it works and I have no better solution.
  • Loading branch information
philip-paul-mueller committed Jun 19, 2024
1 parent 3f11fea commit 742568d
Showing 1 changed file with 76 additions and 1 deletion.
77 changes: 76 additions & 1 deletion dace/transformation/dataflow/redundant_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def can_be_applied(self, graph: SDFGState, expr_index, sdfg, permissive=False):

if not permissive:
# Make sure the memlet covers the removed array
subset = copy.deepcopy(e1.data.subset)
subset = copy.deepcopy(a1_subset)
subset.squeeze()
shape = [sz for sz in in_desc.shape if sz != 1]
if any(m != a for m, a in zip(subset.size(), shape)):
Expand Down Expand Up @@ -552,6 +552,81 @@ def apply(self, graph, sdfg):
self._make_view(sdfg, graph, in_array, out_array, e1, b_subset, b_dims_to_pop)
return in_array


# Special case of a reshaping Memelt, i.e. the Memlet between `in_array` and `out_array`
# performs reshaping. We only handle a special case in which the source has a single
# input. This is a fix for [issue 1595](https://github.com/spcl/dace/issues/1595).
# For a test see 'tests/transformations/redundant_copy_test.py::test_reshaping_with_redundant_arrays'
# as 'Case 1'.
# Furthermore, we require that `in_array` only has a single predecessor, which is an access node.
# This makes our life simpler, because we do not have to rename stuff, however, we should
# lift this restriction later.
in_full_read_out_full_written = all(
all(sssize == arraysize for sssize, arraysize in zip(subset.size(), shape))
for subset, shape in zip([b_subset, a1_subset], [out_desc.shape, in_desc.shape])
)
in_out_array_access_nodes = all(isinstance(node, nodes.AccessNode) for node in (in_array, out_array))
in_out_both_arrays = all(
isinstance(desc, (data.Scalar, data.Array)) and not isinstance(desc, data.View)
for desc in (in_desc, out_desc)
)
is_reshaping_memlet = (
in_full_read_out_full_written
and in_out_array_access_nodes
and in_out_both_arrays
and out_desc.shape != in_desc.shape
and out_desc.total_size == in_desc.total_size
and not e1.data.wcr
and not e1.data.wcr_nonatomic
)

# Not part of a reshaping Memlet but our particular implementation.
single_access_node_predecessor = (
graph.in_degree(in_array) == 1
and isinstance(graph.in_edges(in_array)[0].src, nodes.AccessNode)
)

if is_reshaping_memlet and single_access_node_predecessor:
in_edge = graph.in_edges(in_array)[0]
new_src = in_edge.src
new_src_conn = in_edge.src_conn
old_src_subset = in_edge.data.src_subset
new_dst = out_array
new_dst_conn = e1.dst_conn
old_dst_subset = b_subset

if in_edge.data.data == in_array.data:
new_data = out_array.data
new_subset = old_dst_subset
new_other_subset = old_src_subset
else:
new_data = in_edge.data.data
new_subset = old_src_subset
new_other_subset = old_dst_subset

graph.add_edge(
new_src,
new_src_conn,
new_dst,
new_dst_conn,
mm.Memlet(
data=new_data,
subset=new_subset,
other_subset=new_other_subset,
wcr=in_edge.data.wcr,
wcr_nonatomic=in_edge.data.wcr_nonatomic
)
)

# Finally, remove in_array node
graph.remove_node(in_array)
try:
if in_array.data in sdfg.arrays:
sdfg.remove_data(in_array.data)
except ValueError: # Already in use (e.g., with Views)
pass
return

# 2. Iterate over the e2 edges and traverse the memlet tree
for e2 in graph.in_edges(in_array):
path = graph.memlet_tree(e2)
Expand Down

0 comments on commit 742568d

Please sign in to comment.