Skip to content

Commit

Permalink
Test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Nov 3, 2023
1 parent 900f7b3 commit 47d1e68
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 27 deletions.
8 changes: 8 additions & 0 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2021,6 +2021,7 @@ def add_loop(
condition_expr: str,
increment_expr: str,
loop_end_state=None,
as_block=False,
):
"""
Helper function that adds a looping state machine around a
Expand Down Expand Up @@ -2050,6 +2051,8 @@ def add_loop(
state where the loop iteration ends.
If None, sets the end state to
``loop_state`` as well.
:param as_block: Add the loop as a separate loop block. False by default, in which case the loop is added
as a traditional state machine loop.
:return: A 3-tuple of (``before_state``, generated loop guard state,
``after_state``).
"""
Expand Down Expand Up @@ -2150,6 +2153,7 @@ def compile(self, output_file=None, validate=True) -> 'CompiledSDFG':

# Importing these outside creates an import loop
from dace.codegen import codegen, compiler
from dace.sdfg import utils as sdutils

# Compute build folder path before running codegen
build_folder = self.build_folder
Expand All @@ -2170,6 +2174,10 @@ def compile(self, output_file=None, validate=True) -> 'CompiledSDFG':
# if the codegen modifies the SDFG (thereby changing its hash)
sdfg.build_folder = build_folder

# Convert any loop constructs with hierarchical loop regions into simple 1-level state machine loops.
# TODO (later): Adapt codegen to deal with hierarchical CFGs instead.
sdutils.inline_loop_blocks(sdfg)

# Rename SDFG to avoid runtime issues with clashing names
index = 0
while sdfg.is_loaded():
Expand Down
2 changes: 1 addition & 1 deletion dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context
if not dtypes.validate_name(sdfg.name):
raise InvalidSDFGError("Invalid name", sdfg, None)

all_blocks = set(sdfg.all_control_flow_blocks_recursive())
all_blocks = set(sdfg.all_control_flow_blocks())
if len(all_blocks) != len(set([s.label for s in all_blocks])):
raise InvalidSDFGError('Found multiple blocks with the same name', sdfg, None)

Expand Down
105 changes: 84 additions & 21 deletions tests/sdfg/loop_region_test.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,100 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import dace
import numpy as np
from dace.sdfg.state import LoopRegion


def test_loop_regular_for():
sdfg = dace.SDFG('inlining')
sdfg = dace.SDFG('regular_for')
state0 = sdfg.add_state('state0', is_start_block=True)
loop1 = LoopRegion(label='loop1', condition_expr='i < 10', loop_var='i', initialize_expr='i = 0',
update_expr='i = i + 1', inverted=False)
sdfg.add_node(loop1)
sdfg.add_symbol('i', dace.int32)
sdfg.add_array('A', [10], dace.float32)
state1 = loop1.add_state('state1', is_start_block=True)
state2 = loop1.add_state('state2')
loop1.add_edge(state1, state2, dace.InterstateEdge())
acc_a = state1.add_access('A')
t1 = state1.add_tasklet('t1', None, {'a'}, 'a = i')
state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[i]'))
state3 = sdfg.add_state('state3')
sdfg.add_edge(state0, loop1, dace.InterstateEdge())
sdfg.add_edge(loop1, state3, dace.InterstateEdge())

assert sdfg.is_valid()

a_validation = np.zeros([10], dtype=np.float32)
a_test = np.zeros([10], dtype=np.float32)
sdfg(A=a_test)
for i in range(10):
a_validation[i] = i
assert np.allclose(a_validation, a_test)

def test_loop_inlining_regular_while():
sdfg = dace.SDFG('inlining')

