diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index fdf8835c7e..ceda9a4d5a 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -2021,6 +2021,7 @@ def add_loop( condition_expr: str, increment_expr: str, loop_end_state=None, + as_block=False, ): """ Helper function that adds a looping state machine around a @@ -2050,6 +2051,8 @@ def add_loop( state where the loop iteration ends. If None, sets the end state to ``loop_state`` as well. + :param as_block: Add the loop as a separate loop block. False by default, in which case the loop is added + as a traditional state machine loop. :return: A 3-tuple of (``before_state``, generated loop guard state, ``after_state``). """ @@ -2150,6 +2153,7 @@ def compile(self, output_file=None, validate=True) -> 'CompiledSDFG': # Importing these outside creates an import loop from dace.codegen import codegen, compiler + from dace.sdfg import utils as sdutils # Compute build folder path before running codegen build_folder = self.build_folder @@ -2170,6 +2174,10 @@ def compile(self, output_file=None, validate=True) -> 'CompiledSDFG': # if the codegen modifies the SDFG (thereby changing its hash) sdfg.build_folder = build_folder + # Convert any loop constructs with hierarchical loop regions into simple 1-level state machine loops. + # TODO (later): Adapt codegen to deal with hierarchical CFGs instead. + sdutils.inline_loop_blocks(sdfg) + # Rename SDFG to avoid runtime issues with clashing names index = 0 while sdfg.is_loaded(): diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index 3e402f3e25..45d38e33e2 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -201,7 +201,7 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context if not dtypes.validate_name(sdfg.name): raise InvalidSDFGError("Invalid name", sdfg, None) - all_blocks = set(sdfg.all_control_flow_blocks_recursive()) + all_blocks = set(sdfg.all_control_flow_blocks()) if len(all_blocks) != len(set([s.label for s in all_blocks])): raise InvalidSDFGError('Found multiple blocks with the same name', sdfg, None) diff --git a/tests/sdfg/loop_region_test.py b/tests/sdfg/loop_region_test.py index f17fb7fcd6..bf09de21f4 100644 --- a/tests/sdfg/loop_region_test.py +++ b/tests/sdfg/loop_region_test.py @@ -1,61 +1,100 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import dace +import numpy as np from dace.sdfg.state import LoopRegion def test_loop_regular_for(): - sdfg = dace.SDFG('inlining') + sdfg = dace.SDFG('regular_for') state0 = sdfg.add_state('state0', is_start_block=True) loop1 = LoopRegion(label='loop1', condition_expr='i < 10', loop_var='i', initialize_expr='i = 0', update_expr='i = i + 1', inverted=False) sdfg.add_node(loop1) + sdfg.add_symbol('i', dace.int32) + sdfg.add_array('A', [10], dace.float32) state1 = loop1.add_state('state1', is_start_block=True) - state2 = loop1.add_state('state2') - loop1.add_edge(state1, state2, dace.InterstateEdge()) + acc_a = state1.add_access('A') + t1 = state1.add_tasklet('t1', None, {'a'}, 'a = i') + state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[i]')) state3 = sdfg.add_state('state3') sdfg.add_edge(state0, loop1, dace.InterstateEdge()) sdfg.add_edge(loop1, state3, dace.InterstateEdge()) assert sdfg.is_valid() + a_validation = np.zeros([10], dtype=np.float32) + a_test = np.zeros([10], dtype=np.float32) + sdfg(A=a_test) + for i in range(10): + a_validation[i] = i + assert np.allclose(a_validation, a_test) -def test_loop_inlining_regular_while(): - sdfg = dace.SDFG('inlining') + +def test_loop_regular_while(): + sdfg = dace.SDFG('regular_while') state0 = sdfg.add_state('state0', is_start_block=True) loop1 = LoopRegion(label='loop1', condition_expr='i < 10') + sdfg.add_array('A', [10], dace.float32) sdfg.add_node(loop1) state1 = loop1.add_state('state1', is_start_block=True) state2 = loop1.add_state('state2') - loop1.add_edge(state1, state2, dace.InterstateEdge()) + acc_a = state1.add_access('A') + t1 = state1.add_tasklet('t1', None, {'a'}, 'a = i') + state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[i]')) + sdfg.add_symbol('i', dace.int32) + loop1.add_edge(state1, state2, dace.InterstateEdge(assignments={'i': 'i + 1'})) state3 = sdfg.add_state('state3') - sdfg.add_edge(state0, loop1, dace.InterstateEdge()) + sdfg.add_edge(state0, loop1, dace.InterstateEdge(assignments={'i': '0'})) sdfg.add_edge(loop1, state3, dace.InterstateEdge()) assert sdfg.is_valid() + a_validation = np.zeros([10], dtype=np.float32) + a_test = np.zeros([10], dtype=np.float32) + sdfg(A=a_test) + for i in range(10): + a_validation[i] = i + assert np.allclose(a_validation, a_test) + -def test_loop_inlining_do_while(): - sdfg = dace.SDFG('inlining') +def test_loop_do_while(): + sdfg = dace.SDFG('do_while') + sdfg.add_symbol('i', dace.int32) state0 = sdfg.add_state('state0', is_start_block=True) loop1 = LoopRegion(label='loop1', condition_expr='i < 10', inverted=True) sdfg.add_node(loop1) + sdfg.add_array('A', [10], dace.float32) state1 = loop1.add_state('state1', is_start_block=True) state2 = loop1.add_state('state2') - loop1.add_edge(state1, state2, dace.InterstateEdge()) + acc_a = state1.add_access('A') + t1 = state1.add_tasklet('t1', None, {'a'}, 'a = i') + state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[i]')) + loop1.add_edge(state1, state2, dace.InterstateEdge(assignments={'i': 'i + 1'})) state3 = sdfg.add_state('state3') - sdfg.add_edge(state0, loop1, dace.InterstateEdge()) + sdfg.add_edge(state0, loop1, dace.InterstateEdge(assignments={'i': '10'})) sdfg.add_edge(loop1, state3, dace.InterstateEdge()) assert sdfg.is_valid() + a_validation = np.zeros([11], dtype=np.float32) + a_test = np.zeros([11], dtype=np.float32) + a_validation[10] = 10 + sdfg(A=a_test) + assert np.allclose(a_validation, a_test) + -def test_loop_inlining_do_for(): - sdfg = dace.SDFG('inlining') +def test_loop_do_for(): + sdfg = dace.SDFG('do_for') + sdfg.add_symbol('i', dace.int32) + sdfg.add_array('A', [10], dace.float32) state0 = sdfg.add_state('state0', is_start_block=True) loop1 = LoopRegion(label='loop1', condition_expr='i < 10', loop_var='i', initialize_expr='i = 0', update_expr='i = i + 1', inverted=True) sdfg.add_node(loop1) state1 = loop1.add_state('state1', is_start_block=True) + acc_a = state1.add_access('A') + t1 = state1.add_tasklet('t1', None, {'a'}, 'a = i') + state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[i]')) state2 = loop1.add_state('state2') loop1.add_edge(state1, state2, dace.InterstateEdge()) state3 = sdfg.add_state('state3') @@ -64,9 +103,19 @@ def test_loop_inlining_do_for(): assert sdfg.is_valid() + a_validation = np.zeros([10], dtype=np.float32) + a_test = np.zeros([10], dtype=np.float32) + sdfg(A=a_test) + for i in range(10): + a_validation[i] = i + assert np.allclose(a_validation, a_test) + def test_tripple_nested_for(): sdfg = dace.SDFG('gemm') + sdfg.add_symbol('i', dace.int32) + sdfg.add_symbol('j', dace.int32) + sdfg.add_symbol('k', dace.int32) N = dace.symbol('N') M = dace.symbol('M') K = dace.symbol('K') @@ -87,23 +136,37 @@ def test_tripple_nested_for(): anode = comp_state.add_access('A') bnode = comp_state.add_access('B') tmpnode = comp_state.add_access('tmp') - tasklet = comp_state.add_tasklet('comp', {'a', 'b'}, {'tmp'}, 'tmp = a * b') + tasklet = comp_state.add_tasklet('comp', {'a', 'b'}, {'t'}, 't = a * b') comp_state.add_memlet_path(anode, tasklet, dst_conn='a', memlet=dace.Memlet.simple('A', 'i, k')) comp_state.add_memlet_path(bnode, tasklet, dst_conn='b', memlet=dace.Memlet.simple('B', 'k, j')) - comp_state.add_memlet_path(tasklet, tmpnode, src_conn='tmp', memlet=dace.Memlet.simple('tmp', 'i, j, k')) + comp_state.add_memlet_path(tasklet, tmpnode, src_conn='t', memlet=dace.Memlet.simple('tmp', 'i, j, k')) tmpnode2 = reduce_state.add_access('tmp') cnode = reduce_state.add_access('C') - red = reduce_state.add_reduce('lambda a, b: a + b', axes=[2]) - reduce_state.add_edge(tmpnode2, None, red, None, dace.Memlet.simple('tmp', 'i, j, k')) - reduce_state.add_edge(red, None, cnode, None, dace.Memlet.simple('C', 'i, j')) + red = reduce_state.add_reduce('lambda a, b: a + b', (2,), 0) + reduce_state.add_edge(tmpnode2, None, red, None, dace.Memlet.simple('tmp', '0:N, 0:M, 0:K')) + reduce_state.add_edge(red, None, cnode, None, dace.Memlet.simple('C', '0:N, 0:M')) assert sdfg.is_valid() + N = 5 + M = 10 + K = 8 + A = np.random.rand(N, K).astype(np.float32) + B = np.random.rand(K, M).astype(np.float32) + C_test = np.random.rand(N, M).astype(np.float32) + C_validation = np.random.rand(N, M).astype(np.float32) + + C_validation = A @ B + + sdfg(A=A, B=B, C=C_test, N=N, M=M, K=K) + + assert np.allclose(C_validation, C_test) + if __name__ == '__main__': test_loop_regular_for() - test_loop_inlining_regular_while() - test_loop_inlining_do_while() - test_loop_inlining_do_for() + test_loop_regular_while() + test_loop_do_while() + test_loop_do_for() test_tripple_nested_for() diff --git a/tests/transformations/control_flow_inline_test.py b/tests/transformations/control_flow_inline_test.py index 1900445ca8..ea44a8baef 100644 --- a/tests/transformations/control_flow_inline_test.py +++ b/tests/transformations/control_flow_inline_test.py @@ -164,16 +164,16 @@ def test_inline_tripple_nested_for(): anode = comp_state.add_access('A') bnode = comp_state.add_access('B') tmpnode = comp_state.add_access('tmp') - tasklet = comp_state.add_tasklet('comp', {'a', 'b'}, {'tmp'}, 'tmp = a * b') + tasklet = comp_state.add_tasklet('comp', {'a', 'b'}, {'t'}, 't = a * b') comp_state.add_memlet_path(anode, tasklet, dst_conn='a', memlet=dace.Memlet.simple('A', 'i, k')) comp_state.add_memlet_path(bnode, tasklet, dst_conn='b', memlet=dace.Memlet.simple('B', 'k, j')) - comp_state.add_memlet_path(tasklet, tmpnode, src_conn='tmp', memlet=dace.Memlet.simple('tmp', 'i, j, k')) + comp_state.add_memlet_path(tasklet, tmpnode, src_conn='t', memlet=dace.Memlet.simple('tmp', 'i, j, k')) tmpnode2 = reduce_state.add_access('tmp') cnode = reduce_state.add_access('C') - red = reduce_state.add_reduce('lambda a, b: a + b', axes=[2]) - reduce_state.add_edge(tmpnode2, None, red, None, dace.Memlet.simple('tmp', 'i, j, k')) - reduce_state.add_edge(red, None, cnode, None, dace.Memlet.simple('C', 'i, j')) + red = reduce_state.add_reduce('lambda a, b: a + b', (2,), 0) + reduce_state.add_edge(tmpnode2, None, red, None, dace.Memlet.simple('tmp', '0:N, 0:M, 0:K')) + reduce_state.add_edge(red, None, cnode, None, dace.Memlet.simple('C', '0:N, 0:M')) sdutils.inline_loop_blocks(sdfg)