Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loops as primary SDFG objects #1354

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2299,10 +2299,10 @@ def visit_For(self, node: ast.For):
# Add loop to SDFG
loop_cond = '>' if ((pystr_to_symbolic(ranges[0][2]) < 0) == True) else '<'
incr = {indices[0]: '%s + %s' % (indices[0], astutils.unparse(ast_ranges[0][2]))}
_, loop_guard, loop_end = self.sdfg.add_loop(
_, loop_guard, loop_end, sdfg_loop = self.sdfg.add_loop(
laststate, first_loop_state, end_loop_state, indices[0], astutils.unparse(ast_ranges[0][0]),
'%s %s %s' % (indices[0], loop_cond, astutils.unparse(ast_ranges[0][1])), incr[indices[0]],
last_loop_state)
last_loop_state, False)

# Handle else clause
if node.orelse:
Expand All @@ -2324,19 +2324,26 @@ def visit_For(self, node: ast.For):
out_edges = self.sdfg.out_edges(next_state)
for e in out_edges:
self.sdfg.remove_edge(e)
self.sdfg.add_edge(next_state, loop_guard, dace.InterstateEdge(assignments=incr))
continue_edge = dace.InterstateEdge(assignments=incr)
sdfg_loop.continue_edges.append(continue_edge)
self.sdfg.add_edge(next_state, loop_guard, continue_edge)
break_states = self.break_states.pop()
while break_states:
next_state = break_states.pop()
out_edges = self.sdfg.out_edges(next_state)
for e in out_edges:
self.sdfg.remove_edge(e)
self.sdfg.add_edge(next_state, loop_end, dace.InterstateEdge())
break_edge = dace.InterstateEdge()
sdfg_loop.break_edges.append(break_edge)
self.sdfg.add_edge(next_state, loop_end, break_edge)
self.loop_idx -= 1

for state in body_states:
if not nx.has_path(self.sdfg.nx, loop_guard, state):
self.sdfg.remove_node(state)
else:
sdfg_loop.states.append(state)
sdfg_loop.states.append(loop_guard)
else:
raise DaceSyntaxError(self, node, 'Unsupported for-loop iterator "%s"' % iterator)

Expand Down Expand Up @@ -2407,8 +2414,8 @@ def visit_While(self, node: ast.While):
self.sdfg.add_symbol(astr, atom.dtype)

# Add loop to SDFG
_, loop_guard, loop_end = self.sdfg.add_loop(laststate, first_loop_state, end_loop_state, None, None, loop_cond,
None, last_loop_state)
_, loop_guard, loop_end, sdfg_loop = self.sdfg.add_loop(laststate, first_loop_state, end_loop_state, None, None,
loop_cond, None, last_loop_state, False)

# Connect the correct while-guard state
# Current state:
Expand Down Expand Up @@ -2442,19 +2449,26 @@ def visit_While(self, node: ast.While):
out_edges = self.sdfg.out_edges(next_state)
for e in out_edges:
self.sdfg.remove_edge(e)
self.sdfg.add_edge(next_state, begin_guard, dace.InterstateEdge())
continue_edge = dace.InterstateEdge()
sdfg_loop.continue_edges.append(continue_edge)
self.sdfg.add_edge(next_state, begin_guard, continue_edge)
break_states = self.break_states.pop()
while break_states:
next_state = break_states.pop()
out_edges = self.sdfg.out_edges(next_state)
for e in out_edges:
self.sdfg.remove_edge(e)
self.sdfg.add_edge(next_state, loop_end, dace.InterstateEdge())
break_edge = dace.InterstateEdge()
sdfg_loop.break_edges.append(break_edge)
self.sdfg.add_edge(next_state, loop_end, break_edge)
self.loop_idx -= 1

for state in body_states:
if not nx.has_path(self.sdfg.nx, end_guard, state):
self.sdfg.remove_node(state)
else:
sdfg_loop.states.append(state)
sdfg_loop.states.append(begin_guard)

def visit_Break(self, node: ast.Break):
if self.loop_idx < 0:
Expand Down
2 changes: 1 addition & 1 deletion dace/sdfg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
from dace.sdfg.sdfg import SDFG, InterstateEdge, LogicalGroup
from dace.sdfg.sdfg import SDFG, InterstateEdge, LogicalGroup, SDFGLoop

from dace.sdfg.state import SDFGState

Expand Down
78 changes: 67 additions & 11 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
from dace.distr_types import ProcessGrid, SubArray, RedistrArray
from dace.dtypes import validate_name
from dace.properties import (DebugInfoProperty, EnumProperty, ListProperty, make_properties, Property, CodeProperty,
TransformationHistProperty, OptionalSDFGReferenceProperty, DictProperty, CodeBlock)
TransformationHistProperty, OptionalSDFGReferenceProperty, DictProperty, CodeBlock,
SetProperty)
from typing import BinaryIO

