From dc9510e4f316cdfed1aa55fc9cd080bee60fdf08 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 29 Aug 2023 18:10:48 +0200 Subject: [PATCH] Refactor --- dace/codegen/instrumentation/papi.py | 10 +- dace/frontend/python/newast.py | 3 +- dace/sdfg/__init__.py | 2 +- dace/sdfg/sdfg.py | 3 +- dace/sdfg/sdfg_control_flow.py | 286 ---------------- dace/sdfg/state.py | 304 +++++++++++++++++- dace/sdfg/utils.py | 2 +- .../transformation/interstate/state_fusion.py | 5 +- dace/transformation/transformation.py | 9 +- 9 files changed, 309 insertions(+), 315 deletions(-) delete mode 100644 dace/sdfg/sdfg_control_flow.py diff --git a/dace/codegen/instrumentation/papi.py b/dace/codegen/instrumentation/papi.py index bc7163ea9b..bab44e6779 100644 --- a/dace/codegen/instrumentation/papi.py +++ b/dace/codegen/instrumentation/papi.py @@ -12,7 +12,7 @@ from dace.sdfg.graph import SubgraphView from dace.memlet import Memlet from dace.sdfg import scope_contains_scope -from dace.sdfg.state import StateGraphView +from dace.sdfg.state import DataflowGraphView import sympy as sp import os @@ -392,7 +392,7 @@ def should_instrument_entry(map_entry: EntryNode) -> bool: return cond @staticmethod - def has_surrounding_perfcounters(node, dfg: StateGraphView): + def has_surrounding_perfcounters(node, dfg: DataflowGraphView): """ Returns true if there is a possibility that this node is part of a section that is profiled. """ parent = dfg.entry_node(node) @@ -605,7 +605,7 @@ def get_memlet_byte_size(sdfg: dace.SDFG, memlet: Memlet): return memlet.volume * memdata.dtype.bytes @staticmethod - def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: StateGraphView): + def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: DataflowGraphView): scope_dict = sdfg.node(state_id).scope_dict() out_costs = 0 @@ -636,7 +636,7 @@ def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: return out_costs @staticmethod - def get_tasklet_byte_accesses(tasklet: nodes.CodeNode, dfg: StateGraphView, sdfg: dace.SDFG, state_id: int) -> str: + def get_tasklet_byte_accesses(tasklet: nodes.CodeNode, dfg: DataflowGraphView, sdfg: dace.SDFG, state_id: int) -> str: """ Get the amount of bytes processed by `tasklet`. The formula is sum(inedges * size) + sum(outedges * size) """ in_accum = [] @@ -693,7 +693,7 @@ def get_memory_input_size(node, sdfg, state_id) -> str: return sym2cpp(input_size) @staticmethod - def accumulate_byte_movement(outermost_node, node, dfg: StateGraphView, sdfg, state_id): + def accumulate_byte_movement(outermost_node, node, dfg: DataflowGraphView, sdfg, state_id): itvars = dict() # initialize an empty dict diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 72357fcb79..801c742979 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -28,11 +28,10 @@ from dace.frontend.python.memlet_parser import (DaceSyntaxError, parse_memlet, pyexpr_to_symbolic, ParseMemlet, inner_eval_ast, MemletExpr) from dace.sdfg import nodes, utils as sdutil -from dace.sdfg.sdfg_control_flow import ControlFlowGraph, LoopScopeBlock, ControlFlowBlock from dace.sdfg.propagation import propagate_memlet, propagate_subset, propagate_states from dace.memlet import Memlet from dace.properties import LambdaProperty, CodeBlock -from dace.sdfg import SDFG, SDFGState +from dace.sdfg import SDFG, SDFGState, ControlFlowGraph, ControlFlowBlock, LoopScopeBlock from dace.sdfg.replace import replace_datadesc_names from dace.symbolic import pystr_to_symbolic, inequal_symbols diff --git a/dace/sdfg/__init__.py b/dace/sdfg/__init__.py index ec650f24fa..307da66e3e 100644 --- a/dace/sdfg/__init__.py +++ b/dace/sdfg/__init__.py @@ -1,7 +1,7 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. from dace.sdfg.sdfg import SDFG, InterstateEdge, LogicalGroup -from dace.sdfg.state import SDFGState, StateGraphView +from dace.sdfg.state import SDFGState, ControlFlowBlock, ControlFlowGraph, ScopeBlock, LoopScopeBlock, BranchScopeBlock from dace.sdfg.scope import (scope_contains_scope, is_devicelevel_gpu, devicelevel_block_size, ScopeSubgraphView) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 64cb72310f..06dc0d438a 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -42,8 +42,7 @@ from dace.sdfg.propagation import propagate_memlets_sdfg from dace.sdfg.replace import replace, replace_properties, replace_properties_dict from dace.sdfg.scope import ScopeTree -from dace.sdfg.sdfg_control_flow import ControlFlowGraph, LoopScopeBlock -from dace.sdfg.state import SDFGState +from dace.sdfg.state import SDFGState, ControlFlowGraph, LoopScopeBlock from dace.sdfg.validation import InvalidSDFGError, validate_sdfg # NOTE: In shapes, we try to convert strings to integers. In ranks, a string should be interpreted as data (scalar). diff --git a/dace/sdfg/sdfg_control_flow.py b/dace/sdfg/sdfg_control_flow.py deleted file mode 100644 index f79f381cce..0000000000 --- a/dace/sdfg/sdfg_control_flow.py +++ /dev/null @@ -1,286 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. - -from typing import Dict, List, Optional, Set, Generator - -import dace -from dace import symbolic, data as dt -from dace.memlet import Memlet -from dace.properties import CodeBlock, CodeProperty, Property, make_properties -from dace.sdfg import nodes as nd -from dace.sdfg.graph import OrderedDiGraph, OrderedMultiDiConnectorGraph - - -@make_properties -class ControlFlowBlock(object): - - is_collapsed = Property(dtype=bool, desc='Show this block as collapsed', default=False) - - _parent_cfg: Optional['ControlFlowGraph'] = None - _label: str - - def __init__(self, label: str='', parent: Optional['ControlFlowGraph']=None): - super(ControlFlowBlock, self).__init__() - self._label = label - self._parent_cfg = parent - self._default_lineinfo = None - self.is_collapsed = False - - def set_default_lineinfo(self, lineinfo: dace.dtypes.DebugInfo): - """ - Sets the default source line information to be lineinfo, or None to - revert to default mode. - """ - self._default_lineinfo = lineinfo - - def data_nodes(self) -> List[nd.AccessNode]: - return [] - - def replace_dict(self, - repl: Dict[str, str], - symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None): - """ Finds and replaces all occurrences of a set of symbols or arrays in this state. - - :param repl: Mapping from names to replacements. - :param symrepl: Optional symbolic version of ``repl``. - """ - from dace.sdfg.replace import replace_dict - replace_dict(self, repl, symrepl) - - def to_json(self, parent=None): - tmp = {} - tmp['id'] = parent.node_id(self) if parent is not None else None - tmp['label'] = self._label - tmp['collapsed'] = self.is_collapsed - return tmp - - def __str__(self): - return self._label - - def __repr__(self) -> str: - return f'ControlFlowBlock ({self.label})' - - @property - def label(self) -> str: - return self._label - - @label.setter - def label(self, label: str): - self._label = label - - @property - def name(self) -> str: - return self._label - - @property - def parent_cfg(self): - """ Returns the parent graph of this block. """ - return self._parent_cfg - - @parent_cfg.setter - def parent_cfg(self, value): - self._parent_cfg = value - - -@make_properties -class BasicBlock(OrderedMultiDiConnectorGraph[nd.Node, Memlet], ControlFlowBlock): - - def __init__(self, label: str='', parent: Optional['ControlFlowGraph']=None): - OrderedMultiDiConnectorGraph.__init__(self) - ControlFlowBlock.__init__(self, label, parent) - - def __repr__(self) -> str: - return f'BasicBlock ({self.label})' - - -@make_properties -class ControlFlowGraph(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.InterstateEdge']): - - def __init__(self): - super(ControlFlowGraph, self).__init__() - - self._labels: Set[str] = set() - self._start_block: Optional[int] = None - self._cached_start_block: Optional[ControlFlowBlock] = None - - def add_edge(self, src: 'ControlFlowBlock', dst: 'ControlFlowBlock', data: 'dace.sdfg.InterstateEdge'): - """ Adds a new edge to the graph. Must be an InterstateEdge or a subclass thereof. - - :param u: Source node. - :param v: Destination node. - :param edge: The edge to add. - """ - if not isinstance(src, ControlFlowBlock): - raise TypeError('Expected ControlFlowBlock, got ' + str(type(src))) - if not isinstance(dst, ControlFlowBlock): - raise TypeError('Expected ControlFlowBlock, got ' + str(type(dst))) - if not isinstance(data, dace.sdfg.InterstateEdge): - raise TypeError('Expected InterstateEdge, got ' + str(type(data))) - if dst is self._cached_start_block: - self._cached_start_block = None - return super(ControlFlowGraph, self).add_edge(src, dst, data) - - def add_node(self, node, is_start_block=False): - if not isinstance(node, ControlFlowBlock): - raise TypeError('Expected ControlFlowBlock, got ' + str(type(node))) - super().add_node(node) - node.parent_cfg = self - self._cached_start_block = None - if is_start_block is True: - self.start_block = len(self.nodes()) - 1 - self._cached_start_block = node - - def add_state(self, label=None, is_start_block=False, parent_sdfg=None) -> 'dace.SDFGState': - if self._labels is None or len(self._labels) != self.number_of_nodes(): - self._labels = set(s.label for s in self.nodes()) - label = label or 'state' - existing_labels = self._labels - label = dt.find_new_name(label, existing_labels) - state = dace.SDFGState(label, parent_sdfg) - self._labels.add(label) - self.add_node(state, is_start_block=is_start_block) - return state - - def all_cfgs_recursive(self, recurse_into_sdfgs=True) -> Generator['ControlFlowGraph', None, None]: - """ Iterate over this and all nested CFGs. """ - yield self - for block in self.nodes(): - if isinstance(block, BasicBlock) and recurse_into_sdfgs: - for node in block.nodes(): - if isinstance(node, nd.NestedSDFG): - yield from node.sdfg.all_cfgs_recursive() - elif isinstance(block, ControlFlowGraph): - yield from block.all_cfgs_recursive() - - def all_sdfgs_recursive(self) -> Generator['dace.SDFG', None, None]: - """ Iterate over this and all nested SDFGs. """ - for cfg in self.all_cfgs_recursive(recurse_into_sdfgs=True): - if isinstance(cfg, dace.SDFG): - yield cfg - - def all_states_recursive(self) -> Generator['dace.SDFGState', None, None]: - """ Iterate over all states in this control flow graph. """ - for block in self.nodes(): - if isinstance(block, dace.SDFGState): - yield block - elif isinstance(block, ControlFlowGraph): - yield from block.all_states_recursive() - - @property - def start_block(self): - """ Returns the starting block of this ControlFlowGraph. """ - if self._cached_start_block is not None: - return self._cached_start_block - - source_nodes = self.source_nodes() - if len(source_nodes) == 1: - self._cached_start_block = source_nodes[0] - return source_nodes[0] - # If the starting block is ambiguous allow manual override. - if self._start_block is not None: - self._cached_start_block = self.node(self._start_block) - return self._cached_start_block - raise ValueError('Ambiguous or undefined starting block for ControlFlowGraph, ' - 'please use "is_start_block=True" when adding the ' - 'starting block with "add_state" or "add_node"') - - @start_block.setter - def start_block(self, block_id): - """ Manually sets the starting block of this ControlFlowGraph. - - :param block_id: The node ID (use `node_id(block)`) of the block to set. - """ - if block_id < 0 or block_id >= self.number_of_nodes(): - raise ValueError('Invalid state ID') - self._start_block = block_id - self._cached_start_block = self.node(block_id) - - -@make_properties -class ScopeBlock(ControlFlowGraph, ControlFlowBlock): - - def __init__(self, label: str='', parent: Optional[ControlFlowGraph]=None): - ControlFlowGraph.__init__(self) - ControlFlowBlock.__init__(self, label, parent) - - def data_nodes(self) -> List[nd.AccessNode]: - """ Returns all data_nodes (arrays) present in this state. """ - data_nodes = [] - for n in self.nodes(): - data_nodes.append(n.data_nodes()) - - return [n for n in self.nodes() if isinstance(n, nd.AccessNode)] - - def replace_dict(self, - repl: Dict[str, str], - symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None): - """ Finds and replaces all occurrences of a set of symbols or arrays in this state. - - :param repl: Mapping from names to replacements. - :param symrepl: Optional symbolic version of ``repl``. - """ - for n in self.nodes(): - n.replace_dict(repl, symrepl) - - def to_json(self, parent=None): - graph_json = ControlFlowGraph.to_json(self) - block_json = ControlFlowBlock.to_json(self, parent) - graph_json.update(block_json) - return graph_json - - def all_nodes_recursive(self): - for node in self.nodes(): - yield node, self - if isinstance(node, (ScopeBlock, dace.sdfg.StateGraphView)): - yield from node.all_nodes_recursive() - - def __str__(self): - return ControlFlowBlock.__str__(self) - - def __repr__(self) -> str: - return f'{self.__class__.__name__} ({self.label})' - - -@make_properties -class LoopScopeBlock(ScopeBlock): - - update_statement = CodeProperty(optional=True, allow_none=True, default=None) - init_statement = CodeProperty(optional=True, allow_none=True, default=None) - scope_condition = CodeProperty(allow_none=True, default=None) - inverted = Property(dtype=bool, default=False) - - def __init__(self, - loop_var: str, - initialize_expr: str, - condition_expr: str, - update_expr: str, - label: str = '', - parent: Optional[ControlFlowGraph] = None, - inverted: bool = False): - super(LoopScopeBlock, self).__init__(label, parent) - - if initialize_expr is not None: - self.init_statement = CodeBlock('%s = %s' % (loop_var, initialize_expr)) - else: - self.init_statement = None - - if condition_expr: - self.scope_condition = CodeBlock(condition_expr) - else: - self.scope_condition = CodeBlock('True') - - if update_expr is not None: - self.update_statement = CodeBlock('%s = %s' % (loop_var, update_expr)) - else: - self.update_statement = None - - self.inverted = inverted - - def to_json(self, parent=None): - return super().to_json(parent) - - -@make_properties -class BranchScopeBlock(ScopeBlock): - - def __init__(self): - super(BranchScopeBlock, self).__init__() diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 540f5a03e9..6e9e516fe8 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -19,8 +19,7 @@ from dace.properties import (CodeBlock, DictProperty, EnumProperty, Property, SubsetProperty, SymbolicProperty, CodeProperty, make_properties) from dace.sdfg import nodes as nd -from dace.sdfg.sdfg_control_flow import BasicBlock -from dace.sdfg.graph import MultiConnectorEdge, SubgraphView +from dace.sdfg.graph import MultiConnectorEdge, SubgraphView, OrderedDiGraph, OrderedMultiDiConnectorGraph from dace.sdfg.propagation import propagate_memlet from dace.sdfg.validation import validate_state from dace.subsets import Range, Subset @@ -64,11 +63,10 @@ def _make_iterators(ndrange): return params, map_range -class StateGraphView(object): +class BlockGraphView(object): """ - Read-only view interface of an SDFG state, containing methods for memlet - tracking, traversal, subgraph creation, queries, and replacements. - ``SDFGState`` and ``StateSubgraphView`` inherit from this class to share + Read-only view interface of an SDFG control flow block, containing methods for memlet tracking, traversal, subgraph + creation, queries, and replacements. ``ControlFlowBlock`` and ``StateSubgraphView`` inherit from this class to share methods. """ @@ -94,6 +92,8 @@ def all_nodes_recursive(self): yield node, self if isinstance(node, nd.NestedSDFG): yield from node.sdfg.all_nodes_recursive() + elif isinstance(node, (ScopeBlock, SDFGState)): + yield from node.all_nodes_recursive() def all_edges_recursive(self): for e in self.edges(): @@ -746,8 +746,85 @@ def replace_dict(self, replace_dict(self, repl, symrepl) +SomeNodeT = Union[nd.Node, 'ControlFlowBlock'] +SomeGraphT = Union['ControlFlowGraph', 'SDFGState'] + + +@make_properties +class ControlFlowBlock(BlockGraphView): + + is_collapsed = Property(dtype=bool, desc='Show this block as collapsed', default=False) + + _parent_cfg: Optional['ControlFlowGraph'] = None + _label: str + + def __init__(self, label: str='', parent: Optional['ControlFlowGraph']=None): + super(ControlFlowBlock, self).__init__() + self._label = label + self._parent_cfg = parent + self._default_lineinfo = None + self.is_collapsed = False + + def set_default_lineinfo(self, lineinfo: dace.dtypes.DebugInfo): + """ + Sets the default source line information to be lineinfo, or None to + revert to default mode. + """ + self._default_lineinfo = lineinfo + + def data_nodes(self) -> List[nd.AccessNode]: + return [] + + def replace_dict(self, + repl: Dict[str, str], + symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None): + """ Finds and replaces all occurrences of a set of symbols or arrays in this state. + + :param repl: Mapping from names to replacements. + :param symrepl: Optional symbolic version of ``repl``. + """ + from dace.sdfg.replace import replace_dict + replace_dict(self, repl, symrepl) + + def to_json(self, parent=None): + tmp = { + 'type': self.__class__.__name__, + 'collapsed': self.is_collapsed, + 'label': self._label, + 'id': parent.node_id(self) if parent is not None else None, + } + return tmp + + def __str__(self): + return self._label + + def __repr__(self) -> str: + return f'ControlFlowBlock ({self.label})' + + @property + def label(self) -> str: + return self._label + + @label.setter + def label(self, label: str): + self._label = label + + @property + def name(self) -> str: + return self._label + + @property + def parent_cfg(self): + """ Returns the parent graph of this block. """ + return self._parent_cfg + + @parent_cfg.setter + def parent_cfg(self, value): + self._parent_cfg = value + + @make_properties -class SDFGState(BasicBlock, StateGraphView): +class SDFGState(OrderedMultiDiConnectorGraph[nd.Node, mm.Memlet], ControlFlowBlock): """ An acyclic dataflow multigraph in an SDFG, corresponding to a single state in the SDFG state machine. """ @@ -789,8 +866,9 @@ def __init__(self, label=None, sdfg=None, debuginfo=None, location=None): :param debuginfo: Source code locator for debugging. """ from dace.sdfg.sdfg import SDFG # Avoid import loop - BasicBlock.__init__(self, label, sdfg) - StateGraphView.__init__(self) + OrderedMultiDiConnectorGraph.__init__(self) + #StateGraphView.__init__(self) + ControlFlowBlock.__init__(self, label, sdfg) self._parent: Optional[SDFG] = sdfg self._graph = self # Allowing MemletTrackingView mixin to work self._clear_scopedict_cache() @@ -1944,8 +2022,214 @@ def fill_scope_connectors(self): node.add_in_connector(edge.dst_conn) -class StateSubgraphView(SubgraphView, StateGraphView): +class StateSubgraphView(SubgraphView, BlockGraphView): """ A read-only subgraph view of an SDFG state. """ def __init__(self, graph, subgraph_nodes): super().__init__(graph, subgraph_nodes) + + +@make_properties +class ControlFlowGraph(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.InterstateEdge']): + + def __init__(self): + super(ControlFlowGraph, self).__init__() + + self._labels: Set[str] = set() + self._start_block: Optional[int] = None + self._cached_start_block: Optional[ControlFlowBlock] = None + + def add_edge(self, src: ControlFlowBlock, dst: ControlFlowBlock, data: 'dace.sdfg.InterstateEdge'): + """ Adds a new edge to the graph. Must be an InterstateEdge or a subclass thereof. + + :param u: Source node. + :param v: Destination node. + :param edge: The edge to add. + """ + if not isinstance(src, ControlFlowBlock): + raise TypeError('Expected ControlFlowBlock, got ' + str(type(src))) + if not isinstance(dst, ControlFlowBlock): + raise TypeError('Expected ControlFlowBlock, got ' + str(type(dst))) + if not isinstance(data, dace.sdfg.InterstateEdge): + raise TypeError('Expected InterstateEdge, got ' + str(type(data))) + if dst is self._cached_start_block: + self._cached_start_block = None + return super(ControlFlowGraph, self).add_edge(src, dst, data) + + def add_node(self, node, is_start_block=False): + if not isinstance(node, ControlFlowBlock): + raise TypeError('Expected ControlFlowBlock, got ' + str(type(node))) + super().add_node(node) + node.parent_cfg = self + self._cached_start_block = None + if is_start_block is True: + self.start_block = len(self.nodes()) - 1 + self._cached_start_block = node + + def add_state(self, label=None, is_start_block=False, parent_sdfg=None) -> SDFGState: + if self._labels is None or len(self._labels) != self.number_of_nodes(): + self._labels = set(s.label for s in self.nodes()) + label = label or 'state' + existing_labels = self._labels + label = dt.find_new_name(label, existing_labels) + state = SDFGState(label, parent_sdfg) + self._labels.add(label) + self.add_node(state, is_start_block=is_start_block) + return state + + # ============================================================================= + # = Graph traversal methods =================================================== + # ============================================================================= + + #def all_nodes_recursive(self) -> Iterator[Tuple[SomeNodeT, SomeGraphT]]: + # for node in self.nodes(): + # yield node, self + # if isinstance(node, (ScopeBlock, SDFGState)): + # yield from node.all_nodes_recursive() + + def all_cfgs_recursive(self, recurse_into_sdfgs=True) -> Iterator['ControlFlowGraph']: + """ Iterate over this and all nested CFGs. """ + yield self + for block in self.nodes(): + if isinstance(block, SDFGState) and recurse_into_sdfgs: + for node in block.nodes(): + if isinstance(node, nd.NestedSDFG): + yield from node.sdfg.all_cfgs_recursive() + elif isinstance(block, ControlFlowGraph): + yield from block.all_cfgs_recursive() + + def all_sdfgs_recursive(self) -> Iterator['dace.SDFG']: + """ Iterate over this and all nested SDFGs. """ + for cfg in self.all_cfgs_recursive(recurse_into_sdfgs=True): + if isinstance(cfg, dace.SDFG): + yield cfg + + def all_states_recursive(self) -> Iterator[SDFGState]: + """ Iterate over all states in this control flow graph. """ + for block in self.nodes(): + if isinstance(block, SDFGState): + yield block + elif isinstance(block, ControlFlowGraph): + yield from block.all_states_recursive() + + @property + def start_block(self): + """ Returns the starting block of this ControlFlowGraph. """ + if self._cached_start_block is not None: + return self._cached_start_block + + source_nodes = self.source_nodes() + if len(source_nodes) == 1: + self._cached_start_block = source_nodes[0] + return source_nodes[0] + # If the starting block is ambiguous allow manual override. + if self._start_block is not None: + self._cached_start_block = self.node(self._start_block) + return self._cached_start_block + raise ValueError('Ambiguous or undefined starting block for ControlFlowGraph, ' + 'please use "is_start_block=True" when adding the ' + 'starting block with "add_state" or "add_node"') + + @start_block.setter + def start_block(self, block_id): + """ Manually sets the starting block of this ControlFlowGraph. + + :param block_id: The node ID (use `node_id(block)`) of the block to set. + """ + if block_id < 0 or block_id >= self.number_of_nodes(): + raise ValueError('Invalid state ID') + self._start_block = block_id + self._cached_start_block = self.node(block_id) + + +@make_properties +class ScopeBlock(ControlFlowGraph, ControlFlowBlock): + + # TODO: instrumentation + + def __init__(self, label: str='', parent: Optional[ControlFlowGraph]=None): + ControlFlowGraph.__init__(self) + ControlFlowBlock.__init__(self, label, parent) + + def data_nodes(self) -> List[nd.AccessNode]: + """ Returns all data_nodes (arrays) present in this state. """ + data_nodes = [] + for n in self.nodes(): + data_nodes.append(n.data_nodes()) + + return [n for n in self.nodes() if isinstance(n, nd.AccessNode)] + + def replace_dict(self, + repl: Dict[str, str], + symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None): + """ Finds and replaces all occurrences of a set of symbols or arrays in this state. + + :param repl: Mapping from names to replacements. + :param symrepl: Optional symbolic version of ``repl``. + """ + for n in self.nodes(): + n.replace_dict(repl, symrepl) + + def to_json(self, parent=None): + graph_json = ControlFlowGraph.to_json(self) + block_json = ControlFlowBlock.to_json(self, parent) + graph_json.update(block_json) + return graph_json + + #def all_nodes_recursive(self): + # for node in self.nodes(): + # yield node, self + # if isinstance(node, (ScopeBlock, BlockGraphView)): + # yield from node.all_nodes_recursive() + + def __str__(self): + return ControlFlowBlock.__str__(self) + + def __repr__(self) -> str: + return f'{self.__class__.__name__} ({self.label})' + + +@make_properties +class LoopScopeBlock(ScopeBlock): + + update_statement = CodeProperty(optional=True, allow_none=True, default=None) + init_statement = CodeProperty(optional=True, allow_none=True, default=None) + scope_condition = CodeProperty(allow_none=True, default=None) + inverted = Property(dtype=bool, default=False) + + def __init__(self, + loop_var: str, + initialize_expr: str, + condition_expr: str, + update_expr: str, + label: str = '', + parent: Optional[ControlFlowGraph] = None, + inverted: bool = False): + super(LoopScopeBlock, self).__init__(label, parent) + + if initialize_expr is not None: + self.init_statement = CodeBlock('%s = %s' % (loop_var, initialize_expr)) + else: + self.init_statement = None + + if condition_expr: + self.scope_condition = CodeBlock(condition_expr) + else: + self.scope_condition = CodeBlock('True') + + if update_expr is not None: + self.update_statement = CodeBlock('%s = %s' % (loop_var, update_expr)) + else: + self.update_statement = None + + self.inverted = inverted + + def to_json(self, parent=None): + return super().to_json(parent) + + +@make_properties +class BranchScopeBlock(ScopeBlock): + + def __init__(self): + super(BranchScopeBlock, self).__init__() diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 605a13d87b..1a01bce92c 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1231,7 +1231,7 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> candidate = {StateFusion.first_state: u, StateFusion.second_state: v} sf = StateFusion() sf.setup_match(cfg, id, -1, candidate, 0, override=True) - if sf.can_be_applied(cfg, 0, cfg, permissive=permissive): + if sf.can_be_applied(cfg, 0, sd, permissive=permissive): sf.apply(cfg, sd) applied += 1 counter += 1 diff --git a/dace/transformation/interstate/state_fusion.py b/dace/transformation/interstate/state_fusion.py index d006fa3287..dc00b8b4d5 100644 --- a/dace/transformation/interstate/state_fusion.py +++ b/dace/transformation/interstate/state_fusion.py @@ -10,8 +10,7 @@ from dace.config import Config from dace.sdfg import SDFG, nodes from dace.sdfg import utils as sdutil -from dace.sdfg.state import SDFGState -from dace.sdfg.sdfg_control_flow import ControlFlowGraph, BasicBlock +from dace.sdfg.state import SDFGState, ControlFlowGraph from dace.transformation import transformation @@ -456,7 +455,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return True - def apply(self, graph: Union[ControlFlowGraph, BasicBlock], sdfg: SDFG): + def apply(self, graph: Union[ControlFlowGraph, SDFGState], sdfg: SDFG): first_state: SDFGState = self.first_state second_state: SDFGState = self.second_state diff --git a/dace/transformation/transformation.py b/dace/transformation/transformation.py index 86e877424c..e5e3dc925a 100644 --- a/dace/transformation/transformation.py +++ b/dace/transformation/transformation.py @@ -22,8 +22,7 @@ import copy from dace import dtypes, serialize from dace.dtypes import ScheduleType -from dace.sdfg import SDFG, SDFGState -from dace.sdfg.sdfg_control_flow import ControlFlowGraph, BasicBlock +from dace.sdfg import SDFG, SDFGState, ControlFlowGraph from dace.sdfg import nodes as nd, graph as gr, utils as sdutil, propagation, infer_types, state as st from dace.properties import make_properties, Property, DictProperty, SetProperty from dace.transformation import pass_pipeline as ppl @@ -109,7 +108,7 @@ def expressions(cls) -> List[gr.SubgraphView]: raise NotImplementedError def can_be_applied(self, - graph: Union[SDFG, SDFGState], + graph: Union[ControlFlowGraph, SDFGState], expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: @@ -127,7 +126,7 @@ def can_be_applied(self, """ raise NotImplementedError - def apply(self, graph: Union[ControlFlowGraph, BasicBlock], sdfg: SDFG) -> Union[Any, None]: + def apply(self, graph: Union[ControlFlowGraph, SDFGState], sdfg: SDFG) -> Union[Any, None]: """ Applies this transformation instance on the matched pattern graph. @@ -501,7 +500,7 @@ def expressions(cls) -> List[gr.SubgraphView]: pass @abc.abstractmethod - def can_be_applied(self, graph: SDFG, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: + def can_be_applied(self, graph: ControlFlowGraph, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: """ Returns True if this transformation can be applied on the candidate matched subgraph. :param graph: SDFG object in which the match was found.