def test_loop_regular_while():
sdfg = dace.SDFG('regular_while')
state0 = sdfg.add_state('state0', is_start_block=True)
loop1 = LoopRegion(label='loop1', condition_expr='i < 10')
sdfg.add_array('A', [10], dace.float32)
sdfg.add_node(loop1)
state1 = loop1.add_state('state1', is_start_block=True)
state2 = loop1.add_state('state2')
loop1.add_edge(state1, state2, dace.InterstateEdge())
acc_a = state1.add_access('A')
t1 = state1.add_tasklet('t1', None, {'a'}, 'a = i')
state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[i]'))
sdfg.add_symbol('i', dace.int32)
loop1.add_edge(state1, state2, dace.InterstateEdge(assignments={'i': 'i + 1'}))
state3 = sdfg.add_state('state3')
sdfg.add_edge(state0, loop1, dace.InterstateEdge())
sdfg.add_edge(state0, loop1, dace.InterstateEdge(assignments={'i': '0'}))
sdfg.add_edge(loop1, state3, dace.InterstateEdge())

assert sdfg.is_valid()

a_validation = np.zeros([10], dtype=np.float32)
a_test = np.zeros([10], dtype=np.float32)
sdfg(A=a_test)
for i in range(10):
a_validation[i] = i
assert np.allclose(a_validation, a_test)


def test_loop_inlining_do_while():
sdfg = dace.SDFG('inlining')
def test_loop_do_while():
sdfg = dace.SDFG('do_while')
sdfg.add_symbol('i', dace.int32)
state0 = sdfg.add_state('state0', is_start_block=True)
loop1 = LoopRegion(label='loop1', condition_expr='i < 10', inverted=True)
sdfg.add_node(loop1)
sdfg.add_array('A', [10], dace.float32)
state1 = loop1.add_state('state1', is_start_block=True)
state2 = loop1.add_state('state2')
loop1.add_edge(state1, state2, dace.InterstateEdge())
acc_a = state1.add_access('A')
t1 = state1.add_tasklet('t1', None, {'a'}, 'a = i')
state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[i]'))
loop1.add_edge(state1, state2, dace.InterstateEdge(assignments={'i': 'i + 1'}))
state3 = sdfg.add_state('state3')
sdfg.add_edge(state0, loop1, dace.InterstateEdge())
sdfg.add_edge(state0, loop1, dace.InterstateEdge(assignments={'i': '10'}))
sdfg.add_edge(loop1, state3, dace.InterstateEdge())

assert sdfg.is_valid()

a_validation = np.zeros([11], dtype=np.float32)
a_test = np.zeros([11], dtype=np.float32)
a_validation[10] = 10
sdfg(A=a_test)
assert np.allclose(a_validation, a_test)


def test_loop_inlining_do_for():
sdfg = dace.SDFG('inlining')
def test_loop_do_for():
sdfg = dace.SDFG('do_for')
sdfg.add_symbol('i', dace.int32)
sdfg.add_array('A', [10], dace.float32)
state0 = sdfg.add_state('state0', is_start_block=True)
loop1 = LoopRegion(label='loop1', condition_expr='i < 10', loop_var='i', initialize_expr='i = 0',
update_expr='i = i + 1', inverted=True)
sdfg.add_node(loop1)
state1 = loop1.add_state('state1', is_start_block=True)
acc_a = state1.add_access('A')
t1 = state1.add_tasklet('t1', None, {'a'}, 'a = i')
state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[i]'))
state2 = loop1.add_state('state2')
loop1.add_edge(state1, state2, dace.InterstateEdge())
state3 = sdfg.add_state('state3')
Expand All @@ -64,9 +103,19 @@ def test_loop_inlining_do_for():

assert sdfg.is_valid()

a_validation = np.zeros([10], dtype=np.float32)
a_test = np.zeros([10], dtype=np.float32)
sdfg(A=a_test)
for i in range(10):
a_validation[i] = i
assert np.allclose(a_validation, a_test)


