diff --git a/dace/codegen/instrumentation/papi.py b/dace/codegen/instrumentation/papi.py index c0d3b657a1..4885611408 100644 --- a/dace/codegen/instrumentation/papi.py +++ b/dace/codegen/instrumentation/papi.py @@ -12,7 +12,7 @@ from dace.sdfg.graph import SubgraphView from dace.memlet import Memlet from dace.sdfg import scope_contains_scope -from dace.sdfg.state import StateGraphView +from dace.sdfg.state import DataflowGraphView import sympy as sp import os @@ -392,7 +392,7 @@ def should_instrument_entry(map_entry: EntryNode) -> bool: return cond @staticmethod - def has_surrounding_perfcounters(node, dfg: StateGraphView): + def has_surrounding_perfcounters(node, dfg: DataflowGraphView): """ Returns true if there is a possibility that this node is part of a section that is profiled. """ parent = dfg.entry_node(node) @@ -605,7 +605,7 @@ def get_memlet_byte_size(sdfg: dace.SDFG, memlet: Memlet): return memlet.volume * memdata.dtype.bytes @staticmethod - def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: StateGraphView): + def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: DataflowGraphView): scope_dict = sdfg.node(state_id).scope_dict() out_costs = 0 @@ -636,7 +636,10 @@ def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: return out_costs @staticmethod - def get_tasklet_byte_accesses(tasklet: nodes.CodeNode, dfg: StateGraphView, sdfg: dace.SDFG, state_id: int) -> str: + def get_tasklet_byte_accesses(tasklet: nodes.CodeNode, + dfg: DataflowGraphView, + sdfg: dace.SDFG, + state_id: int) -> str: """ Get the amount of bytes processed by `tasklet`. The formula is sum(inedges * size) + sum(outedges * size) """ in_accum = [] @@ -693,7 +696,7 @@ def get_memory_input_size(node, sdfg, state_id) -> str: return sym2cpp(input_size) @staticmethod - def accumulate_byte_movement(outermost_node, node, dfg: StateGraphView, sdfg, state_id): + def accumulate_byte_movement(outermost_node, node, dfg: DataflowGraphView, sdfg, state_id): itvars = dict() # initialize an empty dict diff --git a/dace/codegen/targets/cuda.py b/dace/codegen/targets/cuda.py index a465d2bbc0..fb8ae90187 100644 --- a/dace/codegen/targets/cuda.py +++ b/dace/codegen/targets/cuda.py @@ -445,7 +445,7 @@ def node_dispatch_predicate(self, sdfg, state, node): if hasattr(node, 'schedule'): # NOTE: Works on nodes and scopes if node.schedule in dtypes.GPU_SCHEDULES: return True - if isinstance(node, nodes.NestedSDFG) and CUDACodeGen._in_device_code: + if CUDACodeGen._in_device_code: return True return False @@ -1324,11 +1324,11 @@ def generate_devicelevel_state(self, sdfg, state, function_stream, callsite_stre if write_scope == 'grid': callsite_stream.write("if (blockIdx.x == 0 " - "&& threadIdx.x == 0) " - "{ // sub-graph begin", sdfg, state.node_id) + "&& threadIdx.x == 0) " + "{ // sub-graph begin", sdfg, state.node_id) elif write_scope == 'block': callsite_stream.write("if (threadIdx.x == 0) " - "{ // sub-graph begin", sdfg, state.node_id) + "{ // sub-graph begin", sdfg, state.node_id) else: callsite_stream.write("{ // subgraph begin", sdfg, state.node_id) else: @@ -2519,15 +2519,17 @@ def generate_devicelevel_scope(self, sdfg, dfg_scope, state_id, function_stream, def generate_node(self, sdfg, dfg, state_id, node, function_stream, callsite_stream): if self.node_dispatch_predicate(sdfg, dfg, node): # Dynamically obtain node generator according to class name - gen = getattr(self, '_generate_' + type(node).__name__) - gen(sdfg, dfg, state_id, node, function_stream, callsite_stream) - return + gen = getattr(self, '_generate_' + type(node).__name__, False) + if gen is not False: # Not every node type has a code generator here + gen(sdfg, dfg, state_id, node, function_stream, callsite_stream) + return if not CUDACodeGen._in_device_code: self._cpu_codegen.generate_node(sdfg, dfg, state_id, node, function_stream, callsite_stream) return - self._locals.clear_scope(self._code_state.indentation + 1) + if isinstance(node, nodes.ExitNode): + self._locals.clear_scope(self._code_state.indentation + 1) if CUDACodeGen._in_device_code and isinstance(node, nodes.MapExit): return # skip @@ -2591,6 +2593,78 @@ def _generate_MapExit(self, sdfg, dfg, state_id, node, function_stream, callsite self._cpu_codegen._generate_MapExit(sdfg, dfg, state_id, node, function_stream, callsite_stream) + def _get_thread_id(self) -> str: + result = 'threadIdx.x' + if self._block_dims[1] != 1: + result += f' + ({sym2cpp(self._block_dims[0])}) * threadIdx.y' + if self._block_dims[2] != 1: + result += f' + ({sym2cpp(self._block_dims[0] * self._block_dims[1])}) * threadIdx.z' + return result + + def _get_warp_id(self) -> str: + return f'(({self._get_thread_id()}) / warpSize)' + + def _get_block_id(self) -> str: + result = 'blockIdx.x' + if self._block_dims[1] != 1: + result += f' + gridDim.x * blockIdx.y' + if self._block_dims[2] != 1: + result += f' + gridDim.x * gridDim.y * blockIdx.z' + return result + + def _generate_condition_from_location(self, name: str, index_expr: str, node: nodes.Tasklet, + callsite_stream: CodeIOStream) -> str: + if name not in node.location: + return 0 + + location: Union[int, str, subsets.Range] = node.location[name] + if isinstance(location, str) and ':' in location: + location = subsets.Range.from_string(location) + elif symbolic.issymbolic(location): + location = sym2cpp(location) + + if isinstance(location, subsets.Range): + # Range of indices + if len(location) != 1: + raise ValueError(f'Only one-dimensional ranges are allowed for {name} specialization, {location} given') + begin, end, stride = location[0] + rb, re, rs = sym2cpp(begin), sym2cpp(end), sym2cpp(stride) + cond = '' + cond += f'(({index_expr}) >= {rb}) && (({index_expr}) <= {re})' + if stride != 1: + cond += f' && ((({index_expr}) - {rb}) % {rs} == 0)' + + callsite_stream.write(f'if ({cond}) {{') + else: + # Single-element + callsite_stream.write(f'if (({index_expr}) == {location}) {{') + + return 1 + + def _generate_Tasklet(self, sdfg: SDFG, dfg, state_id: int, node: nodes.Tasklet, function_stream: CodeIOStream, + callsite_stream: CodeIOStream): + generated_preamble_scopes = 0 + if self._in_device_code: + # If location dictionary prescribes that the code should run on a certain group of threads/blocks, + # add condition + generated_preamble_scopes += self._generate_condition_from_location('gpu_thread', self._get_thread_id(), + node, callsite_stream) + generated_preamble_scopes += self._generate_condition_from_location('gpu_warp', self._get_warp_id(), node, + callsite_stream) + generated_preamble_scopes += self._generate_condition_from_location('gpu_block', self._get_block_id(), node, + callsite_stream) + + # Call standard tasklet generation + old_codegen = self._cpu_codegen.calling_codegen + self._cpu_codegen.calling_codegen = self + self._cpu_codegen._generate_Tasklet(sdfg, dfg, state_id, node, function_stream, callsite_stream) + self._cpu_codegen.calling_codegen = old_codegen + + if generated_preamble_scopes > 0: + # Generate appropriate postamble + for i in range(generated_preamble_scopes): + callsite_stream.write('}', sdfg, state_id, node) + def make_ptr_vector_cast(self, *args, **kwargs): return cpp.make_ptr_vector_cast(*args, **kwargs) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 917f748cb8..084d46f47d 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -275,7 +275,7 @@ def remove_name_collisions(sdfg: SDFG): # Rename duplicate states for state in nsdfg.nodes(): if state.label in state_names_seen: - state.set_label(data.find_new_name(state.label, state_names_seen)) + state.label = data.find_new_name(state.label, state_names_seen) state_names_seen.add(state.label) replacements: Dict[str, str] = {} diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 32369a19a3..a28e9fce38 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -262,9 +262,8 @@ def label(self): def __label__(self, sdfg, state): return self.data - def desc(self, sdfg): - from dace.sdfg import SDFGState, ScopeSubgraphView - if isinstance(sdfg, (SDFGState, ScopeSubgraphView)): + def desc(self, sdfg: Union['dace.sdfg.SDFG', 'dace.sdfg.SDFGState', 'dace.sdfg.ScopeSubgraphView']): + if isinstance(sdfg, (dace.sdfg.SDFGState, dace.sdfg.ScopeSubgraphView)): sdfg = sdfg.parent return sdfg.arrays[self.data] diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index 4b36fad4fe..a2c7b9a43c 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -175,17 +175,18 @@ def replace_datadesc_names(sdfg, repl: Dict[str, str]): sdfg.constants_prop[repl[aname]] = sdfg.constants_prop[aname] del sdfg.constants_prop[aname] - # Replace in interstate edges - for e in sdfg.edges(): - e.data.replace_dict(repl, replace_keys=False) - - for state in sdfg.nodes(): - # Replace in access nodes - for node in state.data_nodes(): - if node.data in repl: - node.data = repl[node.data] - - # Replace in memlets - for edge in state.edges(): - if edge.data.data in repl: - edge.data.data = repl[edge.data.data] + for cf in sdfg.all_control_flow_regions(): + # Replace in interstate edges + for e in cf.edges(): + e.data.replace_dict(repl, replace_keys=False) + + for block in cf.nodes(): + if isinstance(block, dace.SDFGState): + # Replace in access nodes + for node in block.data_nodes(): + if node.data in repl: + node.data = repl[node.data] + # Replace in memlets + for edge in block.edges(): + if edge.data.data in repl: + edge.data.data = repl[edge.data.data] diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index a85e773337..fdf8835c7e 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -30,7 +30,7 @@ from dace.frontend.python import astutils, wrappers from dace.sdfg import nodes as nd from dace.sdfg.graph import OrderedDiGraph, Edge, SubgraphView -from dace.sdfg.state import SDFGState +from dace.sdfg.state import SDFGState, ControlFlowRegion from dace.sdfg.propagation import propagate_memlets_sdfg from dace.distr_types import ProcessGrid, SubArray, RedistrArray from dace.dtypes import validate_name @@ -402,7 +402,7 @@ def label(self): @make_properties -class SDFG(OrderedDiGraph[SDFGState, InterstateEdge]): +class SDFG(ControlFlowRegion): """ The main intermediate representation of code in DaCe. A Stateful DataFlow multiGraph (SDFG) is a directed graph of directed @@ -499,8 +499,6 @@ def __init__(self, self._parent_sdfg = None self._parent_nsdfg_node = None self._sdfg_list = [self] - self._start_state: Optional[int] = None - self._cached_start_state: Optional[SDFGState] = None self._arrays = NestedDict() # type: Dict[str, dt.Array] self._labels: Set[str] = set() self.global_code = {'frame': CodeBlock("", dtypes.Language.CPP)} @@ -531,14 +529,14 @@ def __deepcopy__(self, memo): memo[id(self)] = result for k, v in self.__dict__.items(): # Skip derivative attributes - if k in ('_cached_start_state', '_edges', '_nodes', '_parent', '_parent_sdfg', '_parent_nsdfg_node', + if k in ('_cached_start_block', '_edges', '_nodes', '_parent', '_parent_sdfg', '_parent_nsdfg_node', '_sdfg_list', '_transformation_hist'): continue setattr(result, k, copy.deepcopy(v, memo)) # Copy edges and nodes result._edges = copy.deepcopy(self._edges, memo) result._nodes = copy.deepcopy(self._nodes, memo) - result._cached_start_state = copy.deepcopy(self._cached_start_state, memo) + result._cached_start_block = copy.deepcopy(self._cached_start_block, memo) # Copy parent attributes for k in ('_parent', '_parent_sdfg', '_parent_nsdfg_node'): if id(getattr(self, k)) in memo: @@ -583,7 +581,7 @@ def to_json(self, hash=False): tmp['attributes']['constants_prop'] = json.loads(dace.serialize.dumps(tmp['attributes']['constants_prop'])) tmp['sdfg_list_id'] = int(self.sdfg_id) - tmp['start_state'] = self._start_state + tmp['start_state'] = self._start_block tmp['attributes']['name'] = self.name if hash: @@ -627,7 +625,7 @@ def from_json(cls, json_obj, context_info=None): ret.add_edge(nodelist[int(e.src)], nodelist[int(e.dst)], e.data) if 'start_state' in json_obj: - ret._start_state = json_obj['start_state'] + ret._start_block = json_obj['start_state'] return ret @@ -753,14 +751,7 @@ def replace_dict(self, for array in self.arrays.values(): replace_properties_dict(array, repldict, symrepl) - if replace_in_graph: - # Replace in inter-state edges - for edge in self.edges(): - edge.data.replace_dict(repldict, replace_keys=replace_keys) - - # Replace in states - for state in self.nodes(): - state.replace_dict(repldict, symrepl) + super().replace_dict(repldict, symrepl, replace_in_graph, replace_keys) def add_symbol(self, name, stype): """ Adds a symbol to the SDFG. @@ -787,34 +778,11 @@ def remove_symbol(self, name): @property def start_state(self): - """ Returns the starting state of this SDFG. """ - if self._cached_start_state is not None: - return self._cached_start_state - - source_nodes = self.source_nodes() - if len(source_nodes) == 1: - self._cached_start_state = source_nodes[0] - return source_nodes[0] - # If starting state is ambiguous (i.e., loop to initial state or more - # than one possible start state), allow manually overriding start state - if self._start_state is not None: - self._cached_start_state = self.node(self._start_state) - return self._cached_start_state - raise ValueError('Ambiguous or undefined starting state for SDFG, ' - 'please use "is_start_state=True" when adding the ' - 'starting state with "add_state"') + return self.start_block @start_state.setter def start_state(self, state_id): - """ Manually sets the starting state of this SDFG. - - :param state_id: The node ID (use `node_id(state)`) of the - state to set. - """ - if state_id < 0 or state_id >= self.number_of_nodes(): - raise ValueError("Invalid state ID") - self._start_state = state_id - self._cached_start_state = self.node(state_id) + self.start_block = state_id def set_global_code(self, cpp_code: str, location: str = 'frame'): """ @@ -1127,7 +1095,7 @@ def remove_data(self, name, validate=True): # Verify that there are no access nodes that use this data if validate: - for state in self.nodes(): + for state in self.states(): for node in state.nodes(): if isinstance(node, nd.AccessNode) and node.data == name: raise ValueError(f"Cannot remove data descriptor " @@ -1243,75 +1211,14 @@ def parent_sdfg(self, value): def parent_nsdfg_node(self, value): self._parent_nsdfg_node = value - def add_node(self, node, is_start_state=False): - """ Adds a new node to the SDFG. Must be an SDFGState or a subclass - thereof. - - :param node: The node to add. - :param is_start_state: If True, sets this node as the starting - state. - """ - if not isinstance(node, SDFGState): - raise TypeError("Expected SDFGState, got " + str(type(node))) - super(SDFG, self).add_node(node) - self._cached_start_state = None - if is_start_state is True: - self.start_state = len(self.nodes()) - 1 - self._cached_start_state = node - def remove_node(self, node: SDFGState): - if node is self._cached_start_state: - self._cached_start_state = None + if node is self._cached_start_block: + self._cached_start_block = None return super().remove_node(node) - def add_edge(self, u, v, edge): - """ Adds a new edge to the SDFG. Must be an InterstateEdge or a - subclass thereof. - - :param u: Source node. - :param v: Destination node. - :param edge: The edge to add. - """ - if not isinstance(u, SDFGState): - raise TypeError("Expected SDFGState, got: {}".format(type(u).__name__)) - if not isinstance(v, SDFGState): - raise TypeError("Expected SDFGState, got: {}".format(type(v).__name__)) - if not isinstance(edge, InterstateEdge): - raise TypeError("Expected InterstateEdge, got: {}".format(type(edge).__name__)) - if v is self._cached_start_state: - self._cached_start_state = None - return super(SDFG, self).add_edge(u, v, edge) - def states(self): - """ Alias that returns the nodes (states) in this SDFG. """ - return self.nodes() - - def all_nodes_recursive(self) -> Iterator[Tuple[nd.Node, Union['SDFG', 'SDFGState']]]: - """ Iterate over all nodes in this SDFG, including states, nodes in - states, and recursive states and nodes within nested SDFGs, - returning tuples on the form (node, parent), where the parent is - either the SDFG (for states) or a DFG (nodes). """ - for node in self.nodes(): - yield node, self - yield from node.all_nodes_recursive() - - def all_sdfgs_recursive(self): - """ Iterate over this and all nested SDFGs. """ - yield self - for state in self.nodes(): - for node in state.nodes(): - if isinstance(node, nd.NestedSDFG): - yield from node.sdfg.all_sdfgs_recursive() - - def all_edges_recursive(self): - """ Iterate over all edges in this SDFG, including state edges, - inter-state edges, and recursively edges within nested SDFGs, - returning tuples on the form (edge, parent), where the parent is - either the SDFG (for states) or a DFG (nodes). """ - for e in self.edges(): - yield e, self - for node in self.nodes(): - yield from node.all_edges_recursive() + """ Returns the states in this SDFG, recursing into state scope blocks. """ + return list(self.all_states()) def arrays_recursive(self): """ Iterate over all arrays in this SDFG, including arrays within @@ -1323,19 +1230,15 @@ def arrays_recursive(self): if isinstance(node, nd.NestedSDFG): yield from node.sdfg.arrays_recursive() - def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]: - """ - Returns a set of symbol names that are used by the SDFG, but not - defined within it. This property is used to determine the symbolic - parameters of the SDFG. - - :param all_symbols: If False, only returns the set of symbols that will be used - in the generated code and are needed as arguments. - :param keep_defined_in_mapping: If True, symbols defined in inter-state edges that are in the symbol mapping - will be removed from the set of defined symbols. - """ - defined_syms = set() - free_syms = set() + def _used_symbols_internal(self, + all_symbols: bool, + defined_syms: Optional[Set]=None, + free_syms: Optional[Set]=None, + used_before_assignment: Optional[Set]=None, + keep_defined_in_mapping: bool=False) -> Tuple[Set[str], Set[str], Set[str]]: + defined_syms = set() if defined_syms is None else defined_syms + free_syms = set() if free_syms is None else free_syms + used_before_assignment = set() if used_before_assignment is None else used_before_assignment # Exclude data descriptor names and constants for name in self.arrays.keys(): @@ -1349,54 +1252,10 @@ def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) - for code in self.exit_code.values(): free_syms |= symbolic.symbols_in_code(code.as_string, self.symbols.keys()) - # Add free state symbols - used_before_assignment = set() - - try: - ordered_states = self.topological_sort(self.start_state) - except ValueError: # Failsafe (e.g., for invalid or empty SDFGs) - ordered_states = self.nodes() - - for state in ordered_states: - state_fsyms = state.used_symbols(all_symbols) - free_syms |= state_fsyms - - # Add free inter-state symbols - for e in self.out_edges(state): - # NOTE: First we get the true InterstateEdge free symbols, then we compute the newly defined symbols by - # subracting the (true) free symbols from the edge's assignment keys. This way we can correctly - # compute the symbols that are used before being assigned. - efsyms = e.data.used_symbols(all_symbols) - defined_syms |= set(e.data.assignments.keys()) - (efsyms | state_fsyms) - used_before_assignment.update(efsyms - defined_syms) - free_syms |= efsyms - - # Remove symbols that were used before they were assigned - defined_syms -= used_before_assignment - - # Remove from defined symbols those that are in the symbol mapping - if self.parent_nsdfg_node is not None and keep_defined_in_mapping: - defined_syms -= set(self.parent_nsdfg_node.symbol_mapping.keys()) - - # Add the set of SDFG symbol parameters - # If all_symbols is False, those symbols would only be added in the case of non-Python tasklets - if all_symbols: - free_syms |= set(self.symbols.keys()) - - # Subtract symbols defined in inter-state edges and constants - return free_syms - defined_syms - - @property - def free_symbols(self) -> Set[str]: - """ - Returns a set of symbol names that are used by the SDFG, but not - defined within it. This property is used to determine the symbolic - parameters of the SDFG and verify that ``SDFG.symbols`` is complete. - - :note: Assumes that the graph is valid (i.e., without undefined or - overlapping symbols). - """ - return self.used_symbols(all_symbols=True) + return super()._used_symbols_internal( + all_symbols=all_symbols, keep_defined_in_mapping=keep_defined_in_mapping, + defined_syms=defined_syms, free_syms=free_syms, used_before_assignment=used_before_assignment + ) def get_all_toplevel_symbols(self) -> Set[str]: """ @@ -1608,16 +1467,16 @@ def shared_transients(self, check_toplevel=True) -> List[str]: shared = [] # If a transient is present in an inter-state edge, it is shared - for interstate_edge in self.edges(): + for interstate_edge in self.all_interstate_edges(): for sym in interstate_edge.data.free_symbols: if sym in self.arrays and self.arrays[sym].transient: seen[sym] = interstate_edge shared.append(sym) # If transient is accessed in more than one state, it is shared - for state in self.nodes(): - for node in state.nodes(): - if isinstance(node, nd.AccessNode) and node.desc(self).transient: + for state in self.states(): + for node in state.data_nodes(): + if node.desc(self).transient: if (check_toplevel and node.desc(self).toplevel) or (node.data in seen and seen[node.data] != state): shared.append(node.data) @@ -1706,62 +1565,6 @@ def from_file(filename: str) -> 'SDFG': # Dynamic SDFG creation API ############################## - def add_state(self, label=None, is_start_state=False) -> 'SDFGState': - """ Adds a new SDFG state to this graph and returns it. - - :param label: State label. - :param is_start_state: If True, resets SDFG starting state to this - state. - :return: A new SDFGState object. - """ - if self._labels is None or len(self._labels) != self.number_of_nodes(): - self._labels = set(s.label for s in self.nodes()) - label = label or 'state' - existing_labels = self._labels - label = dt.find_new_name(label, existing_labels) - state = SDFGState(label, self) - self._labels.add(label) - - self.add_node(state, is_start_state=is_start_state) - return state - - def add_state_before(self, state: 'SDFGState', label=None, is_start_state=False) -> 'SDFGState': - """ Adds a new SDFG state before an existing state, reconnecting - predecessors to it instead. - - :param state: The state to prepend the new state before. - :param label: State label. - :param is_start_state: If True, resets SDFG starting state to this - state. - :return: A new SDFGState object. - """ - new_state = self.add_state(label, is_start_state) - # Reconnect - for e in self.in_edges(state): - self.remove_edge(e) - self.add_edge(e.src, new_state, e.data) - # Add unconditional connection between the new state and the current - self.add_edge(new_state, state, InterstateEdge()) - return new_state - - def add_state_after(self, state: 'SDFGState', label=None, is_start_state=False) -> 'SDFGState': - """ Adds a new SDFG state after an existing state, reconnecting - it to the successors instead. - - :param state: The state to append the new state after. - :param label: State label. - :param is_start_state: If True, resets SDFG starting state to this - state. - :return: A new SDFGState object. - """ - new_state = self.add_state(label, is_start_state) - # Reconnect - for e in self.out_edges(state): - self.remove_edge(e) - self.add_edge(new_state, e.dst, e.data) - # Add unconditional connection between the current and the new state - self.add_edge(state, new_state, InterstateEdge()) - return new_state def _find_new_name(self, name: str): """ Tries to find a new name by adding an underscore and a number. """ @@ -2482,7 +2285,7 @@ def __call__(self, *args, **kwargs): def fill_scope_connectors(self): """ Fills missing scope connectors (i.e., "IN_#"/"OUT_#" on entry/exit nodes) according to data on the memlets. """ - for state in self.nodes(): + for state in self.states(): state.fill_scope_connectors() def predecessor_state_transitions(self, state): diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 1ff8fe4cf1..097365fbc3 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2,6 +2,7 @@ """ Contains classes of a single SDFG state and dataflow subgraphs. """ import ast +import abc import collections import copy import inspect @@ -19,7 +20,7 @@ from dace.properties import (CodeBlock, DictProperty, EnumProperty, Property, SubsetProperty, SymbolicProperty, CodeProperty, make_properties) from dace.sdfg import nodes as nd -from dace.sdfg.graph import MultiConnectorEdge, OrderedMultiDiConnectorGraph, SubgraphView +from dace.sdfg.graph import MultiConnectorEdge, OrderedMultiDiConnectorGraph, SubgraphView, OrderedDiGraph, Edge from dace.sdfg.propagation import propagate_memlet from dace.sdfg.validation import validate_state from dace.subsets import Range, Subset @@ -28,6 +29,11 @@ import dace.sdfg.scope +NodeT = Union[nd.Node, 'ControlFlowBlock'] +EdgeT = Union[MultiConnectorEdge[mm.Memlet], Edge['dace.sdfg.InterstateEdge']] +GraphT = Union['ControlFlowRegion', 'SDFGState'] + + def _getdebuginfo(old_dinfo=None) -> dtypes.DebugInfo: """ Returns a DebugInfo object for the position that called this function. @@ -66,13 +72,248 @@ def _make_iterators(ndrange): return params, map_range -class StateGraphView(object): +class BlockGraphView(object): """ - Read-only view interface of an SDFG state, containing methods for memlet - tracking, traversal, subgraph creation, queries, and replacements. - ``SDFGState`` and ``StateSubgraphView`` inherit from this class to share + Read-only view interface of an SDFG control flow block, containing methods for memlet tracking, traversal, subgraph + creation, queries, and replacements. ``ControlFlowBlock`` and ``StateSubgraphView`` inherit from this class to share methods. """ + + + ################################################################### + # Typing overrides + + @overload + def nodes(self) -> List[NodeT]: + ... + + @overload + def edges(self) -> List[EdgeT]: + ... + + @overload + def in_degree(self, node: NodeT) -> int: + ... + + @overload + def out_degree(self, node: NodeT) -> int: + ... + + ################################################################### + # Traversal methods + + @abc.abstractmethod + def all_nodes_recursive(self) -> Iterator[Tuple[NodeT, GraphT]]: + """ + Iterate over all nodes in this graph or subgraph. + This includes control flow blocks, nodes in those blocks, and recursive control flow blocks and nodes within + nested SDFGs. It returns tuples of the form (node, parent), where the node is either a dataflow node, in which + case the parent is an SDFG state, or a control flow block, in which case the parent is a control flow graph + (i.e., an SDFG or a scope block). + """ + raise NotImplementedError() + + @abc.abstractmethod + def all_edges_recursive(self) -> Iterator[Tuple[EdgeT, GraphT]]: + """ + Iterate over all edges in this graph or subgraph. + This includes dataflow edges, inter-state edges, and recursive edges within nested SDFGs. It returns tuples of + the form (edge, parent), where the edge is either a dataflow edge, in which case the parent is an SDFG state, or + an inter-stte edge, in which case the parent is a control flow graph (i.e., an SDFG or a scope block). + """ + raise NotImplementedError() + + @abc.abstractmethod + def data_nodes(self) -> List[nd.AccessNode]: + """ + Returns all data nodes (i.e., AccessNodes, arrays) present in this graph or subgraph. + Note: This does not recurse into nested SDFGs. + """ + raise NotImplementedError() + + @abc.abstractmethod + def entry_node(self, node: nd.Node) -> nd.EntryNode: + """ Returns the entry node that wraps the current node, or None if it is top-level in a state. """ + raise NotImplementedError() + + @abc.abstractmethod + def exit_node(self, entry_node: nd.EntryNode) -> nd.ExitNode: + """ Returns the exit node leaving the context opened by the given entry node. """ + raise NotImplementedError() + + ################################################################### + # Memlet-tracking methods + + @abc.abstractmethod + def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnectorEdge[mm.Memlet]]: + """ + Given one edge, returns a list of edges representing a path between its source and sink nodes. + Used for memlet tracking. + + :note: Behavior is undefined when there is more than one path involving this edge. + :param edge: An edge within a state (memlet). + :return: A list of edges from a source node to a destination node. + """ + raise NotImplementedError() + + @abc.abstractmethod + def memlet_tree(self, edge: MultiConnectorEdge) -> mm.MemletTree: + """ + Given one edge, returns a tree of edges between its node source(s) and sink(s). + Used for memlet tracking. + + :param edge: An edge within a state (memlet). + :return: A tree of edges whose root is the source/sink node (depending on direction) and associated children + edges. + """ + raise NotImplementedError() + + @abc.abstractmethod + def in_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + """ + Returns a generator over edges entering the given connector of the given node. + + :param node: Destination node of edges. + :param connector: Destination connector of edges. + """ + raise NotImplementedError() + + @abc.abstractmethod + def out_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + """ + Returns a generator over edges exiting the given connector of the given node. + + :param node: Source node of edges. + :param connector: Source connector of edges. + """ + raise NotImplementedError() + + @abc.abstractmethod + def edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + """ + Returns a generator over edges entering or exiting the given connector of the given node. + + :param node: Source/destination node of edges. + :param connector: Source/destination connector of edges. + """ + raise NotImplementedError() + + ################################################################### + # Query, subgraph, and replacement methods + + @abc.abstractmethod + def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]: + """ + Returns a set of symbol names that are used in the graph. + + :param all_symbols: If False, only returns symbols that are needed as arguments (only used in generated code). + :param keep_defined_in_mapping: If True, symbols defined in inter-state edges that are in the symbol mapping + will be removed from the set of defined symbols. + """ + raise NotImplementedError() + + @property + def free_symbols(self) -> Set[str]: + """ + Returns a set of symbol names that are used, but not defined, in this graph view. + In the case of an SDFG, this property is used to determine the symbolic parameters of the SDFG and + verify that ``SDFG.symbols`` is complete. + + :note: Assumes that the graph is valid (i.e., without undefined or overlapping symbols). + """ + return self.used_symbols(all_symbols=True) + + @abc.abstractmethod + def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: + """ + Determines what data is read and written in this graph. + Does not include reads to subsets of containers that have previously been written within the same state. + + :return: A two-tuple of sets of things denoting ({data read}, {data written}). + """ + raise NotImplementedError() + + @abc.abstractmethod + def unordered_arglist(self, + defined_syms=None, + shared_transients=None) -> Tuple[Dict[str, dt.Data], Dict[str, dt.Data]]: + raise NotImplementedError() + + def arglist(self, defined_syms=None, shared_transients=None) -> Dict[str, dt.Data]: + """ + Returns an ordered dictionary of arguments (names and types) required to invoke this subgraph. + + The arguments differ from SDFG.arglist, but follow the same order, + namely: , . + + Data arguments contain: + * All used non-transient data containers in the subgraph + * All used transient data containers that were allocated outside. + This includes data from memlets, transients shared across multiple states, and transients that could not + be allocated within the subgraph (due to their ``AllocationLifetime`` or according to the + ``dtypes.can_allocate`` function). + + Scalar arguments contain: + * Free symbols in this state/subgraph. + * All transient and non-transient scalar data containers used in this subgraph. + + This structure will create a sorted list of pointers followed by a sorted list of PoDs and structs. + + :return: An ordered dictionary of (name, data descriptor type) of all the arguments, sorted as defined here. + """ + data_args, scalar_args = self.unordered_arglist(defined_syms, shared_transients) + + # Fill up ordered dictionary + result = collections.OrderedDict() + for k, v in itertools.chain(sorted(data_args.items()), sorted(scalar_args.items())): + result[k] = v + + return result + + def signature_arglist(self, with_types=True, for_call=False): + """ Returns a list of arguments necessary to call this state or subgraph, formatted as a list of C definitions. + + :param with_types: If True, includes argument types in the result. + :param for_call: If True, returns arguments that can be used when calling the SDFG. + :return: A list of strings. For example: `['float *A', 'int b']`. + """ + return [v.as_arg(name=k, with_types=with_types, for_call=for_call) for k, v in self.arglist().items()] + + @abc.abstractmethod + def top_level_transients(self) -> Set[str]: + """Iterate over top-level transients of this graph.""" + raise NotImplementedError() + + @abc.abstractmethod + def all_transients(self) -> List[str]: + """Iterate over all transients in this graph.""" + raise NotImplementedError() + + @abc.abstractmethod + def replace(self, name: str, new_name: str): + """ + Finds and replaces all occurrences of a symbol or array in this graph. + + :param name: Name to find. + :param new_name: Name to replace. + """ + raise NotImplementedError() + + @abc.abstractmethod + def replace_dict(self, + repl: Dict[str, str], + symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None): + """ + Finds and replaces all occurrences of a set of symbols or arrays in this graph. + + :param repl: Mapping from names to replacements. + :param symrepl: Optional symbolic version of ``repl``. + """ + raise NotImplementedError() + + +@make_properties +class DataflowGraphView(BlockGraphView, abc.ABC): def __init__(self, *args, **kwargs): self._clear_scopedict_cache() @@ -91,29 +332,29 @@ def edges(self) -> List[MultiConnectorEdge[mm.Memlet]]: ################################################################### # Traversal methods - def all_nodes_recursive(self): + def all_nodes_recursive(self) -> Iterator[Tuple[NodeT, GraphT]]: for node in self.nodes(): yield node, self if isinstance(node, nd.NestedSDFG): yield from node.sdfg.all_nodes_recursive() - def all_edges_recursive(self): + def all_edges_recursive(self) -> Iterator[Tuple[EdgeT, GraphT]]: for e in self.edges(): yield e, self for node in self.nodes(): if isinstance(node, nd.NestedSDFG): yield from node.sdfg.all_edges_recursive() - def data_nodes(self): + def data_nodes(self) -> List[nd.AccessNode]: """ Returns all data_nodes (arrays) present in this state. """ return [n for n in self.nodes() if isinstance(n, nd.AccessNode)] - def entry_node(self, node: nd.Node) -> nd.EntryNode: + def entry_node(self, node: nd.Node) -> Optional[nd.EntryNode]: """ Returns the entry node that wraps the current node, or None if it is top-level in a state. """ return self.scope_dict()[node] - def exit_node(self, entry_node: nd.EntryNode) -> nd.ExitNode: + def exit_node(self, entry_node: nd.EntryNode) -> Optional[nd.ExitNode]: """ Returns the exit node leaving the context opened by the given entry node. """ node_to_children = self.scope_children() @@ -152,7 +393,7 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto result.insert(0, next_edge) curedge = next_edge - # Prepend outgoing edges until reaching the sink node + # Append outgoing edges until reaching the sink node curedge = edge while not isinstance(curedge.dst, (nd.CodeNode, nd.AccessNode)): # Trace through scope entry using IN_# -> OUT_# @@ -168,13 +409,6 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto return result def memlet_tree(self, edge: MultiConnectorEdge) -> mm.MemletTree: - """ Given one edge, returns a tree of edges between its node source(s) - and sink(s). Used for memlet tracking. - - :param edge: An edge within this state. - :return: A tree of edges whose root is the source/sink node - (depending on direction) and associated children edges. - """ propagate_forward = False propagate_backward = False if ((isinstance(edge.src, nd.EntryNode) and edge.src_conn is not None) or @@ -246,30 +480,12 @@ def traverse(node): return traverse(tree_root) def in_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: - """ Returns a generator over edges entering the given connector of the - given node. - - :param node: Destination node of edges. - :param connector: Destination connector of edges. - """ return (e for e in self.in_edges(node) if e.dst_conn == connector) def out_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: - """ Returns a generator over edges exiting the given connector of the - given node. - - :param node: Source node of edges. - :param connector: Source connector of edges. - """ return (e for e in self.out_edges(node) if e.src_conn == connector) def edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: - """ Returns a generator over edges entering or exiting the given - connector of the given node. - - :param node: Source/destination node of edges. - :param connector: Source/destination connector of edges. - """ return itertools.chain(self.in_edges_by_connector(node, connector), self.out_edges_by_connector(node, connector)) @@ -297,8 +513,6 @@ def scope_tree(self) -> 'dace.sdfg.scope.ScopeTree': result = {} - sdfg_symbols = self.parent.symbols.keys() - # Get scopes for node, scopenodes in sdc.items(): if node is None: @@ -325,15 +539,7 @@ def scope_leaves(self) -> List['dace.sdfg.scope.ScopeTree']: self._scope_leaves_cached = [scope for scope in st.values() if len(scope.children) == 0] return copy.copy(self._scope_leaves_cached) - def scope_dict(self, return_ids: bool = False, validate: bool = True) -> Dict[nd.Node, Optional[nd.Node]]: - """ Returns a dictionary that maps each SDFG node to its parent entry - node, or to None if the node is not in any scope. - - :param return_ids: Return node ID numbers instead of node objects. - :param validate: Ensure that the graph is not malformed when - computing dictionary. - :return: The mapping from a node to its parent scope entry node. - """ + def scope_dict(self, return_ids: bool = False, validate: bool = True) -> Dict[nd.Node, Union['SDFGState', nd.Node]]: from dace.sdfg.scope import _scope_dict_inner, _scope_dict_to_ids result = None result = copy.copy(self._scope_dict_toparent_cached) @@ -367,16 +573,7 @@ def scope_dict(self, return_ids: bool = False, validate: bool = True) -> Dict[nd def scope_children(self, return_ids: bool = False, - validate: bool = True) -> Dict[Optional[nd.EntryNode], List[nd.Node]]: - """ Returns a dictionary that maps each SDFG entry node to its children, - not including the children of children entry nodes. The key `None` - contains a list of top-level nodes (i.e., not in any scope). - - :param return_ids: Return node ID numbers instead of node objects. - :param validate: Ensure that the graph is not malformed when - computing dictionary. - :return: The mapping from a node to a list of children nodes. - """ + validate: bool = True) -> Dict[Union[nd.Node, 'SDFGState'], List[nd.Node]]: from dace.sdfg.scope import _scope_dict_inner, _scope_dict_to_ids result = None if self._scope_dict_tochildren_cached is not None: @@ -419,13 +616,7 @@ def is_leaf_memlet(self, e): return False return True - def used_symbols(self, all_symbols: bool) -> Set[str]: - """ - Returns a set of symbol names that are used in the state. - - :param all_symbols: If False, only returns the set of symbols that will be used - in the generated code and are needed as arguments. - """ + def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]: state = self.graph if isinstance(self, SubgraphView) else self sdfg = state.parent new_symbols = set() @@ -579,33 +770,9 @@ def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: read_set, write_set = self._read_and_write_sets() return set(read_set.keys()), set(write_set.keys()) - def arglist(self, defined_syms=None, shared_transients=None) -> Dict[str, dt.Data]: - """ - Returns an ordered dictionary of arguments (names and types) required - to invoke this SDFG state or subgraph thereof. - - The arguments differ from SDFG.arglist, but follow the same order, - namely: , . - - Data arguments contain: - * All used non-transient data containers in the subgraph - * All used transient data containers that were allocated outside. - This includes data from memlets, transients shared across multiple - states, and transients that could not be allocated within the - subgraph (due to their ``AllocationLifetime`` or according to the - ``dtypes.can_allocate`` function). - - Scalar arguments contain: - * Free symbols in this state/subgraph. - * All transient and non-transient scalar data containers used in - this subgraph. - - This structure will create a sorted list of pointers followed by a - sorted list of PoDs and structs. - - :return: An ordered dictionary of (name, data descriptor type) of all - the arguments, sorted as defined here. - """ + def unordered_arglist(self, + defined_syms=None, + shared_transients=None) -> Tuple[Dict[str, dt.Data], Dict[str, dt.Data]]: sdfg: 'dace.sdfg.SDFG' = self.parent shared_transients = shared_transients or sdfg.shared_transients() sdict = self.scope_dict() @@ -699,12 +866,7 @@ def arglist(self, defined_syms=None, shared_transients=None) -> Dict[str, dt.Dat if not str(k).startswith('__dace') and str(k) not in sdfg.constants }) - # Fill up ordered dictionary - result = collections.OrderedDict() - for k, v in itertools.chain(sorted(data_args.items()), sorted(scalar_args.items())): - result[k] = v - - return result + return data_args, scalar_args def signature_arglist(self, with_types=True, for_call=False): """ Returns a list of arguments necessary to call this state or @@ -749,22 +911,212 @@ def replace(self, name: str, new_name: str): def replace_dict(self, repl: Dict[str, str], symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None): - """ Finds and replaces all occurrences of a set of symbols or arrays in this state. - - :param repl: Mapping from names to replacements. - :param symrepl: Optional symbolic version of ``repl``. - """ from dace.sdfg.replace import replace_dict replace_dict(self, repl, symrepl) @make_properties -class SDFGState(OrderedMultiDiConnectorGraph[nd.Node, mm.Memlet], StateGraphView): +class ControlGraphView(BlockGraphView, abc.ABC): + + ################################################################### + # Typing overrides + + @overload + def nodes(self) -> List['ControlFlowBlock']: + ... + + @overload + def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]: + ... + + ################################################################### + # Traversal methods + + def all_nodes_recursive(self) -> Iterator[Tuple[NodeT, GraphT]]: + for node in self.nodes(): + yield node, self + yield from node.all_nodes_recursive() + + def all_edges_recursive(self) -> Iterator[Tuple[EdgeT, GraphT]]: + for e in self.edges(): + yield e, self + for node in self.nodes(): + yield from node.all_edges_recursive() + + def data_nodes(self) -> List[nd.AccessNode]: + data_nodes = [] + for node in self.nodes(): + data_nodes.extend(node.data_nodes()) + return data_nodes + + def entry_node(self, node: nd.Node) -> Optional[nd.EntryNode]: + for block in self.nodes(): + if node in block.nodes(): + return block.exit_node(node) + return None + + def exit_node(self, entry_node: nd.EntryNode) -> Optional[nd.ExitNode]: + for block in self.nodes(): + if entry_node in block.nodes(): + return block.exit_node(entry_node) + return None + + ################################################################### + # Memlet-tracking methods + + def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnectorEdge[mm.Memlet]]: + for block in self.nodes(): + if edge in block.edges(): + return block.memlet_path(edge) + return [] + + def memlet_tree(self, edge: MultiConnectorEdge) -> mm.MemletTree: + for block in self.nodes(): + if edge in block.edges(): + return block.memlet_tree(edge) + return mm.MemletTree(edge) + + def in_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + for block in self.nodes(): + if node in block.nodes(): + return block.in_edges_by_connector(node, connector) + return [] + + def out_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + for block in self.nodes(): + if node in block.nodes(): + return block.out_edges_by_connector(node, connector) + return [] + + def edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + for block in self.nodes(): + if node in block.nodes(): + return block.edges_by_connector(node, connector) + + ################################################################### + # Query, subgraph, and replacement methods + + @abc.abstractmethod + def _used_symbols_internal(self, + all_symbols: bool, + defined_syms: Optional[Set] = None, + free_syms: Optional[Set] = None, + used_before_assignment: Optional[Set] = None, + keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: + raise NotImplementedError() + + def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]: + return self._used_symbols_internal(all_symbols, keep_defined_in_mapping=keep_defined_in_mapping)[0] + + def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: + read_set = set() + write_set = set() + for block in self.nodes(): + for edge in self.in_edges(block): + read_set |= edge.data.free_symbols & self.sdfg.arrays.keys() + rs, ws = block.read_and_write_sets() + read_set.update(rs) + write_set.update(ws) + return read_set, write_set + + def unordered_arglist(self, + defined_syms=None, + shared_transients=None) -> Tuple[Dict[str, dt.Data], Dict[str, dt.Data]]: + data_args = {} + scalar_args = {} + for block in self.nodes(): + n_data_args, n_scalar_args = block.unordered_arglist(defined_syms, shared_transients) + data_args.update(n_data_args) + scalar_args.update(n_scalar_args) + return data_args, scalar_args + + def top_level_transients(self) -> Set[str]: + res = set() + for block in self.nodes(): + res.update(block.top_level_transients()) + return res + + def all_transients(self) -> List[str]: + res = [] + for block in self.nodes(): + res.extend(block.all_transients()) + return dtypes.deduplicate(res) + + def replace(self, name: str, new_name: str): + for n in self.nodes(): + n.replace(name, new_name) + + def replace_dict(self, + repl: Dict[str, str], + symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None, + replace_in_graph: bool = True, replace_keys: bool = False): + symrepl = symrepl or { + symbolic.symbol(k): symbolic.pystr_to_symbolic(v) if isinstance(k, str) else v + for k, v in repl.items() + } + + if replace_in_graph: + # Replace in inter-state edges + for edge in self.edges(): + edge.data.replace_dict(repl, replace_keys=replace_keys) + + # Replace in states + for state in self.nodes(): + state.replace_dict(repl, symrepl) + +@make_properties +class ControlFlowBlock(BlockGraphView, abc.ABC): + + is_collapsed = Property(dtype=bool, desc='Show this block as collapsed', default=False) + + _label: str + + def __init__(self, label: str=''): + super(ControlFlowBlock, self).__init__() + self._label = label + self._default_lineinfo = None + self.is_collapsed = False + + def set_default_lineinfo(self, lineinfo: dace.dtypes.DebugInfo): + """ + Sets the default source line information to be lineinfo, or None to + revert to default mode. + """ + self._default_lineinfo = lineinfo + + def to_json(self, parent=None): + tmp = { + 'type': self.__class__.__name__, + 'collapsed': self.is_collapsed, + 'label': self._label, + 'id': parent.node_id(self) if parent is not None else None, + } + return tmp + + def __str__(self): + return self._label + + def __repr__(self) -> str: + return f'ControlFlowBlock ({self.label})' + + @property + def label(self) -> str: + return self._label + + @label.setter + def label(self, label: str): + self._label = label + + @property + def name(self) -> str: + return self._label + + +@make_properties +class SDFGState(OrderedMultiDiConnectorGraph[nd.Node, mm.Memlet], ControlFlowBlock, DataflowGraphView): """ An acyclic dataflow multigraph in an SDFG, corresponding to a single state in the SDFG state machine. """ - is_collapsed = Property(dtype=bool, desc="Show this node/scope/state as collapsed", default=False) - nosync = Property(dtype=bool, default=False, desc="Do not synchronize at the end of the state") instrument = EnumProperty(dtype=dtypes.InstrumentationType, @@ -803,13 +1155,14 @@ def __init__(self, label=None, sdfg=None, debuginfo=None, location=None): :param debuginfo: Source code locator for debugging. """ from dace.sdfg.sdfg import SDFG # Avoid import loop + OrderedMultiDiConnectorGraph.__init__(self) + ControlFlowBlock.__init__(self, label) super(SDFGState, self).__init__() self._label = label self._parent: SDFG = sdfg self._graph = self # Allowing MemletTrackingView mixin to work self._clear_scopedict_cache() self._debuginfo = debuginfo - self.is_collapsed = False self.nosync = False self.location = location if location is not None else {} self._default_lineinfo = None @@ -839,33 +1192,12 @@ def parent(self): def parent(self, value): self._parent = value - def __str__(self): - return self._label - - @property - def label(self): - return self._label - - @property - def name(self): - return self._label - - def set_label(self, label): - self._label = label - def is_empty(self): return self.number_of_nodes() == 0 def validate(self) -> None: validate_state(self) - def set_default_lineinfo(self, lineinfo: dtypes.DebugInfo): - """ - Sets the default source line information to be lineinfo, or None to - revert to default mode. - """ - self._default_lineinfo = lineinfo - def nodes(self) -> List[nd.Node]: # Added for type hints return super().nodes() @@ -1981,8 +2313,244 @@ def fill_scope_connectors(self): node.add_in_connector(edge.dst_conn) -class StateSubgraphView(SubgraphView, StateGraphView): +class StateSubgraphView(SubgraphView, DataflowGraphView): """ A read-only subgraph view of an SDFG state. """ def __init__(self, graph, subgraph_nodes): super().__init__(graph, subgraph_nodes) + + +@make_properties +class ControlFlowRegion(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.InterstateEdge'], ControlGraphView, + ControlFlowBlock): + + def __init__(self, + label: str=''): + OrderedDiGraph.__init__(self) + ControlGraphView.__init__(self) + ControlFlowBlock.__init__(self, label) + + self._labels: Set[str] = set() + self._start_block: Optional[int] = None + self._cached_start_block: Optional[ControlFlowBlock] = None + + def add_edge(self, src: ControlFlowBlock, dst: ControlFlowBlock, data: 'dace.sdfg.InterstateEdge'): + """ Adds a new edge to the graph. Must be an InterstateEdge or a subclass thereof. + + :param u: Source node. + :param v: Destination node. + :param edge: The edge to add. + """ + if not isinstance(src, ControlFlowBlock): + raise TypeError('Expected ControlFlowBlock, got ' + str(type(src))) + if not isinstance(dst, ControlFlowBlock): + raise TypeError('Expected ControlFlowBlock, got ' + str(type(dst))) + if not isinstance(data, dace.sdfg.InterstateEdge): + raise TypeError('Expected InterstateEdge, got ' + str(type(data))) + if dst is self._cached_start_block: + self._cached_start_block = None + return super().add_edge(src, dst, data) + + def add_node(self, node, is_start_block=False, *, is_start_state: bool=None): + if not isinstance(node, ControlFlowBlock): + raise TypeError('Expected ControlFlowBlock, got ' + str(type(node))) + super().add_node(node) + self._cached_start_block = None + start_block = is_start_block + if is_start_state is not None: + warnings.warn('is_start_state is deprecated, use is_start_block instead', DeprecationWarning) + start_block = is_start_state + + if start_block: + self.start_block = len(self.nodes()) - 1 + self._cached_start_block = node + + def add_state(self, label=None, is_start_block=False, *, is_start_state: bool=None) -> SDFGState: + if self._labels is None or len(self._labels) != self.number_of_nodes(): + self._labels = set(s.label for s in self.nodes()) + label = label or 'state' + existing_labels = self._labels + label = dt.find_new_name(label, existing_labels) + state = SDFGState(label) + state.parent = self + self._labels.add(label) + start_block = is_start_block + if is_start_state is not None: + warnings.warn('is_start_state is deprecated, use is_start_block instead', DeprecationWarning) + start_block = is_start_state + self.add_node(state, is_start_block=start_block) + return state + + def add_state_before(self, state: SDFGState, label=None, is_start_state=False) -> SDFGState: + """ Adds a new SDFG state before an existing state, reconnecting predecessors to it instead. + + :param state: The state to prepend the new state before. + :param label: State label. + :param is_start_state: If True, resets scope block starting state to this state. + :return: A new SDFGState object. + """ + new_state = self.add_state(label, is_start_state) + # Reconnect + for e in self.in_edges(state): + self.remove_edge(e) + self.add_edge(e.src, new_state, e.data) + # Add unconditional connection between the new state and the current + self.add_edge(new_state, state, dace.sdfg.InterstateEdge()) + return new_state + + def add_state_after(self, state: SDFGState, label=None, is_start_state=False) -> SDFGState: + """ Adds a new SDFG state after an existing state, reconnecting it to the successors instead. + + :param state: The state to append the new state after. + :param label: State label. + :param is_start_state: If True, resets SDFG starting state to this state. + :return: A new SDFGState object. + """ + new_state = self.add_state(label, is_start_state) + # Reconnect + for e in self.out_edges(state): + self.remove_edge(e) + self.add_edge(new_state, e.dst, e.data) + # Add unconditional connection between the current and the new state + self.add_edge(state, new_state, dace.sdfg.InterstateEdge()) + return new_state + + @abc.abstractmethod + def _used_symbols_internal(self, + all_symbols: bool, + defined_syms: Optional[Set] = None, + free_syms: Optional[Set] = None, + used_before_assignment: Optional[Set] = None, + keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: + defined_syms = set() if defined_syms is None else defined_syms + free_syms = set() if free_syms is None else free_syms + used_before_assignment = set() if used_before_assignment is None else used_before_assignment + + try: + ordered_blocks = self.topological_sort(self.start_block) + except ValueError: # Failsafe (e.g., for invalid or empty SDFGs) + ordered_blocks = self.nodes() + + for block in ordered_blocks: + state_symbols = set() + if isinstance(block, ControlFlowRegion): + b_free_syms, b_defined_syms, b_used_before_syms = block._used_symbols_internal(all_symbols) + free_syms |= b_free_syms + defined_syms |= b_defined_syms + used_before_assignment |= b_used_before_syms + state_symbols = b_free_syms + else: + state_symbols = block.used_symbols(all_symbols) + free_syms |= state_symbols + + # Add free inter-state symbols + for e in self.out_edges(block): + # NOTE: First we get the true InterstateEdge free symbols, then we compute the newly defined symbols by + # subracting the (true) free symbols from the edge's assignment keys. This way we can correctly + # compute the symbols that are used before being assigned. + efsyms = e.data.used_symbols(all_symbols) + defined_syms |= set(e.data.assignments.keys()) - (efsyms | state_symbols) + used_before_assignment.update(efsyms - defined_syms) + free_syms |= efsyms + + # Remove symbols that were used before they were assigned. + defined_syms -= used_before_assignment + + if isinstance(self, dace.SDFG): + # Remove from defined symbols those that are in the symbol mapping + if self.parent_nsdfg_node is not None and keep_defined_in_mapping: + defined_syms -= set(self.parent_nsdfg_node.symbol_mapping.keys()) + + # Add the set of SDFG symbol parameters + # If all_symbols is False, those symbols would only be added in the case of non-Python tasklets + if all_symbols: + free_syms |= set(self.symbols.keys()) + + # Subtract symbols defined in inter-state edges and constants from the list of free symbols. + free_syms -= defined_syms + + return free_syms, defined_syms, used_before_assignment + + def to_json(self, parent=None): + graph_json = OrderedDiGraph.to_json(self) + block_json = ControlFlowBlock.to_json(self, parent) + graph_json.update(block_json) + return graph_json + + ################################################################### + # Traversal methods + + def all_control_flow_regions(self, recursive=False) -> Iterator['ControlFlowRegion']: + """ Iterate over this and all nested control flow regions. """ + yield self + for block in self.nodes(): + if isinstance(block, SDFGState) and recursive: + for node in block.nodes(): + if isinstance(node, nd.NestedSDFG): + yield from node.sdfg.all_control_flow_regions(recursive=recursive) + elif isinstance(block, ControlFlowRegion): + yield from block.all_control_flow_regions(recursive=recursive) + + def all_sdfgs_recursive(self) -> Iterator['dace.SDFG']: + """ Iterate over this and all nested SDFGs. """ + for cfg in self.all_control_flow_regions(recursive=True): + if isinstance(cfg, dace.SDFG): + yield cfg + + def all_states(self) -> Iterator[SDFGState]: + """ Iterate over all states in this control flow graph. """ + for block in self.nodes(): + if isinstance(block, SDFGState): + yield block + elif isinstance(block, ControlFlowRegion): + yield from block.all_states() + + def all_control_flow_blocks(self, recursive=False) -> Iterator[ControlFlowBlock]: + """ Iterate over all control flow blocks in this control flow graph. """ + for cfg in self.all_control_flow_regions(recursive=recursive): + for block in cfg.nodes(): + yield block + + def all_interstate_edges(self, recursive=False) -> Iterator[Edge['dace.sdfg.InterstateEdge']]: + """ Iterate over all interstate edges in this control flow graph. """ + for cfg in self.all_control_flow_regions(recursive=recursive): + for edge in cfg.edges(): + yield edge + + ################################################################### + # Getters & setters, overrides + + def __str__(self): + return ControlFlowBlock.__str__(self) + + def __repr__(self) -> str: + return f'{self.__class__.__name__} ({self.label})' + + @property + def start_block(self): + """ Returns the starting block of this ControlFlowGraph. """ + if self._cached_start_block is not None: + return self._cached_start_block + + source_nodes = self.source_nodes() + if len(source_nodes) == 1: + self._cached_start_block = source_nodes[0] + return source_nodes[0] + # If the starting block is ambiguous allow manual override. + if self._start_block is not None: + self._cached_start_block = self.node(self._start_block) + return self._cached_start_block + raise ValueError('Ambiguous or undefined starting block for ControlFlowGraph, ' + 'please use "is_start_block=True" when adding the ' + 'starting block with "add_state" or "add_node"') + + @start_block.setter + def start_block(self, block_id): + """ Manually sets the starting block of this ControlFlowGraph. + + :param block_id: The node ID (use `node_id(block)`) of the block to set. + """ + if block_id < 0 or block_id >= self.number_of_nodes(): + raise ValueError('Invalid state ID') + self._start_block = block_id + self._cached_start_block = self.node(block_id) diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 1078414161..621f8a9e16 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -668,7 +668,7 @@ def consolidate_edges(sdfg: SDFG, starting_scope=None) -> int: from dace.sdfg.propagation import propagate_memlets_scope total_consolidated = 0 - for state in sdfg.nodes(): + for state in sdfg.states(): # Start bottom-up if starting_scope and starting_scope.entry not in state.nodes(): continue @@ -1206,8 +1206,8 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> counter = 0 if progress is True or progress is None: fusible_states = 0 - for sd in sdfg.all_sdfgs_recursive(): - fusible_states += sd.number_of_edges() + for cfg in sdfg.all_control_flow_regions(): + fusible_states += cfg.number_of_edges() if progress is True: pbar = tqdm(total=fusible_states, desc='Fusing states') @@ -1217,30 +1217,32 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> for sd in sdfg.all_sdfgs_recursive(): id = sd.sdfg_id - while True: - edges = list(sd.nx.edges) - applied = 0 - skip_nodes = set() - for u, v in edges: - if (progress is None and tqdm is not None and (time.time() - start) > 5): - progress = True - pbar = tqdm(total=fusible_states, desc='Fusing states', initial=counter) - - if u in skip_nodes or v in skip_nodes: - continue - candidate = {StateFusion.first_state: u, StateFusion.second_state: v} - sf = StateFusion() - sf.setup_match(sd, id, -1, candidate, 0, override=True) - if sf.can_be_applied(sd, 0, sd, permissive=permissive): - sf.apply(sd, sd) - applied += 1 - counter += 1 - if progress: - pbar.update(1) - skip_nodes.add(u) - skip_nodes.add(v) - if applied == 0: - break + for cfg in sd.all_control_flow_regions(): + while True: + edges = list(cfg.nx.edges) + applied = 0 + skip_nodes = set() + for u, v in edges: + if (progress is None and tqdm is not None and (time.time() - start) > 5): + progress = True + pbar = tqdm(total=fusible_states, desc='Fusing states', initial=counter) + + if (u in skip_nodes or v in skip_nodes or not isinstance(v, SDFGState) or + not isinstance(u, SDFGState)): + continue + candidate = {StateFusion.first_state: u, StateFusion.second_state: v} + sf = StateFusion() + sf.setup_match(cfg, id, -1, candidate, 0, override=True) + if sf.can_be_applied(cfg, 0, sd, permissive=permissive): + sf.apply(cfg, sd) + applied += 1 + counter += 1 + if progress: + pbar.update(1) + skip_nodes.add(u) + skip_nodes.add(v) + if applied == 0: + break if progress: pbar.close() return counter diff --git a/dace/transformation/dataflow/double_buffering.py b/dace/transformation/dataflow/double_buffering.py index 8ff70a6355..6efe6543ca 100644 --- a/dace/transformation/dataflow/double_buffering.py +++ b/dace/transformation/dataflow/double_buffering.py @@ -128,7 +128,7 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): ############################## # Add initial reads to initial nested state initial_state: sd.SDFGState = nsdfg_node.sdfg.start_state - initial_state.set_label('%s_init' % map_entry.map.label) + initial_state.label = '%s_init' % map_entry.map.label for edge in edges_to_replace: initial_state.add_node(edge.src) rnode = edge.src @@ -152,7 +152,7 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): # Add the main state's contents to the last state, modifying # memlets appropriately. final_state: sd.SDFGState = nsdfg_node.sdfg.sink_nodes()[0] - final_state.set_label('%s_final_computation' % map_entry.map.label) + final_state.label = '%s_final_computation' % map_entry.map.label dup_nstate = copy.deepcopy(nstate) final_state.add_nodes_from(dup_nstate.nodes()) for e in dup_nstate.edges(): @@ -183,7 +183,7 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): nstate.add_edge(rnode, edge.src_conn, wnode, edge.dst_conn, new_memlet) - nstate.set_label('%s_double_buffered' % map_entry.map.label) + nstate.label = '%s_double_buffered' % map_entry.map.label # Divide by loop stride new_expr = symbolic.pystr_to_symbolic('((%s / %s) + 1) %% 2' % (map_param, map_rstride)) sd.replace(nstate, '__dace_db_param', new_expr) diff --git a/dace/transformation/dataflow/wcr_conversion.py b/dace/transformation/dataflow/wcr_conversion.py index e95674adc1..7f4fbc654d 100644 --- a/dace/transformation/dataflow/wcr_conversion.py +++ b/dace/transformation/dataflow/wcr_conversion.py @@ -2,10 +2,14 @@ """ Transformations to convert subgraphs to write-conflict resolutions. """ import ast import re -from dace import registry, nodes, dtypes +import copy +from dace import registry, nodes, dtypes, Memlet from dace.transformation import transformation, helpers as xfh from dace.sdfg import graph as gr, utils as sdutil from dace import SDFG, SDFGState +from dace.sdfg.state import StateSubgraphView +from dace.transformation import helpers +from dace.sdfg.propagation import propagate_memlets_state class AugAssignToWCR(transformation.SingleStateTransformation): @@ -20,6 +24,7 @@ class AugAssignToWCR(transformation.SingleStateTransformation): map_exit = transformation.PatternNode(nodes.MapExit) _EXPRESSIONS = ['+', '-', '*', '^', '%'] #, '/'] + _FUNCTIONS = ['min', 'max'] _EXPR_MAP = {'-': ('+', '-({expr})'), '/': ('*', '((decltype({expr}))1)/({expr})')} _PYOP_MAP = {ast.Add: '+', ast.Sub: '-', ast.Mult: '*', ast.BitXor: '^', ast.Mod: '%', ast.Div: '/'} @@ -27,6 +32,7 @@ class AugAssignToWCR(transformation.SingleStateTransformation): def expressions(cls): return [ sdutil.node_path_graph(cls.input, cls.tasklet, cls.output), + sdutil.node_path_graph(cls.input, cls.map_entry, cls.tasklet, cls.map_exit, cls.output) ] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): @@ -38,7 +44,6 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Free tasklet if expr_index == 0: - # Only free tasklets supported for now if graph.entry_node(tasklet) is not None: return False @@ -49,8 +54,6 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Make sure augmented assignment can be fissioned as necessary if any(not isinstance(e.src, nodes.AccessNode) for e in graph.in_edges(tasklet)): return False - if graph.in_degree(inarr) > 0 and graph.out_degree(outarr) > 0: - return False outedge = graph.edges_between(tasklet, outarr)[0] else: # Free map @@ -65,12 +68,10 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): if len(graph.edges_between(tasklet, mx)) > 1: return False - # Currently no fission is supported + # Make sure augmented assignment can be fissioned as necessary if any(e.src is not me and not isinstance(e.src, nodes.AccessNode) for e in graph.in_edges(me) + graph.in_edges(tasklet)): return False - if graph.in_degree(inarr) > 0: - return False outedge = graph.edges_between(tasklet, mx)[0] @@ -78,6 +79,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): outconn = outedge.src_conn ops = '[%s]' % ''.join(re.escape(o) for o in AugAssignToWCR._EXPRESSIONS) + funcs = '|'.join(re.escape(o) for o in AugAssignToWCR._FUNCTIONS) if tasklet.language is dtypes.Language.Python: # Match a single assignment with a binary operation as RHS @@ -108,18 +110,33 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Try to match a single C assignment that can be converted to WCR inconn = edge.dst_conn lhs = r'^\s*%s\s*=\s*%s\s*%s.*;$' % (re.escape(outconn), re.escape(inconn), ops) - rhs = r'^\s*%s\s*=\s*.*%s\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn)) - if re.match(lhs, cstr) is None: - continue + # rhs: a = (...) op b + rhs = r'^\s*%s\s*=\s*\(.*\)\s*%s\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn)) + func_lhs = r'^\s*%s\s*=\s*(%s)\(\s*%s\s*,.*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn)) + func_rhs = r'^\s*%s\s*=\s*(%s)\(.*,\s*%s\s*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn)) + if re.match(lhs, cstr) is None and re.match(rhs, cstr) is None: + if re.match(func_lhs, cstr) is None and re.match(func_rhs, cstr) is None: + inconns = list(self.tasklet.in_connectors) + if len(inconns) != 2: + continue + + # Special case: a = op b + other_inconn = inconns[0] if inconns[0] != inconn else inconns[1] + rhs2 = r'^\s*%s\s*=\s*%s\s*%s\s*%s;$' % (re.escape(outconn), re.escape(other_inconn), ops, + re.escape(inconn)) + if re.match(rhs2, cstr) is None: + continue + # Same memlet if edge.data.subset != outedge.data.subset: continue # If in map, only match if the subset is independent of any # map indices (otherwise no conflict) - if (expr_index == 1 and len(outedge.data.subset.free_symbols - & set(me.map.params)) == len(me.map.params)): - continue + if expr_index == 1: + if not permissive and len(outedge.data.subset.free_symbols & set(me.map.params)) == len( + me.map.params): + continue return True else: @@ -132,50 +149,22 @@ def apply(self, state: SDFGState, sdfg: SDFG): input: nodes.AccessNode = self.input tasklet: nodes.Tasklet = self.tasklet output: nodes.AccessNode = self.output + if self.expr_index == 1: + me = self.map_entry + mx = self.map_exit # If state fission is necessary to keep semantics, do it first - if (self.expr_index == 0 and state.in_degree(input) > 0 and state.out_degree(output) == 0): - newstate = sdfg.add_state_after(state) - newstate.add_node(tasklet) - new_input, new_output = None, None - - # Keep old edges for after we remove tasklet from the original state - in_edges = list(state.in_edges(tasklet)) - out_edges = list(state.out_edges(tasklet)) - - for e in in_edges: - r = newstate.add_read(e.src.data) - newstate.add_edge(r, e.src_conn, e.dst, e.dst_conn, e.data) - if e.src is input: - new_input = r - for e in out_edges: - w = newstate.add_write(e.dst.data) - newstate.add_edge(e.src, e.src_conn, w, e.dst_conn, e.data) - if e.dst is output: - new_output = w - - # Remove tasklet and resulting isolated nodes - state.remove_node(tasklet) - for e in in_edges: - if state.degree(e.src) == 0: - state.remove_node(e.src) - for e in out_edges: - if state.degree(e.dst) == 0: - state.remove_node(e.dst) - - # Reset state and nodes for rest of transformation - input = new_input - output = new_output - state = newstate - # End of state fission + if state.in_degree(input) > 0: + subgraph_nodes = set([e.src for e in state.bfs_edges(input, reverse=True)]) + subgraph_nodes.add(input) + + subgraph = StateSubgraphView(state, subgraph_nodes) + helpers.state_fission(sdfg, subgraph) if self.expr_index == 0: inedges = state.edges_between(input, tasklet) outedge = state.edges_between(tasklet, output)[0] else: - me = self.map_entry - mx = self.map_exit - inedges = state.edges_between(me, tasklet) outedge = state.edges_between(tasklet, mx)[0] @@ -183,6 +172,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): outconn = outedge.src_conn ops = '[%s]' % ''.join(re.escape(o) for o in AugAssignToWCR._EXPRESSIONS) + funcs = '|'.join(re.escape(o) for o in AugAssignToWCR._FUNCTIONS) # Change tasklet code if tasklet.language is dtypes.Language.Python: @@ -206,13 +196,40 @@ def apply(self, state: SDFGState, sdfg: SDFG): inconn = edge.dst_conn match = re.match(r'^\s*%s\s*=\s*%s\s*(%s)(.*);$' % (re.escape(outconn), re.escape(inconn), ops), cstr) if match is None: - # match = re.match( - # r'^\s*%s\s*=\s*(.*)\s*(%s)\s*%s;$' % - # (re.escape(outconn), ops, re.escape(inconn)), cstr) - # if match is None: - continue - # op = match.group(2) - # expr = match.group(1) + match = re.match( + r'^\s*%s\s*=\s*\((.*)\)\s*(%s)\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn)), cstr) + if match is None: + func_rhs = r'^\s*%s\s*=\s*(%s)\((.*),\s*%s\s*\)\s*;$' % (re.escape(outconn), funcs, + re.escape(inconn)) + match = re.match(func_rhs, cstr) + if match is None: + func_lhs = r'^\s*%s\s*=\s*(%s)\(\s*%s\s*,(.*)\)\s*;$' % (re.escape(outconn), funcs, + re.escape(inconn)) + match = re.match(func_lhs, cstr) + if match is None: + inconns = list(self.tasklet.in_connectors) + if len(inconns) != 2: + continue + + # Special case: a = op b + other_inconn = inconns[0] if inconns[0] != inconn else inconns[1] + rhs2 = r'^\s*%s\s*=\s*(%s)\s*(%s)\s*%s;$' % ( + re.escape(outconn), re.escape(other_inconn), ops, re.escape(inconn)) + match = re.match(rhs2, cstr) + if match is None: + continue + else: + op = match.group(2) + expr = match.group(1) + else: + op = match.group(1) + expr = match.group(2) + else: + op = match.group(1) + expr = match.group(2) + else: + op = match.group(2) + expr = match.group(1) else: op = match.group(1) expr = match.group(2) @@ -232,16 +249,14 @@ def apply(self, state: SDFGState, sdfg: SDFG): raise NotImplementedError # Change output edge - outedge.data.wcr = f'lambda a,b: a {op} b' - - if self.expr_index == 0: - # Remove input node and connector - state.remove_edge_and_connectors(inedge) - if state.degree(input) == 0: - state.remove_node(input) + if op in AugAssignToWCR._FUNCTIONS: + outedge.data.wcr = f'lambda a,b: {op}(a, b)' else: - # Remove input edge and dst connector, but not necessarily src - state.remove_memlet_path(inedge) + outedge.data.wcr = f'lambda a,b: a {op} b' + + # Remove input node and connector + state.remove_memlet_path(inedge) + propagate_memlets_state(sdfg, state) # If outedge leads to non-transient, and this is a nested SDFG, # propagate outwards @@ -252,6 +267,9 @@ def apply(self, state: SDFGState, sdfg: SDFG): sd = sd.parent_sdfg outedge = next(iter(nstate.out_edges_by_connector(nsdfg, outedge.data.data))) for outedge in nstate.memlet_path(outedge): - outedge.data.wcr = f'lambda a,b: a {op} b' + if op in AugAssignToWCR._FUNCTIONS: + outedge.data.wcr = f'lambda a,b: {op}(a, b)' + else: + outedge.data.wcr = f'lambda a,b: a {op} b' # At this point we are leading to an access node again and can # traverse further up diff --git a/dace/transformation/interstate/loop_unroll.py b/dace/transformation/interstate/loop_unroll.py index 47d438a2fc..b1dbfdd5c9 100644 --- a/dace/transformation/interstate/loop_unroll.py +++ b/dace/transformation/interstate/loop_unroll.py @@ -116,8 +116,7 @@ def instantiate_loop( # Replace iterate with value in each state for state in new_states: - state.set_label(state.label + '_' + itervar + '_' + - (state_suffix if state_suffix is not None else str(value))) + state.label = state.label + '_' + itervar + '_' + (state_suffix if state_suffix is not None else str(value)) state.replace(itervar, value) # Add subgraph to original SDFG diff --git a/dace/transformation/interstate/multistate_inline.py b/dace/transformation/interstate/multistate_inline.py index 74dd51a483..4d560ab70a 100644 --- a/dace/transformation/interstate/multistate_inline.py +++ b/dace/transformation/interstate/multistate_inline.py @@ -334,7 +334,7 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): if nstate.label in statenames: newname = data.find_new_name(nstate.label, statenames) statenames.add(newname) - nstate.set_label(newname) + nstate.label = newname ####################################################### # Add nested SDFG states into top-level SDFG diff --git a/doc/sdfg/images/elements.svg b/doc/sdfg/images/elements.svg index 80d35e39f0..6402de8e1d 100644 --- a/doc/sdfg/images/elements.svg +++ b/doc/sdfg/images/elements.svg @@ -1,90 +1,506 @@ - + - - - -Access Nodes - -T -ransient -Global - -Stream - -V -iew - -Reference - -T -asklet - - - - - - - - - -Nested SDFG - -Consume - - -Map - -... - - - -Library Node - - -... - -A[0] -CR: Sum -V -olume: 1 - -B[i, j] -V -olume: 1 -Memlet -W -rite-Conflict -Resolution - -State -State -T -ransition - + + + + +Access Nodes + +T +ransient +Global + +Stream + +V +iew + +Reference + +T +asklet + + + + + + + + + +Nested SDFG + +Consume + + +Map + +... + + + +Library Node + + +... + +A[0] +CR: Sum +V +olume: 1 + +B[i, j] +V +olume: 1 +Memlet +W +rite-Conflict +Resolution + +State +State +T +ransition +Control FlowRegion diff --git a/doc/sdfg/ir.rst b/doc/sdfg/ir.rst index 3c651fab19..f7bbb0ff79 100644 --- a/doc/sdfg/ir.rst +++ b/doc/sdfg/ir.rst @@ -29,7 +29,7 @@ Some of the main differences between SDFGs and other representations are: The Language ------------ -In a nutshell, an SDFG is a state machine of acyclic dataflow multigraphs. Here is an example graph: +In a nutshell, an SDFG is a hierarchical state machine of acyclic dataflow multigraphs. Here is an example graph: .. raw:: html @@ -43,7 +43,7 @@ In a nutshell, an SDFG is a state machine of acyclic dataflow multigraphs. Here The cyan rectangles are called **states** and together they form a state machine, executing the code from the starting state and following the blue edge that matches the conditions. In each state, an acyclic multigraph controls execution -through dataflow. There are four elements in the above state: +through dataflow. There are four elements in the above states: * **Access nodes** (ovals) that give access to data containers * **Memlets** (edges/dotted arrows) that represent units of data movement @@ -58,7 +58,14 @@ The state machine shown in the example is a for-loop (``for _ in range(5)``). Th the guard state controls the loop, and at the end the result is copied to the special ``__return`` data container, which designates the return value of the function. -There are other kinds of elements in an SDFG, as detailed below. +The state machine is analogous to a control flow graph, where states represent basic blocks. Multiple such basic blocks, +such as with the described loop, can be put together to form a **control flow region**. This allows them to be +represented with a single graph node in the SDFG's state machine, which is useful for optimization and analysis. +The SDFG itself can be thought of as one big control flow region. This means that control flow regions are directed +graphs, where nodes are states or other control flow regions, and edges are state transitions. + +In addition to the elements seen in the example above, there are other kinds of elements in an SDFG, which are detailed +below. .. _sdfg-lang: @@ -142,6 +149,12 @@ new value, and specifies how the update is performed. In the summation example, end of each state there is an implicit synchronization point, so it will not finish executing until all the last nodes have been reached (this assumption can be removed in extreme cases, see :class:`~dace.sdfg.state.SDFGState.nosync`). +**Control Flow Region**: Forms a directed graph of states and other control flow regions, where edges are state +transitions. This allows representing complex control flow in a single graph node, which is useful for analysis and +optimization. The SDFG itself is a control flow region, which means that control flow regions are recursive / +hierarchical. Similar to the SDFG, each control flow region has a unique starting state, which is the entry point to +the region and is executed first. + **State Transition**: Transitions, internally referred to as *inter-state edges*, specify how execution proceeds after the end of a State. Inter-state edges optionally contain a symbolic *condition* that is checked at the end of the preceding state. If any of the conditions are true, execution will continue to the destination of this edge (the @@ -783,5 +796,7 @@ file uses the :func:`~dace.sdfg.sdfg.SDFG.from_file` static method. For example, The ``compress`` argument can be used to save a smaller (``gzip`` compressed) file. It can keep the same extension, but it is customary to use ``.sdfg.gz`` or ``.sdfgz`` to let others know it is compressed. +It is recommended to use this option for large SDFGs, as it not only saves space, but also speeds up loading and +editing of the SDFG in visualization tools and the VSCode extension. diff --git a/requirements.txt b/requirements.txt index 5f804e1b4c..266b3368c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,13 +14,13 @@ Jinja2==3.1.2 MarkupSafe==2.1.3 mpmath==1.3.0 networkx==3.1 -numpy==1.24.3 +numpy==1.26.1 ply==3.11 -PyYAML==6.0 +PyYAML==6.0.1 requests==2.31.0 six==1.16.0 sympy==1.9 urllib3==2.0.7 websockets==11.0.3 -Werkzeug==2.3.5 +Werkzeug==3.0.1 zipp==3.15.0 diff --git a/tests/cuda_block_test.py b/tests/cuda_block_test.py index f77e80673f..676785e0e5 100644 --- a/tests/cuda_block_test.py +++ b/tests/cuda_block_test.py @@ -10,8 +10,10 @@ @dace.program(dace.float64[N], dace.float64[N]) def cudahello(V, Vout): + @dace.mapscope(_[0:N:32]) def multiplication(i): + @dace.map(_[0:32]) def mult_block(bi): in_V << V[i + bi] @@ -55,6 +57,7 @@ def test_gpu(): @pytest.mark.gpu def test_different_block_sizes_nesting(): + @dace.program def nested(V: dace.float64[34], v1: dace.float64[1]): with dace.tasklet: @@ -105,6 +108,7 @@ def diffblocks(V: dace.float64[130], v1: dace.float64[4], v2: dace.float64[128]) @pytest.mark.gpu def test_custom_block_size_onemap(): + @dace.program def tester(A: dace.float64[400, 300]): for i, j in dace.map[0:400, 0:300]: @@ -132,6 +136,7 @@ def tester(A: dace.float64[400, 300]): @pytest.mark.gpu def test_custom_block_size_twomaps(): + @dace.program def tester(A: dace.float64[400, 300, 2, 32]): for i, j in dace.map[0:400, 0:300]: @@ -154,9 +159,42 @@ def tester(A: dace.float64[400, 300, 2, 32]): sdfg.compile() +@pytest.mark.gpu +def test_block_thread_specialization(): + + @dace.program + def tester(A: dace.float64[200]): + for i in dace.map[0:200:32]: + for bi in dace.map[0:32]: + with dace.tasklet: + a >> A[i + bi] + a = 1 + with dace.tasklet: # Tasklet to be specialized + a >> A[i + bi] + a = 2 + + sdfg = tester.to_sdfg() + sdfg.apply_gpu_transformations(sequential_innermaps=False) + tasklet = next(n for n, _ in sdfg.all_nodes_recursive() + if isinstance(n, dace.nodes.Tasklet) and '2' in n.code.as_string) + tasklet.location['gpu_thread'] = dace.subsets.Range.from_string('2:9:3') + tasklet.location['gpu_block'] = 1 + + code = sdfg.generate_code()[1].clean_code # Get GPU code (second file) + assert '>= 2' in code and '<= 8' in code + assert ' == 1' in code + + a = np.random.rand(200) + ref = np.ones_like(a) + ref[32:64][2:9:3] = 2 + sdfg(a) + assert np.allclose(a, ref) + + if __name__ == "__main__": test_cpu() test_gpu() test_different_block_sizes_nesting() test_custom_block_size_onemap() test_custom_block_size_twomaps() + test_block_thread_specialization() diff --git a/tests/sdfg/nested_control_flow_regions_test.py b/tests/sdfg/nested_control_flow_regions_test.py new file mode 100644 index 0000000000..f29c093dad --- /dev/null +++ b/tests/sdfg/nested_control_flow_regions_test.py @@ -0,0 +1,18 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +import dace + + +def test_is_start_state_deprecation(): + sdfg = dace.SDFG('deprecation_test') + with pytest.deprecated_call(): + sdfg.add_state('state1', is_start_state=True) + sdfg2 = dace.SDFG('deprecation_test2') + state = dace.SDFGState('state2') + with pytest.deprecated_call(): + sdfg2.add_node(state, is_start_state=True) + + +if __name__ == '__main__': + test_is_start_state_deprecation() diff --git a/tests/sdfg_validate_names_test.py b/tests/sdfg_validate_names_test.py index dad79c8950..1650a4e4b1 100644 --- a/tests/sdfg_validate_names_test.py +++ b/tests/sdfg_validate_names_test.py @@ -28,7 +28,7 @@ def test_state_duplication(self): sdfg = dace.SDFG('ok') s1 = sdfg.add_state('also_ok') s2 = sdfg.add_state('also_ok') - s2.set_label('also_ok') + s2.label = 'also_ok' sdfg.add_edge(s1, s2, dace.InterstateEdge()) sdfg.validate() self.fail('Failed to detect duplicate state') diff --git a/tests/transformations/wcr_conversion_test.py b/tests/transformations/wcr_conversion_test.py new file mode 100644 index 0000000000..091b2a9db8 --- /dev/null +++ b/tests/transformations/wcr_conversion_test.py @@ -0,0 +1,247 @@ +import dace + +from dace.transformation.dataflow import AugAssignToWCR + + +def test_aug_assign_tasklet_lhs(): + + @dace.program + def sdfg_aug_assign_tasklet_lhs(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet: + a << A[i] + k << B[i] + b >> A[i] + b = a + k + + sdfg = sdfg_aug_assign_tasklet_lhs.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_lhs_brackets(): + + @dace.program + def sdfg_aug_assign_tasklet_lhs_brackets(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet: + a << A[i] + k << B[i] + b >> A[i] + b = a + (k + 1) + + sdfg = sdfg_aug_assign_tasklet_lhs_brackets.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_rhs(): + + @dace.program + def sdfg_aug_assign_tasklet_rhs(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet: + a << A[i] + k << B[i] + b >> A[i] + b = k + a + + sdfg = sdfg_aug_assign_tasklet_rhs.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_rhs_brackets(): + + @dace.program + def sdfg_aug_assign_tasklet_rhs_brackets(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet: + a << A[i] + k << B[i] + b >> A[i] + b = (k + 1) + a + + sdfg = sdfg_aug_assign_tasklet_rhs_brackets.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_lhs_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_lhs_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + k << B[i] + b >> A[i] + """ + b = a + k; + """ + + sdfg = sdfg_aug_assign_tasklet_lhs_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_lhs_brackets_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_lhs_brackets_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + k << B[i] + b >> A[i] + """ + b = a + (k + 1); + """ + + sdfg = sdfg_aug_assign_tasklet_lhs_brackets_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_rhs_brackets_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_rhs_brackets_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + k << B[i] + b >> A[i] + """ + b = (k + 1) + a; + """ + + sdfg = sdfg_aug_assign_tasklet_rhs_brackets_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_func_lhs_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_func_lhs_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + c << B[i] + b >> A[i] + """ + b = min(a, c); + """ + + sdfg = sdfg_aug_assign_tasklet_func_lhs_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_func_rhs_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_func_rhs_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + c << B[i] + b >> A[i] + """ + b = min(c, a); + """ + + sdfg = sdfg_aug_assign_tasklet_func_rhs_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_free_map(): + + @dace.program + def sdfg_aug_assign_free_map(A: dace.float64[32], B: dace.float64[32]): + for i in dace.map[0:32]: + with dace.tasklet(language=dace.Language.CPP): + a << A[0] + k << B[i] + b >> A[0] + """ + b = k * a; + """ + + sdfg = sdfg_aug_assign_free_map.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_state_fission_map(): + + @dace.program + def sdfg_aug_assign_state_fission(A: dace.float64[32], B: dace.float64[32]): + for i in dace.map[0:32]: + with dace.tasklet: + a << B[i] + b >> A[i] + b = a + + for i in dace.map[0:32]: + with dace.tasklet: + a << A[0] + b >> A[0] + b = a * 2 + + for i in dace.map[0:32]: + with dace.tasklet: + a << A[0] + b >> A[0] + b = a * 2 + + sdfg = sdfg_aug_assign_state_fission.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 2 + + +def test_free_map_permissive(): + + @dace.program + def sdfg_free_map_permissive(A: dace.float64[32], B: dace.float64[32]): + for i in dace.map[0:32]: + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + k << B[i] + b >> A[i] + """ + b = k * a; + """ + + sdfg = sdfg_free_map_permissive.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR, permissive=False) + assert applied == 0 + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR, permissive=True) + assert applied == 1