# NOTE: In shapes, we try to convert strings to integers. In ranks, a string should be interpreted as data (scalar).
Expand Down Expand Up @@ -360,6 +361,37 @@ def label(self):
return self.condition.as_string + '; ' + assignments


@make_properties
class SDFGLoop(object):

init_statement = CodeProperty(optional=True, allow_none=True, default=None)
loop_condition = CodeProperty(allow_none=True, default=None)
update_statement = CodeProperty(optional=True, allow_none=True, default=None)
states = ListProperty(SDFGState, desc='States in the loop')

guard_state = Property(dtype=SDFGState, allow_none=True, default=None)
inverted = Property(dtype=bool, default=False, desc='Whether the loop is inverted (do-while style)')

continue_edges = ListProperty(element_type=InterstateEdge)
break_edges = ListProperty(element_type=InterstateEdge)
#return_edges = ListProperty(element_type=InterstateEdge)

init_edge = Property(dtype=InterstateEdge)
update_edge = Property(dtype=InterstateEdge)
loop_edge = Property(dtype=InterstateEdge)
exit_edge = Property(dtype=InterstateEdge)

def __init__(self):
self.init_statement = None
self.loop_condition = CodeBlock('True')
self.update_statement = None
self.states = []
self.guard_state = None
self.continue_edges = []
self.break_edges = []
#self.return_edges = []


