Skip to content

Commit

Permalink
Fix cpp.reshape_strides(), which currently cannot handle inputs tha…
Browse files Browse the repository at this point in the history
…t tries to increase the number of dimensions after the reshape. (#1692)

Fix the issue with cpp codegen, where it currently cannot handle inputs:
```c++
cpp.reshape_strides(Range([(0, 4, 1), (0, 5, 1)]), None, None, [2, 3, 5])
```
and crashes with an index error.

Also fixes #1690 where `RedundantArray` was producing a (valid) graph
that triggered this.
  • Loading branch information
pratyai authored Nov 8, 2024
1 parent fb7f56c commit 2911a45
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 1 deletion.
3 changes: 2 additions & 1 deletion dace/codegen/targets/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,10 @@ def reshape_strides(subset, strides, original_strides, copy_shape):
dims = len(copy_shape)

reduced_tile_sizes = [ts for ts, s in zip(subset.tile_sizes, original_copy_shape) if s != 1]
reduced_tile_sizes += [1] * (dims - len(reduced_tile_sizes)) # Pad the remainder with 1s to maintain dimensions.

reshaped_copy = copy_shape + [ts for ts in subset.tile_sizes if ts != 1]
reshaped_copy[:len(copy_shape)] = [s / ts for s, ts in zip(copy_shape, reduced_tile_sizes)]
reshaped_copy[:len(copy_shape)] = [s // ts for s, ts in zip(copy_shape, reduced_tile_sizes)]

new_strides = [0] * len(reshaped_copy)
elements_remaining = functools.reduce(sp.Mul, copy_shape, 1)
Expand Down
221 changes: 221 additions & 0 deletions tests/codegen/cpp_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.

from functools import reduce
from operator import mul
from typing import Dict, Collection

import dace
from dace import SDFG, Memlet
from dace.codegen.targets import cpp
from dace.sdfg.state import SDFGState
from dace.subsets import Range
from dace.transformation.dataflow import RedundantArray


def _add_map_with_connectors(st: SDFGState, name: str, ndrange: Dict[str, str],
en_conn_bases: Collection[str] = None, ex_conn_bases: Collection[str] = None):
en, ex = st.add_map(name, ndrange)
if en_conn_bases:
for c in en_conn_bases:
en.add_in_connector(f"IN_{c}")
en.add_out_connector(f"OUT_{c}")
if ex_conn_bases:
for c in ex_conn_bases:
ex.add_in_connector(f"IN_{c}")
ex.add_out_connector(f"OUT_{c}")
return en, ex


def test_reshape_strides_multidim_array_all_dims_unit():
r = Range([(0, 0, 1), (0, 0, 1)])

# To smaller-sized shape
target_dims = [1]
assert reduce(mul, r.size_exact()) == reduce(mul, target_dims)
reshaped, strides = cpp.reshape_strides(r, None, None, target_dims)
assert reshaped == [1]
assert strides == [1]

# To equal-sized shape
target_dims = [1, 1]
assert reduce(mul, r.size_exact()) == reduce(mul, target_dims)
reshaped, strides = cpp.reshape_strides(r, None, None, target_dims)
assert reshaped == [1, 1]
assert strides == [1, 1]

# To larger-sized shape
target_dims = [1, 1, 1]
assert reduce(mul, r.size_exact()) == reduce(mul, target_dims)
reshaped, strides = cpp.reshape_strides(r, None, None, target_dims)
assert reshaped == [1, 1, 1]
assert strides == [1, 1, 1]


def test_reshape_strides_multidim_array_some_dims_unit():
r = Range([(0, 1, 1), (0, 0, 1)])

# To smaller-sized shape
target_dims = [2]
assert reduce(mul, r.size_exact()) == reduce(mul, target_dims)
reshaped, strides = cpp.reshape_strides(r, None, None, target_dims)
assert reshaped == target_dims
assert strides == [1]

# To equal-sized shape
target_dims = [2, 1]
assert reduce(mul, r.size_exact()) == reduce(mul, target_dims)
reshaped, strides = cpp.reshape_strides(r, None, None, target_dims)
assert reshaped == target_dims
assert strides == [1, 1]
# To equal-sized shape, but units first.
target_dims = [1, 2]
assert reduce(mul, r.size_exact()) == reduce(mul, target_dims)
reshaped, strides = cpp.reshape_strides(r, None, None, target_dims)
assert reshaped == target_dims
assert strides == [2, 1]

# To larger-sized shape.
target_dims = [2, 1, 1]
assert reduce(mul, r.size_exact()) == reduce(mul, target_dims)
reshaped, strides = cpp.reshape_strides(r, None, None, target_dims)
assert reshaped == target_dims
assert strides == [1, 1, 1]
# To larger-sized shape, but units first.
target_dims = [1, 1, 2]
assert reduce(mul, r.size_exact()) == reduce(mul, target_dims)
reshaped, strides = cpp.reshape_strides(r, None, None, target_dims)
assert reshaped == target_dims
assert strides == [2, 2, 1]


def test_reshape_strides_multidim_array_different_shape():
r = Range([(0, 4, 1), (0, 5, 1)])

# To smaller-sized shape
target_dims = [30]
assert reduce(mul, r.size_exact()) == reduce(mul, target_dims)
reshaped, strides = cpp.reshape_strides(r, None, None, target_dims)
assert reshaped == target_dims
assert strides == [1]

# To equal-sized shape
target_dims = [15, 2]
assert reduce(mul, r.size_exact()) == reduce(mul, target_dims)
reshaped, strides = cpp.reshape_strides(r, None, None, target_dims)
assert reshaped == target_dims
assert strides == [2, 1]

# To larger-sized shape
target_dims = [3, 5, 2]
assert reduce(mul, r.size_exact()) == reduce(mul, target_dims)
reshaped, strides = cpp.reshape_strides(r, None, None, target_dims)
assert reshaped == target_dims
assert strides == [10, 2, 1]


def test_reshape_strides_from_strided_range():
r = Range([(0, 4, 2), (0, 6, 2)])

# To smaller-sized shape
target_dims = [12]
assert reduce(mul, r.size_exact()) == reduce(mul, target_dims)
reshaped, strides = cpp.reshape_strides(r, None, None, target_dims)
assert reshaped == target_dims
assert strides == [1]

# To equal-sized shape
target_dims = [4, 3]
assert reduce(mul, r.size_exact()) == reduce(mul, target_dims)
reshaped, strides = cpp.reshape_strides(r, None, None, target_dims)
assert reshaped == target_dims
assert strides == [3, 1]

# To larger-sized shape
target_dims = [2, 3, 2]
assert reduce(mul, r.size_exact()) == reduce(mul, target_dims)
reshaped, strides = cpp.reshape_strides(r, None, None, target_dims)
assert reshaped == target_dims
assert strides == [6, 2, 1]


def test_reshape_strides_from_strided_and_offset_range():
r = Range([(10, 14, 2), (10, 16, 2)])

# To smaller-sized shape
target_dims = [12]
assert reduce(mul, r.size_exact()) == reduce(mul, target_dims)
reshaped, strides = cpp.reshape_strides(r, None, None, target_dims)
assert reshaped == target_dims
assert strides == [1]

# To equal-sized shape
target_dims = [4, 3]
assert reduce(mul, r.size_exact()) == reduce(mul, target_dims)
reshaped, strides = cpp.reshape_strides(r, None, None, target_dims)
assert reshaped == target_dims
assert strides == [3, 1]

# To larger-sized shape
target_dims = [2, 3, 2]
assert reduce(mul, r.size_exact()) == reduce(mul, target_dims)
reshaped, strides = cpp.reshape_strides(r, None, None, target_dims)
assert reshaped == target_dims
assert strides == [6, 2, 1]


def redundant_array_crashes_codegen_test_original_graph():
g = SDFG('prog')
g.add_array('A', (5, 5), dace.float32)
g.add_array('b', (1,), dace.float32, transient=True)
g.add_array('c', (5, 5), dace.float32, transient=True)

st0 = g.add_state('st0', is_start_block=True)
st = st0

# Make a single map that copies A[i, j] to a transient "scalar" b, then copies that out to a transient array
# c[i, j], then finally back to A[i, j] again.
A = st.add_access('A')
en, ex = _add_map_with_connectors(st, 'm0', {'i': '0:1', 'j': '0:1'}, ['A'], ['A'])
st.add_edge(A, None, en, 'IN_A', Memlet(expr='A[0:1, 0:1]'))
b = st.add_access('b')
st.add_edge(en, 'OUT_A', b, None, Memlet(expr='A[i, j] -> b[0]'))
c = st.add_access('c')
st.add_nedge(b, c, Memlet(expr='b[0] -> c[i, j]'))
st.add_edge(c, None, ex, 'IN_A', Memlet(expr='c[i, j] -> A[i, j]'))
A = st.add_access('A')
st.add_edge(ex, 'OUT_A', A, None, Memlet(expr='A[0:1, 0:1]'))
st0.fill_scope_connectors()

g.validate()
g.compile()
return g


def test_redundant_array_does_not_crash_codegen_but_produces_bad_graph_now():
"""
This test demonstrates the bug in CPP Codegen that the [PR](https://github.com/spcl/dace/pull/1692) fixes.
"""
g = redundant_array_crashes_codegen_test_original_graph()
g.apply_transformations(RedundantArray)
g.validate()
g.compile()

# NOTE: The produced graph still has bug. So, let's test for its existence.
assert len(g.states()) == 1
st = g.states()[0]
assert len(st.source_nodes()) == 1
src = st.source_nodes()[0]
assert len(st.out_edges(src)) == 1
e = st.out_edges(src)[0]
# This is the wrong part. These symbols are not available in this scope.
assert e.data.free_symbols == {'i', 'j'}


if __name__ == '__main__':
test_reshape_strides_multidim_array_all_dims_unit()
test_reshape_strides_multidim_array_some_dims_unit()
test_reshape_strides_multidim_array_different_shape()
test_reshape_strides_from_strided_range()
test_reshape_strides_from_strided_and_offset_range()

test_redundant_array_does_not_crash_codegen_but_produces_bad_graph_now()

0 comments on commit 2911a45

Please sign in to comment.