Skip to content

Commit

Permalink
Eliminate extraneous branch-end gotos in code generation (#1355)
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun authored Aug 30, 2023
1 parent f99bbca commit c5ca99a
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 20 deletions.
77 changes: 58 additions & 19 deletions dace/codegen/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class ControlFlow:
# a string with its generated code.
dispatch_state: Callable[[SDFGState], str]

# The parent control flow block of this one, used to avoid generating extraneous ``goto``s
parent: Optional['ControlFlow']

@property
def first_state(self) -> SDFGState:
"""
Expand Down Expand Up @@ -222,11 +225,18 @@ def as_cpp(self, codegen, symbols) -> str:
out_edges = sdfg.out_edges(elem.state)
for j, e in enumerate(out_edges):
if e not in self.gotos_to_ignore:
# If this is the last generated edge and it leads
# to the next state, skip emitting goto
# Skip gotos to immediate successors
successor = None
if (j == (len(out_edges) - 1) and (i + 1) < len(self.elements)):
successor = self.elements[i + 1].first_state
# If this is the last generated edge
if j == (len(out_edges) - 1):
if (i + 1) < len(self.elements):
# If last edge leads to next state in block
successor = self.elements[i + 1].first_state
elif i == len(self.elements) - 1:
# If last edge leads to first state in next block
next_block = _find_next_block(self)
if next_block is not None:
successor = next_block.first_state

expr += elem.generate_transition(sdfg, e, successor)
else:
Expand Down Expand Up @@ -478,13 +488,14 @@ def children(self) -> List[ControlFlow]:

def _loop_from_structure(sdfg: SDFG, guard: SDFGState, enter_edge: Edge[InterstateEdge],
leave_edge: Edge[InterstateEdge], back_edges: List[Edge[InterstateEdge]],
dispatch_state: Callable[[SDFGState], str]) -> Union[ForScope, WhileScope]:
dispatch_state: Callable[[SDFGState],
str], parent_block: GeneralBlock) -> Union[ForScope, WhileScope]:
"""
Helper method that constructs the correct structured loop construct from a
set of states. Can construct for or while loops.
"""

body = GeneralBlock(dispatch_state, [], [], [], [], [], True)
body = GeneralBlock(dispatch_state, parent_block, [], [], [], [], [], True)

guard_inedges = sdfg.in_edges(guard)
increment_edges = [e for e in guard_inedges if e in back_edges]
Expand Down Expand Up @@ -535,10 +546,10 @@ def _loop_from_structure(sdfg: SDFG, guard: SDFGState, enter_edge: Edge[Intersta
# Also ignore assignments in increment edge (handled in for stmt)
body.assignments_to_ignore.append(increment_edge)

return ForScope(dispatch_state, itvar, guard, init, condition, update, body, init_edges)
return ForScope(dispatch_state, parent_block, itvar, guard, init, condition, update, body, init_edges)

# Otherwise, it is a while loop
return WhileScope(dispatch_state, guard, condition, body)
return WhileScope(dispatch_state, parent_block, guard, condition, body)


def _cases_from_branches(
Expand Down Expand Up @@ -617,6 +628,31 @@ def _child_of(node: SDFGState, parent: SDFGState, ptree: Dict[SDFGState, SDFGSta
return False


def _find_next_block(block: ControlFlow) -> Optional[ControlFlow]:
"""
Returns the immediate successor control flow block.
"""
# Find block in parent
parent = block.parent
if parent is None:
return None
ind = next(i for i, b in enumerate(parent.children) if b is block)
if ind == len(parent.children) - 1 or isinstance(parent, (IfScope, IfElseChain, SwitchCaseScope)):
# If last block, or other children are not reachable from current node (branches),
# recursively continue upwards
return _find_next_block(parent)
return parent.children[ind + 1]


def _reset_block_parents(block: ControlFlow):
"""
Fixes block parents after processing.
"""
for child in block.children:
child.parent = block
_reset_block_parents(child)


def _structured_control_flow_traversal(sdfg: SDFG,
start: SDFGState,
ptree: Dict[SDFGState, SDFGState],
Expand Down Expand Up @@ -645,7 +681,7 @@ def _structured_control_flow_traversal(sdfg: SDFG,
"""

def make_empty_block():
return GeneralBlock(dispatch_state, [], [], [], [], [], True)
return GeneralBlock(dispatch_state, parent_block, [], [], [], [], [], True)

# Traverse states in custom order
visited = set() if visited is None else visited
Expand All @@ -657,7 +693,7 @@ def make_empty_block():
if node in visited or node is stop:
continue
visited.add(node)
stateblock = SingleState(dispatch_state, node)
stateblock = SingleState(dispatch_state, parent_block, node)

oe = sdfg.out_edges(node)
if len(oe) == 0: # End state
Expand Down Expand Up @@ -708,23 +744,25 @@ def make_empty_block():
if (len(oe) == 2 and oe[0].data.condition_sympy() == sp.Not(oe[1].data.condition_sympy())):
# If without else
if oe[0].dst is mergestate:
branch_block = IfScope(dispatch_state, sdfg, node, oe[1].data.condition, cblocks[oe[1]])
branch_block = IfScope(dispatch_state, parent_block, sdfg, node, oe[1].data.condition,
cblocks[oe[1]])
elif oe[1].dst is mergestate:
branch_block = IfScope(dispatch_state, sdfg, node, oe[0].data.condition, cblocks[oe[0]])
branch_block = IfScope(dispatch_state, parent_block, sdfg, node, oe[0].data.condition,
cblocks[oe[0]])
else:
branch_block = IfScope(dispatch_state, sdfg, node, oe[0].data.condition, cblocks[oe[0]],
cblocks[oe[1]])
branch_block = IfScope(dispatch_state, parent_block, sdfg, node, oe[0].data.condition,
cblocks[oe[0]], cblocks[oe[1]])
else:
# If there are 2 or more edges (one is not the negation of the
# other):
switch = _cases_from_branches(oe, cblocks)
if switch:
# If all edges are of form "x == y" for a single x and
# integer y, it is a switch/case
branch_block = SwitchCaseScope(dispatch_state, sdfg, node, switch[0], switch[1])
branch_block = SwitchCaseScope(dispatch_state, parent_block, sdfg, node, switch[0], switch[1])
else:
# Otherwise, create if/else if/.../else goto exit chain
branch_block = IfElseChain(dispatch_state, sdfg, node,
branch_block = IfElseChain(dispatch_state, parent_block, sdfg, node,
[(e.data.condition, cblocks[e] if e in cblocks else make_empty_block())
for e in oe])
# End of branch classification
Expand All @@ -739,11 +777,11 @@ def make_empty_block():
loop_exit = None
scope = None
if ptree[oe[0].dst] == node and ptree[oe[1].dst] != node:
scope = _loop_from_structure(sdfg, node, oe[0], oe[1], back_edges, dispatch_state)
scope = _loop_from_structure(sdfg, node, oe[0], oe[1], back_edges, dispatch_state, parent_block)
body_start = oe[0].dst
loop_exit = oe[1].dst
elif ptree[oe[1].dst] == node and ptree[oe[0].dst] != node:
scope = _loop_from_structure(sdfg, node, oe[1], oe[0], back_edges, dispatch_state)
scope = _loop_from_structure(sdfg, node, oe[1], oe[0], back_edges, dispatch_state, parent_block)
body_start = oe[1].dst
loop_exit = oe[0].dst

Expand Down Expand Up @@ -836,7 +874,8 @@ def structured_control_flow_tree(sdfg: SDFG, dispatch_state: Callable[[SDFGState
if len(common_frontier) == 1:
branch_merges[state] = next(iter(common_frontier))

root_block = GeneralBlock(dispatch_state, [], [], [], [], [], True)
root_block = GeneralBlock(dispatch_state, None, [], [], [], [], [], True)
_structured_control_flow_traversal(sdfg, sdfg.start_state, ptree, branch_merges, back_edges, dispatch_state,
root_block)
_reset_block_parents(root_block)
return root_block
2 changes: 1 addition & 1 deletion dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def dispatch_state(state: SDFGState) -> str:
# If disabled, generate entire graph as general control flow block
states_topological = list(sdfg.topological_sort(sdfg.start_state))
last = states_topological[-1]
cft = cflow.GeneralBlock(dispatch_state,
cft = cflow.GeneralBlock(dispatch_state, None,
[cflow.SingleState(dispatch_state, s, s is last) for s in states_topological], [],
[], [], [], False)

Expand Down
29 changes: 29 additions & 0 deletions tests/codegen/control_flow_detection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,33 @@ def test_single_outedge_branch():
assert np.allclose(res, 2)


def test_extraneous_goto():

@dace.program
def tester(a: dace.float64[20]):
if a[0] < 0:
a[1] = 1
a[2] = 1

sdfg = tester.to_sdfg(simplify=True)
assert 'goto' not in sdfg.generate_code()[0].code


def test_extraneous_goto_nested():

@dace.program
def tester(a: dace.float64[20]):
if a[0] < 0:
if a[0] < 1:
a[1] = 1
else:
a[1] = 2
a[2] = 1

sdfg = tester.to_sdfg(simplify=True)
assert 'goto' not in sdfg.generate_code()[0].code


if __name__ == '__main__':
test_for_loop_detection()
test_invalid_for_loop_detection()
Expand All @@ -128,3 +155,5 @@ def test_single_outedge_branch():
test_edge_sympy_function('TrueFalse')
test_edge_sympy_function('SwitchCase')
test_single_outedge_branch()
test_extraneous_goto()
test_extraneous_goto_nested()

0 comments on commit c5ca99a

Please sign in to comment.