@make_properties
class SDFG(OrderedDiGraph[SDFGState, InterstateEdge]):
""" The main intermediate representation of code in DaCe.
Expand All @@ -384,6 +416,7 @@ class SDFG(OrderedDiGraph[SDFGState, InterstateEdge]):
to_json=_arrays_to_json,
from_json=_arrays_from_json)
symbols = DictProperty(str, dtypes.typeclass, desc="Global symbols for this SDFG")
sdfg_loops = ListProperty(element_type=SDFGLoop)

instrument = EnumProperty(dtype=dtypes.InstrumentationType,
desc="Measure execution statistics with given method",
Expand Down Expand Up @@ -455,6 +488,7 @@ def __init__(self,
self._propagate = propagate
self._parent = parent
self.symbols = {}
self.sdfg_loops = []
self._parent_sdfg = None
self._parent_nsdfg_node = None
self._sdfg_list = [self]
Expand Down Expand Up @@ -2111,7 +2145,8 @@ def add_loop(
condition_expr: str,
increment_expr: str,
loop_end_state=None,
):
inverted=False,
) -> Tuple[SDFGState, SDFGState, SDFGState, 'SDFGLoop']:
"""
Helper function that adds a looping state machine around a
given state (or sequence of states).
Expand Down Expand Up @@ -2140,15 +2175,18 @@ def add_loop(
state where the loop iteration ends.
If None, sets the end state to
``loop_state`` as well.
:return: A 3-tuple of (``before_state``, generated loop guard state,
``after_state``).
:param inverted: If this loop is inverted, i.e., a do-while style.
:return: A 4-tuple of (``before_state``, generated loop guard state,
``after_state``, resulting ``SDFGLoop``).
"""
from dace.frontend.python.astutils import negate_expr # Avoid import loops

# Argument checks
if loop_var is None and (initialize_expr or increment_expr):
raise ValueError("Cannot initalize or increment an empty loop variable")

loop = SDFGLoop()

# Handling empty states
if loop_end_state is None:
loop_end_state = loop_state
Expand All @@ -2162,21 +2200,39 @@ def add_loop(

# Loop initialization
init = None if initialize_expr is None else {loop_var: initialize_expr}
self.add_edge(before_state, guard, InterstateEdge(assignments=init))
init_edge = InterstateEdge(assignments=init)
self.add_edge(before_state, guard, init_edge)

loop.init_edge = init_edge
loop.init_statement = CodeBlock('%s = %s' % (loop_var, initialize_expr)) if initialize_expr else None

# Loop condition
if condition_expr:
cond_ast = CodeBlock(condition_expr).code
cond_code_block = CodeBlock(condition_expr)
cond_ast = cond_code_block.code
else:
cond_ast = CodeBlock('True').code
self.add_edge(guard, loop_state, InterstateEdge(cond_ast))
self.add_edge(guard, after_state, InterstateEdge(negate_expr(cond_ast)))
cond_code_block = CodeBlock('True')
cond_ast = cond_code_block.code
loop.loop_condition = cond_code_block
loop_edge = InterstateEdge(cond_ast)
exit_edge = InterstateEdge(negate_expr(cond_ast))
self.add_edge(guard, loop_state, loop_edge)
self.add_edge(guard, after_state, exit_edge)
loop.loop_edge = loop_edge
loop.exit_edge = exit_edge

# Loop incrementation
incr = None if increment_expr is None else {loop_var: increment_expr}
self.add_edge(loop_end_state, guard, InterstateEdge(assignments=incr))
incr_edge = InterstateEdge(assignments=incr)
self.add_edge(loop_end_state, guard, incr_edge)
loop.update_statement = CodeBlock('%s = %s' % (loop_var, increment_expr)) if increment_expr else None
loop.update_edge = incr_edge

loop.guard_state = guard
loop.inverted = inverted

return before_state, guard, after_state
self.sdfg_loops.append(loop)
return before_state, guard, after_state, loop

# SDFG queries
##############################
Expand Down
6 changes: 5 additions & 1 deletion dace/transformation/dataflow/map_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def replace_param(param):
'%s < %s' % (loop_idx, replace_param(loop_to + 1)),
'%s + %s' % (loop_idx, replace_param(loop_step)))
# store as object fields for external access
self.before_state, self.guard, self.after_state = loop_result
self.before_state, self.guard, self.after_state, self.loop = loop_result
# Skip map in input edges
for edge in nstate.out_edges(map_entry):
src_node = nstate.memlet_path(edge)[0].src
Expand All @@ -96,6 +96,10 @@ def replace_param(param):
nstate.add_edge(edge.src, edge.src_conn, dst_node, None, edge.data)
nstate.remove_edge(edge)

for state in nsdfg.states():
if state != self.before_state and state != self.after_state:
self.loop.states.add(state)

# Remove nodes from dynamic map range
nstate.remove_nodes_from([e.src for e in dace.sdfg.dynamic_map_inputs(nstate, map_entry)])
# Remove scope nodes
Expand Down
5 changes: 4 additions & 1 deletion dace/transformation/interstate/move_loop_into_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def apply(self, _, sdfg: sd.SDFG):
nsdfg = helpers.nest_state_subgraph(sdfg, body, map_subgraph, full_data=True)

# replicate loop in nested sdfg
new_before, new_guard, new_after = nsdfg.sdfg.add_loop(
new_before, new_guard, new_after, new_loop = nsdfg.sdfg.add_loop(
before_state=None,
loop_state=nsdfg.sdfg.nodes()[0],
loop_end_state=None,
Expand All @@ -185,6 +185,9 @@ def apply(self, _, sdfg: sd.SDFG):
condition_expr=f'{itervar} <= {end}' if forward_loop else f'{itervar} >= {end}',
increment_expr=f'{itervar} + {step}' if forward_loop else f'{itervar} - {abs(step)}')

for state in nsdfg.sdfg.nodes():
new_loop.states.add(state)

# remove outer loop
before_guard_edge = nsdfg.sdfg.edges_between(new_before, new_guard)[0]
for e in nsdfg.sdfg.out_edges(new_guard):
Expand Down
3 changes: 2 additions & 1 deletion tests/graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def test_ordered_multidigraph(self):
def test_dfs_edges(self):

sdfg = dace.SDFG('test_dfs_edges')
before, _, _ = sdfg.add_loop(sdfg.add_state(), sdfg.add_state(), sdfg.add_state(), 'i', '0', 'i < 10', 'i + 1')
before, _, _, _ = sdfg.add_loop(sdfg.add_state(), sdfg.add_state(), sdfg.add_state(), 'i', '0', 'i < 10',
'i + 1')

visited_edges = list(sdfg.dfs_edges(before))
assert len(visited_edges) == len(set(visited_edges))
Expand Down
4 changes: 2 additions & 2 deletions tests/transformations/loop_to_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def test_need_for_tasklet():
aname, _ = sdfg.add_array('A', (10, ), dace.int32)
bname, _ = sdfg.add_array('B', (10, ), dace.int32)
body = sdfg.add_state('body')
_, _, _ = sdfg.add_loop(None, body, None, 'i', '0', 'i < 10', 'i + 1', None)
_, _, _, _ = sdfg.add_loop(None, body, None, 'i', '0', 'i < 10', 'i + 1', None)
anode = body.add_access(aname)
bnode = body.add_access(bname)
body.add_nedge(anode, bnode, dace.Memlet(data=aname, subset='i', other_subset='9 - i'))
Expand All @@ -325,7 +325,7 @@ def test_need_for_transient():
aname, _ = sdfg.add_array('A', (10, 10), dace.int32)
bname, _ = sdfg.add_array('B', (10, 10), dace.int32)
body = sdfg.add_state('body')
_, _, _ = sdfg.add_loop(None, body, None, 'i', '0', 'i < 10', 'i + 1', None)
_, _, _, _ = sdfg.add_loop(None, body, None, 'i', '0', 'i < 10', 'i + 1', None)
anode = body.add_access(aname)
bnode = body.add_access(bname)
body.add_nedge(anode, bnode, dace.Memlet(data=aname, subset='0:10, i', other_subset='0:10, 9 - i'))
Expand Down