Skip to content

Commit

Permalink
Conditional Blocks (#1666)
Browse files Browse the repository at this point in the history
This is a continuation of #1617
(superseded and closed by this PR), with a lot of the work being done by
@luca-patrignani.

# Conditional Blocks
This PR implements Conditional Blocks, which are a native way of
semantically expressing conditional branching in an SDFG. This replaces
the traditional "state machine only" way of expressing conditional
branching, with two main goals:
1. **Simplify SDFG analysis and optimization by clearly exposing
conditional branching.** Previously, detecting and treating conditional
branches required expensive analysis of the control flow graph
structure, which had to be performed repeatedly and was error prone. By
contrast, Conditional Blocks can be generated by a frontend using
semantic information from the source language, entirely circumventing
this step.
2. **Address code generation issues.** Code generation relies on a
series of control flow detections to generate appropriate code that is
not full of `goto` statements for each state transition. However, just
as in the above issue, this process is error prone and often leads to
invalid code being generated for complex control flow constructs (e.g.,
conditionals inside of loops with conditional break, continue, return,
etc.). By exposing _all_ regular control flow (i.e., loops and
conditional branching) with native SDFG constructs, this step can be
skipped in code generation.

### Anatomy of Conditional Blocks
`ConditionalBlock`s are a type of `ControlFlowBlock` which contains a
series of **branches**. Each branch is represented by a full
`ControlFlowRegion` and has a condition in the form of a `CodeBlock`
attached to it. When a `ConditionalBlock` is executed, the conditions
are checked in the insertion order of the branches, and if a matching
condition was found, that branch (and only that branch) is executed.
When the executed branch finishes executing, the `ConditionalBlock`'s
successor is next. If no condition matches, no branch is executed.

The condition for a single branch at a time may be `None`, which
represents a wildcard or `else` case that is executed if no conditions
match.

### Code Generation Changes
Code generation (when using this feature) is drastically simplified with
respect to control flow: no more control flow detection is performed.
Instead, regular control flow constructs are only generated from the new
native SDFG constructs
([`LoopRegion`s](#1475) and
`ConditionalBlock`s), and any other state transition is either only used
for sequential ordering (unconditional transitions to a single, direct
successor), or leads to a `goto`. This makes code generation
significantly less error prone and simpler to work with.

### Compatibility
This feature is implemented minimally invasive and with full backwards
compatibility for now.
Just as with [`LoopRegion`s](#1475),
this feature is only used if an explicit `use_experimental_cfg_blocks`
flag is set to `True` in compatible frontends (currently only Python
frontend, Fortran frontend integration is coming soon).

If an SDFG makes use of these experimental blocks, some passes and
transformations will no longer be applied automatically in pipelines.
Transformations that handle these blocks correctly can be explicitly
marked with `@transformation.experimental_cfg_block_compatible` to apply
them on SDFGs with experimental blocks.

### Inlining
Conditional blocks can be inlined through a utility function to
traditional SDFG state machines. This is automatically done by
compatible frontends if the experimental CFG blocks feature is turned
off.

### Visualization Components
The visualization components are being worked on separately in
spcl/dace-webclient#173. This PR does not depend
on the visualization components to be merged.

---------

Co-authored-by: Luca Patrignani <[email protected]>
Co-authored-by: luca-patrignani <[email protected]>
  • Loading branch information
3 people authored Sep 27, 2024
1 parent 9945f48 commit 1dc9bc5
Show file tree
Hide file tree
Showing 15 changed files with 617 additions and 161 deletions.
137 changes: 65 additions & 72 deletions dace/codegen/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
import sympy as sp
from dace import dtypes
from dace.sdfg.analysis import cfg as cfg_analysis
from dace.sdfg.state import (BreakBlock, ContinueBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion,
from dace.sdfg.state import (BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion,
ReturnBlock, SDFGState)
from dace.sdfg.sdfg import SDFG, InterstateEdge
from dace.sdfg.graph import Edge
Expand Down Expand Up @@ -236,14 +236,18 @@ def first_block(self) -> ReturnBlock:


@dataclass
class GeneralBlock(ControlFlow):
"""
General (or unrecognized) control flow block with gotos between blocks.
"""
class RegionBlock(ControlFlow):

# The control flow region that this block corresponds to (may be the SDFG in the absence of hierarchical regions).
region: Optional[ControlFlowRegion]


@dataclass
class GeneralBlock(RegionBlock):
"""
General (or unrecognized) control flow block with gotos between blocks.
"""

# List of children control flow blocks
elements: List[ControlFlow]

Expand All @@ -270,7 +274,7 @@ def as_cpp(self, codegen, symbols) -> str:
for i, elem in enumerate(self.elements):
expr += elem.as_cpp(codegen, symbols)
# In a general block, emit transitions and assignments after each individual block or region.
if isinstance(elem, BasicCFBlock) or (isinstance(elem, GeneralBlock) and elem.region):
if isinstance(elem, BasicCFBlock) or (isinstance(elem, RegionBlock) and elem.region):
cfg = elem.state.parent_graph if isinstance(elem, BasicCFBlock) else elem.region.parent_graph
sdfg = cfg if isinstance(cfg, SDFG) else cfg.sdfg
out_edges = cfg.out_edges(elem.state) if isinstance(elem, BasicCFBlock) else cfg.out_edges(elem.region)
Expand Down Expand Up @@ -514,10 +518,9 @@ def children(self) -> List[ControlFlow]:


@dataclass
class GeneralLoopScope(ControlFlow):
class GeneralLoopScope(RegionBlock):
""" General loop block based on a loop control flow region. """

loop: LoopRegion
body: ControlFlow

def as_cpp(self, codegen, symbols) -> str:
Expand Down Expand Up @@ -565,6 +568,10 @@ def as_cpp(self, codegen, symbols) -> str:

return expr

@property
def loop(self) -> LoopRegion:
return self.region

@property
def first_block(self) -> ControlFlowBlock:
return self.loop.start_block
Expand Down Expand Up @@ -601,6 +608,46 @@ def children(self) -> List[ControlFlow]:
return list(self.cases.values())


@dataclass
class GeneralConditionalScope(RegionBlock):
""" General conditional block based on a conditional control flow region. """

branch_bodies: List[Tuple[Optional[CodeBlock], ControlFlow]]

def as_cpp(self, codegen, symbols) -> str:
sdfg = self.conditional.sdfg
expr = ''
for i in range(len(self.branch_bodies)):
branch = self.branch_bodies[i]
if branch[0] is not None:
cond = unparse_interstate_edge(branch[0].code, sdfg, codegen=codegen, symbols=symbols)
cond = cond.strip(';')
if i == 0:
expr += f'if ({cond}) {{\n'
else:
expr += f'}} else if ({cond}) {{\n'
else:
if i < len(self.branch_bodies) - 1 or i == 0:
raise RuntimeError('Missing branch condition for non-final conditional branch')
expr += '} else {\n'
expr += branch[1].as_cpp(codegen, symbols)
if i == len(self.branch_bodies) - 1:
expr += '}\n'
return expr

@property
def conditional(self) -> ConditionalBlock:
return self.region

@property
def first_block(self) -> ControlFlowBlock:
return self.conditional

@property
def children(self) -> List[ControlFlow]:
return [b for _, b in self.branch_bodies]


def _loop_from_structure(sdfg: SDFG, guard: SDFGState, enter_edge: Edge[InterstateEdge],
leave_edge: Edge[InterstateEdge], back_edges: List[Edge[InterstateEdge]],
dispatch_state: Callable[[SDFGState],
Expand Down Expand Up @@ -973,7 +1020,6 @@ def _structured_control_flow_traversal_with_regions(cfg: ControlFlowRegion,
if branch_merges is None:
branch_merges = cfg_analysis.branch_merges(cfg)


if ptree is None:
ptree = cfg_analysis.block_parent_tree(cfg, with_loops=False)

Expand Down Expand Up @@ -1004,6 +1050,14 @@ def make_empty_block():
cfg_block = ContinueCFBlock(dispatch_state, parent_block, True, node)
elif isinstance(node, ReturnBlock):
cfg_block = ReturnCFBlock(dispatch_state, parent_block, True, node)
elif isinstance(node, ConditionalBlock):
cfg_block = GeneralConditionalScope(dispatch_state, parent_block, False, node, [])
for cond, branch in node.branches:
if branch is not None:
body = make_empty_block()
body.parent = cfg_block
_structured_control_flow_traversal_with_regions(branch, dispatch_state, body)
cfg_block.branch_bodies.append((cond, body))
elif isinstance(node, ControlFlowRegion):
if isinstance(node, LoopRegion):
body = make_empty_block()
Expand All @@ -1027,69 +1081,8 @@ def make_empty_block():
stack.append(oe[0].dst)
parent_block.elements.append(cfg_block)
continue

# Potential branch or loop
if node in branch_merges:
mergeblock = branch_merges[node]

# Add branching node and ignore outgoing edges
parent_block.elements.append(cfg_block)
parent_block.gotos_to_ignore.extend(oe) # TODO: why?
parent_block.assignments_to_ignore.extend(oe) # TODO: why?
cfg_block.last_block = True

# Parse all outgoing edges recursively first
cblocks: Dict[Edge[InterstateEdge], GeneralBlock] = {}
for branch in oe:
if branch.dst is mergeblock:
# If we hit the merge state (if without else), defer to end of branch traversal
continue
cblocks[branch] = make_empty_block()
_structured_control_flow_traversal_with_regions(cfg=cfg,
dispatch_state=dispatch_state,
parent_block=cblocks[branch],
start=branch.dst,
stop=mergeblock,
generate_children_of=node,
branch_merges=branch_merges,
ptree=ptree,
visited=visited)

# Classify branch type:
branch_block = None
# If there are 2 out edges, one negation of the other:
# * if/else in case both branches are not merge state
# * if without else in case one branch is merge state
if (len(oe) == 2 and oe[0].data.condition_sympy() == sp.Not(oe[1].data.condition_sympy())):
if oe[0].dst is mergeblock:
# If without else
branch_block = IfScope(dispatch_state, parent_block, False, node, oe[1].data.condition,
cblocks[oe[1]])
elif oe[1].dst is mergeblock:
branch_block = IfScope(dispatch_state, parent_block, False, node, oe[0].data.condition,
cblocks[oe[0]])
else:
branch_block = IfScope(dispatch_state, parent_block, False, node, oe[0].data.condition,
cblocks[oe[0]], cblocks[oe[1]])
else:
# If there are 2 or more edges (one is not the negation of the
# other):
switch = _cases_from_branches(oe, cblocks)
if switch:
# If all edges are of form "x == y" for a single x and
# integer y, it is a switch/case
branch_block = SwitchCaseScope(dispatch_state, parent_block, False, node, switch[0], switch[1])
else:
# Otherwise, create if/else if/.../else goto exit chain
branch_block = IfElseChain(dispatch_state, parent_block, False, node,
[(e.data.condition, cblocks[e] if e in cblocks else make_empty_block())
for e in oe])
# End of branch classification
parent_block.elements.append(branch_block)
if mergeblock != stop:
stack.append(mergeblock)

else: # No merge state: Unstructured control flow
else:
# Unstructured control flow.
parent_block.sequential = False
parent_block.elements.append(cfg_block)
stack.extend([e.dst for e in oe])
Expand Down
2 changes: 1 addition & 1 deletion dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def dispatch_state(state: SDFGState) -> str:
states_generated.add(state) # For sanity check
return stream.getvalue()

if sdfg.root_sdfg.using_experimental_blocks:
if sdfg.root_sdfg.recheck_using_experimental_blocks():
# Use control flow blocks embedded in the SDFG to generate control flow.
cft = cflow.structured_control_flow_tree_with_regions(sdfg, dispatch_state)
elif config.Config.get_bool('optimizer', 'detect_control_flow'):
Expand Down
19 changes: 17 additions & 2 deletions dace/frontend/common/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from functools import reduce
from itertools import chain
from string import ascii_letters
from typing import Dict, Optional
from typing import Dict, List, Optional

import numpy as np

import dace
from dace import dtypes, subsets, symbolic
Expand Down Expand Up @@ -180,6 +182,19 @@ def create_einsum_sdfg(pv: 'dace.frontend.python.newast.ProgramVisitor',
beta=beta)[0]


def _build_einsum_views(tensors: str, dimension_dict: dict) -> List[np.ndarray]:
"""
Function taken and adjusted from opt_einsum package version 3.3.0 following unexpected removal in vesion 3.4.0.
Reference: https://github.com/dgasmith/opt_einsum/blob/v3.3.0/opt_einsum/helpers.py#L18
"""
views = []
terms = tensors.split('->')[0].split(',')
for term in terms:
dims = [dimension_dict[x] for x in term]
views.append(np.random.rand(*dims))
return views


def _create_einsum_internal(sdfg: SDFG,
state: SDFGState,
einsum_string: str,
Expand Down Expand Up @@ -231,7 +246,7 @@ def _create_einsum_internal(sdfg: SDFG,

# Create optimal contraction path
# noinspection PyTypeChecker
_, path_info = oe.contract_path(einsum_string, *oe.helpers.build_views(einsum_string, chardict))
_, path_info = oe.contract_path(einsum_string, *_build_einsum_views(einsum_string, chardict))

input_nodes = nodes or {arr: state.add_read(arr) for arr in arrays}
result_node = None
Expand Down
42 changes: 42 additions & 0 deletions dace/frontend/python/astutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,48 @@ def negate_expr(node):
return ast.fix_missing_locations(newexpr)


def and_expr(node_a, node_b):
""" Generates the logical AND of two AST expressions.
"""
if type(node_a) is not type(node_b):
raise ValueError('Node types do not match')

# Support for SymPy expressions
if isinstance(node_a, sympy.Basic):
return sympy.And(node_a, node_b)
# Support for numerical constants
if isinstance(node_a, (numbers.Number, numpy.bool_)):
return str(node_a and node_b)
# Support for strings (most likely dace.Data.Scalar names)
if isinstance(node_a, str):
return f'({node_a}) and ({node_b})'

from dace.properties import CodeBlock # Avoid import loop
if isinstance(node_a, CodeBlock):
node_a = node_a.code
node_b = node_b.code

if hasattr(node_a, "__len__"):
if len(node_a) > 1:
raise ValueError("and_expr only expects single expressions, got: {}".format(node_a))
if len(node_b) > 1:
raise ValueError("and_expr only expects single expressions, got: {}".format(node_b))
expr_a = node_a[0]
expr_b = node_b[0]
else:
expr_a = node_a
expr_b = node_b

if isinstance(expr_a, ast.Expr):
expr_a = expr_a.value
if isinstance(expr_b, ast.Expr):
expr_b = expr_b.value

newexpr = ast.Expr(value=ast.BinOp(left=copy_tree(expr_a), op=ast.And, right=copy_tree(expr_b)))
newexpr = ast.copy_location(newexpr, expr_a)
return ast.fix_missing_locations(newexpr)


def copy_tree(node: ast.AST) -> ast.AST:
"""
Copies an entire AST without copying the non-AST parts (e.g., constant values).
Expand Down
6 changes: 5 additions & 1 deletion dace/frontend/python/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def program(f: F,
recompile: bool = True,
distributed_compilation: bool = False,
constant_functions=False,
use_experimental_cfg_blocks=False,
**kwargs) -> Callable[..., parser.DaceProgram]:
"""
Entry point to a data-centric program. For methods and ``classmethod``s, use
Expand All @@ -68,6 +69,8 @@ def program(f: F,
not depend on internal variables are constant.
This will hardcode their return values into the
resulting program.
:param use_experimental_cfg_blocks: If True, makes use of experimental CFG blocks susch as loop and conditional
regions.
:note: If arguments are defined with type hints, the program can be compiled
ahead-of-time with ``.compile()``.
"""
Expand All @@ -83,7 +86,8 @@ def program(f: F,
recreate_sdfg=recreate_sdfg,
regenerate_code=regenerate_code,
recompile=recompile,
distributed_compilation=distributed_compilation)
distributed_compilation=distributed_compilation,
use_experimental_cfg_blocks=use_experimental_cfg_blocks)


function = program
Expand Down
Loading

0 comments on commit 1dc9bc5

Please sign in to comment.