def test_tripple_nested_for():
sdfg = dace.SDFG('gemm')
sdfg.add_symbol('i', dace.int32)
sdfg.add_symbol('j', dace.int32)
sdfg.add_symbol('k', dace.int32)
N = dace.symbol('N')
M = dace.symbol('M')
K = dace.symbol('K')
Expand All @@ -87,23 +136,37 @@ def test_tripple_nested_for():
anode = comp_state.add_access('A')
bnode = comp_state.add_access('B')
tmpnode = comp_state.add_access('tmp')
tasklet = comp_state.add_tasklet('comp', {'a', 'b'}, {'tmp'}, 'tmp = a * b')
tasklet = comp_state.add_tasklet('comp', {'a', 'b'}, {'t'}, 't = a * b')
comp_state.add_memlet_path(anode, tasklet, dst_conn='a', memlet=dace.Memlet.simple('A', 'i, k'))
comp_state.add_memlet_path(bnode, tasklet, dst_conn='b', memlet=dace.Memlet.simple('B', 'k, j'))
comp_state.add_memlet_path(tasklet, tmpnode, src_conn='tmp', memlet=dace.Memlet.simple('tmp', 'i, j, k'))
comp_state.add_memlet_path(tasklet, tmpnode, src_conn='t', memlet=dace.Memlet.simple('tmp', 'i, j, k'))

tmpnode2 = reduce_state.add_access('tmp')
cnode = reduce_state.add_access('C')
red = reduce_state.add_reduce('lambda a, b: a + b', axes=[2])
reduce_state.add_edge(tmpnode2, None, red, None, dace.Memlet.simple('tmp', 'i, j, k'))
reduce_state.add_edge(red, None, cnode, None, dace.Memlet.simple('C', 'i, j'))
red = reduce_state.add_reduce('lambda a, b: a + b', (2,), 0)
reduce_state.add_edge(tmpnode2, None, red, None, dace.Memlet.simple('tmp', '0:N, 0:M, 0:K'))
reduce_state.add_edge(red, None, cnode, None, dace.Memlet.simple('C', '0:N, 0:M'))

assert sdfg.is_valid()

N = 5
M = 10
K = 8
A = np.random.rand(N, K).astype(np.float32)
B = np.random.rand(K, M).astype(np.float32)
C_test = np.random.rand(N, M).astype(np.float32)
C_validation = np.random.rand(N, M).astype(np.float32)

C_validation = A @ B

sdfg(A=A, B=B, C=C_test, N=N, M=M, K=K)

assert np.allclose(C_validation, C_test)


if __name__ == '__main__':
test_loop_regular_for()
test_loop_inlining_regular_while()
test_loop_inlining_do_while()
test_loop_inlining_do_for()
test_loop_regular_while()
test_loop_do_while()
test_loop_do_for()
test_tripple_nested_for()
10 changes: 5 additions & 5 deletions tests/transformations/control_flow_inline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,16 +164,16 @@ def test_inline_tripple_nested_for():
anode = comp_state.add_access('A')
bnode = comp_state.add_access('B')
tmpnode = comp_state.add_access('tmp')
tasklet = comp_state.add_tasklet('comp', {'a', 'b'}, {'tmp'}, 'tmp = a * b')
tasklet = comp_state.add_tasklet('comp', {'a', 'b'}, {'t'}, 't = a * b')
comp_state.add_memlet_path(anode, tasklet, dst_conn='a', memlet=dace.Memlet.simple('A', 'i, k'))
comp_state.add_memlet_path(bnode, tasklet, dst_conn='b', memlet=dace.Memlet.simple('B', 'k, j'))
comp_state.add_memlet_path(tasklet, tmpnode, src_conn='tmp', memlet=dace.Memlet.simple('tmp', 'i, j, k'))
comp_state.add_memlet_path(tasklet, tmpnode, src_conn='t', memlet=dace.Memlet.simple('tmp', 'i, j, k'))

tmpnode2 = reduce_state.add_access('tmp')
cnode = reduce_state.add_access('C')
red = reduce_state.add_reduce('lambda a, b: a + b', axes=[2])
reduce_state.add_edge(tmpnode2, None, red, None, dace.Memlet.simple('tmp', 'i, j, k'))
reduce_state.add_edge(red, None, cnode, None, dace.Memlet.simple('C', 'i, j'))
red = reduce_state.add_reduce('lambda a, b: a + b', (2,), 0)
reduce_state.add_edge(tmpnode2, None, red, None, dace.Memlet.simple('tmp', '0:N, 0:M, 0:K'))
reduce_state.add_edge(red, None, cnode, None, dace.Memlet.simple('C', '0:N, 0:M'))

sdutils.inline_loop_blocks(sdfg)

Expand Down

0 comments on commit 47d1e68

Please sign in to comment.