diff --git a/dace/transformation/dataflow/redundant_array.py b/dace/transformation/dataflow/redundant_array.py index 680936dc70..bb9a9dea8d 100644 --- a/dace/transformation/dataflow/redundant_array.py +++ b/dace/transformation/dataflow/redundant_array.py @@ -230,7 +230,7 @@ def can_be_applied(self, graph: SDFGState, expr_index, sdfg, permissive=False): if not permissive: # Make sure the memlet covers the removed array - subset = copy.deepcopy(e1.data.subset) + subset = copy.deepcopy(a1_subset) subset.squeeze() shape = [sz for sz in in_desc.shape if sz != 1] if any(m != a for m, a in zip(subset.size(), shape)): @@ -552,6 +552,81 @@ def apply(self, graph, sdfg): self._make_view(sdfg, graph, in_array, out_array, e1, b_subset, b_dims_to_pop) return in_array + + # Special case of a reshaping Memelt, i.e. the Memlet between `in_array` and `out_array` + # performs reshaping. We only handle a special case in which the source has a single + # input. This is a fix for [issue 1595](https://github.com/spcl/dace/issues/1595). + # For a test see 'tests/transformations/redundant_copy_test.py::test_reshaping_with_redundant_arrays' + # as 'Case 1'. + # Furthermore, we require that `in_array` only has a single predecessor, which is an access node. + # This makes our life simpler, because we do not have to rename stuff, however, we should + # lift this restriction later. + in_full_read_out_full_written = all( + all(sssize == arraysize for sssize, arraysize in zip(subset.size(), shape)) + for subset, shape in zip([b_subset, a1_subset], [out_desc.shape, in_desc.shape]) + ) + in_out_array_access_nodes = all(isinstance(node, nodes.AccessNode) for node in (in_array, out_array)) + in_out_both_arrays = all( + isinstance(desc, (data.Scalar, data.Array)) and not isinstance(desc, data.View) + for desc in (in_desc, out_desc) + ) + is_reshaping_memlet = ( + in_full_read_out_full_written + and in_out_array_access_nodes + and in_out_both_arrays + and out_desc.shape != in_desc.shape + and out_desc.total_size == in_desc.total_size + and not e1.data.wcr + and not e1.data.wcr_nonatomic + ) + + # Not part of a reshaping Memlet but our particular implementation. + single_access_node_predecessor = ( + graph.in_degree(in_array) == 1 + and isinstance(graph.in_edges(in_array)[0].src, nodes.AccessNode) + ) + + if is_reshaping_memlet and single_access_node_predecessor: + in_edge = graph.in_edges(in_array)[0] + new_src = in_edge.src + new_src_conn = in_edge.src_conn + old_src_subset = in_edge.data.src_subset + new_dst = out_array + new_dst_conn = e1.dst_conn + old_dst_subset = b_subset + + if in_edge.data.data == in_array.data: + new_data = out_array.data + new_subset = old_dst_subset + new_other_subset = old_src_subset + else: + new_data = in_edge.data.data + new_subset = old_src_subset + new_other_subset = old_dst_subset + + graph.add_edge( + new_src, + new_src_conn, + new_dst, + new_dst_conn, + mm.Memlet( + data=new_data, + subset=new_subset, + other_subset=new_other_subset, + wcr=in_edge.data.wcr, + wcr_nonatomic=in_edge.data.wcr_nonatomic + ) + ) + + # Finally, remove in_array node + graph.remove_node(in_array) + try: + if in_array.data in sdfg.arrays: + sdfg.remove_data(in_array.data) + except ValueError: # Already in use (e.g., with Views) + pass + return + # 2. Iterate over the e2 edges and traverse the memlet tree for e2 in graph.in_edges(in_array): path = graph.memlet_tree(e2)