diff --git a/dace/codegen/codegen.py b/dace/codegen/codegen.py index c502a47376..5dbc28645c 100644 --- a/dace/codegen/codegen.py +++ b/dace/codegen/codegen.py @@ -16,7 +16,7 @@ from dace.codegen.targets import cpp, cpu from dace.codegen.instrumentation import InstrumentationProvider -from dace.sdfg.state import SDFGState +from dace.sdfg.state import SDFGState, ScopeBlock def generate_headers(sdfg: SDFG, frame: framecode.DaCeCodeGenerator) -> str: @@ -100,13 +100,13 @@ def _get_codegen_targets(sdfg: SDFG, frame: framecode.DaCeCodeGenerator): frame.targets.add(disp.get_scope_dispatcher(node.schedule)) elif isinstance(node, dace.nodes.Node): state: SDFGState = parent - nsdfg = state.parent + nsdfg = state.sdfg frame.targets.add(disp.get_node_dispatcher(nsdfg, state, node)) # Array allocation if isinstance(node, dace.nodes.AccessNode): state: SDFGState = parent - nsdfg = state.parent + nsdfg = state.sdfg desc = node.desc(nsdfg) frame.targets.add(disp.get_array_dispatcher(desc.storage)) @@ -124,13 +124,13 @@ def _get_codegen_targets(sdfg: SDFG, frame: framecode.DaCeCodeGenerator): dst_node = leaf_e.dst if leaf_e.data.is_empty(): continue - tgt = disp.get_copy_dispatcher(node, dst_node, leaf_e, state.parent, state) + tgt = disp.get_copy_dispatcher(node, dst_node, leaf_e, state.sdfg, state) if tgt is not None: frame.targets.add(tgt) else: # Rooted at dst_node dst_node = mtree.root().edge.dst - tgt = disp.get_copy_dispatcher(node, dst_node, e, state.parent, state) + tgt = disp.get_copy_dispatcher(node, dst_node, e, state.sdfg, state) if tgt is not None: frame.targets.add(tgt) @@ -149,7 +149,7 @@ def _get_codegen_targets(sdfg: SDFG, frame: framecode.DaCeCodeGenerator): disp.instrumentation[sdfg.instrument] = provider_mapping[sdfg.instrument] -def generate_code(sdfg, validate=True) -> List[CodeObject]: +def generate_code(sdfg: SDFG, validate=True) -> List[CodeObject]: """ Generates code as a list of code objects for a given SDFG. diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index 28bf38f14d..297017c429 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -124,7 +124,7 @@ class SingleState(ControlFlow): last_state: bool = False def as_cpp(self, codegen, symbols) -> str: - sdfg = self.state.parent + sdfg = self.state.sdfg expr = '__state_{}_{}:;\n'.format(sdfg.sdfg_id, self.state.label) if self.state.number_of_nodes() > 0: @@ -221,7 +221,7 @@ def as_cpp(self, codegen, symbols) -> str: # In a general block, emit transitions and assignments after each # individual state if isinstance(elem, SingleState): - sdfg = elem.state.parent + sdfg = elem.state.sdfg out_edges = sdfg.out_edges(elem.state) for j, e in enumerate(out_edges): if e not in self.gotos_to_ignore: @@ -361,7 +361,7 @@ class ForScope(ControlFlow): def as_cpp(self, codegen, symbols) -> str: - sdfg = self.guard.parent + sdfg = self.guard.sdfg # Initialize to either "int i = 0" or "i = 0" depending on whether # the type has been defined @@ -415,7 +415,7 @@ class WhileScope(ControlFlow): def as_cpp(self, codegen, symbols) -> str: if self.test is not None: - sdfg = self.guard.parent + sdfg = self.guard.sdfg test = unparse_interstate_edge(self.test.code[0], sdfg, codegen=codegen) else: test = 'true' diff --git a/dace/codegen/instrumentation/papi.py b/dace/codegen/instrumentation/papi.py index bc7163ea9b..bab44e6779 100644 --- a/dace/codegen/instrumentation/papi.py +++ b/dace/codegen/instrumentation/papi.py @@ -12,7 +12,7 @@ from dace.sdfg.graph import SubgraphView from dace.memlet import Memlet from dace.sdfg import scope_contains_scope -from dace.sdfg.state import StateGraphView +from dace.sdfg.state import DataflowGraphView import sympy as sp import os @@ -392,7 +392,7 @@ def should_instrument_entry(map_entry: EntryNode) -> bool: return cond @staticmethod - def has_surrounding_perfcounters(node, dfg: StateGraphView): + def has_surrounding_perfcounters(node, dfg: DataflowGraphView): """ Returns true if there is a possibility that this node is part of a section that is profiled. """ parent = dfg.entry_node(node) @@ -605,7 +605,7 @@ def get_memlet_byte_size(sdfg: dace.SDFG, memlet: Memlet): return memlet.volume * memdata.dtype.bytes @staticmethod - def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: StateGraphView): + def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: DataflowGraphView): scope_dict = sdfg.node(state_id).scope_dict() out_costs = 0 @@ -636,7 +636,7 @@ def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: return out_costs @staticmethod - def get_tasklet_byte_accesses(tasklet: nodes.CodeNode, dfg: StateGraphView, sdfg: dace.SDFG, state_id: int) -> str: + def get_tasklet_byte_accesses(tasklet: nodes.CodeNode, dfg: DataflowGraphView, sdfg: dace.SDFG, state_id: int) -> str: """ Get the amount of bytes processed by `tasklet`. The formula is sum(inedges * size) + sum(outedges * size) """ in_accum = [] @@ -693,7 +693,7 @@ def get_memory_input_size(node, sdfg, state_id) -> str: return sym2cpp(input_size) @staticmethod - def accumulate_byte_movement(outermost_node, node, dfg: StateGraphView, sdfg, state_id): + def accumulate_byte_movement(outermost_node, node, dfg: DataflowGraphView, sdfg, state_id): itvars = dict() # initialize an empty dict diff --git a/dace/codegen/targets/cuda.py b/dace/codegen/targets/cuda.py index a465d2bbc0..fe74d840ed 100644 --- a/dace/codegen/targets/cuda.py +++ b/dace/codegen/targets/cuda.py @@ -204,9 +204,9 @@ def preprocess(self, sdfg: SDFG) -> None: for state, node, defined_syms in sdutil.traverse_sdfg_with_defined_symbols(sdfg, recursive=True): if (isinstance(node, nodes.MapEntry) and node.map.schedule in (dtypes.ScheduleType.GPU_Device, dtypes.ScheduleType.GPU_Persistent)): - if state.parent not in shared_transients: - shared_transients[state.parent] = state.parent.shared_transients() - self._arglists[node] = state.scope_subgraph(node).arglist(defined_syms, shared_transients[state.parent]) + if state.sdfg not in shared_transients: + shared_transients[state.sdfg] = state.sdfg.shared_transients() + self._arglists[node] = state.scope_subgraph(node).arglist(defined_syms, shared_transients[state.sdfg]) def _compute_pool_release(self, top_sdfg: SDFG): """ @@ -831,7 +831,7 @@ def increment(streams): # Remove CUDA streams from paths of non-gpu copies and CPU tasklets for node, graph in sdfg.all_nodes_recursive(): if isinstance(graph, SDFGState): - cur_sdfg = graph.parent + cur_sdfg = graph.sdfg if (isinstance(node, (nodes.EntryNode, nodes.ExitNode)) and node.schedule in dtypes.GPU_SCHEDULES): # Node must have GPU stream, remove childpath and continue @@ -1421,7 +1421,7 @@ def generate_scope(self, sdfg, dfg_scope, state_id, function_stream, callsite_st visited = set() for node, parent in dfg_scope.all_nodes_recursive(): if isinstance(node, nodes.AccessNode): - nsdfg: SDFG = parent.parent + nsdfg: SDFG = parent.sdfg desc = node.desc(nsdfg) if (nsdfg, node.data) in visited: continue diff --git a/dace/codegen/targets/fpga.py b/dace/codegen/targets/fpga.py index 413cb751d6..e1753dc323 100644 --- a/dace/codegen/targets/fpga.py +++ b/dace/codegen/targets/fpga.py @@ -1332,7 +1332,7 @@ def partition_kernels(self, state: dace.SDFGState, default_kernel: int = 0): """ concurrent_kernels = 0 # Max number of kernels - sdfg = state.parent + sdfg = state.sdfg def increment(kernel_id): if concurrent_kernels > 0: diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index b1eb42fe60..9b18f969fa 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -83,7 +83,8 @@ def free_symbols(self, obj: Any): if k in self.fsyms: return self.fsyms[k] if hasattr(obj, 'used_symbols'): - result = obj.used_symbols(all_symbols=False) + intermediate = obj.used_symbols(all_symbols=False) + result = intermediate[0] if type(intermediate) is tuple else intermediate else: result = obj.free_symbols self.fsyms[k] = result @@ -395,9 +396,14 @@ def generate_external_memory_management(self, sdfg: SDFG, callsite_stream: CodeI # Footer callsite_stream.write('}', sdfg) - def generate_state(self, sdfg, state, global_stream, callsite_stream, generate_state_footer=True): + def generate_state(self, + sdfg: SDFG, + state: SDFGState, + global_stream: CodeIOStream, + callsite_stream: CodeIOStream, + generate_state_footer=True) -> None: - sid = sdfg.node_id(state) + sid = state.parent.node_id(state) # Emit internal transient array allocation self.allocate_arrays_in_scope(sdfg, state, global_stream, callsite_stream) @@ -444,7 +450,7 @@ def generate_state(self, sdfg, state, global_stream, callsite_stream, generate_s if instr is not None: instr.on_state_end(sdfg, state, callsite_stream, global_stream) - def generate_states(self, sdfg, global_stream, callsite_stream): + def generate_states(self, sdfg: SDFG, global_stream: CodeIOStream, callsite_stream: CodeIOStream): states_generated = set() opbar = progress.OptionalProgressBar(sdfg.number_of_nodes(), title=f'Generating code (SDFG {sdfg.sdfg_id})') @@ -491,7 +497,7 @@ def _get_schedule(self, scope: Union[nodes.EntryNode, SDFGState, SDFG]) -> dtype elif isinstance(scope, nodes.EntryNode): return scope.schedule elif isinstance(scope, (SDFGState, SDFG)): - sdfg: SDFG = (scope if isinstance(scope, SDFG) else scope.parent) + sdfg: SDFG = (scope if isinstance(scope, SDFG) else scope.sdfg) if sdfg.parent_nsdfg_node is None: return TOP_SCHEDULE @@ -526,8 +532,7 @@ def _can_allocate(self, sdfg: SDFG, state: SDFGState, desc: data.Data, scope: Un def determine_allocation_lifetime(self, top_sdfg: SDFG): """ - Determines where (at which scope/state/SDFG) each data descriptor - will be allocated/deallocated. + Determines where (at which scope/state/SDFG) each data descriptor will be allocated/deallocated. :param top_sdfg: The top-level SDFG to determine for. """ @@ -543,8 +548,7 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG): ############################################# # Look for all states in which a scope-allocated array is used in instances: Dict[str, List[Tuple[SDFGState, nodes.AccessNode]]] = collections.defaultdict(list) - array_names = sdfg.arrays.keys( - ) #set(k for k, v in sdfg.arrays.items() if v.lifetime == dtypes.AllocationLifetime.Scope) + array_names = sdfg.arrays.keys() # Iterate topologically to get state-order for state in sdfg.topological_sort(): for node in state.data_nodes(): @@ -721,7 +725,7 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG): if curscope is None: curscope = curstate elif isinstance(curscope, (SDFGState, SDFG)): - cursdfg: SDFG = (curscope if isinstance(curscope, SDFG) else curscope.parent) + cursdfg: SDFG = (curscope if isinstance(curscope, SDFG) else curscope.sdfg) # Go one SDFG up if cursdfg.parent_nsdfg_node is None: curscope = None diff --git a/dace/data.py b/dace/data.py index 0a9858458b..7c6abbcbb8 100644 --- a/dace/data.py +++ b/dace/data.py @@ -152,27 +152,6 @@ def _prod(sequence): return functools.reduce(lambda a, b: a * b, sequence, 1) -def find_new_name(name: str, existing_names: Sequence[str]) -> str: - """ - Returns a name that matches the given ``name`` as a prefix, but does not - already exist in the given existing name set. The behavior is typically - to append an underscore followed by a unique (increasing) number. If the - name does not already exist in the set, it is returned as-is. - - :param name: The given name to find. - :param existing_names: The set of existing names. - :return: A new name that is not in existing_names. - """ - if name not in existing_names: - return name - cur_offset = 0 - new_name = name + '_' + str(cur_offset) - while new_name in existing_names: - cur_offset += 1 - new_name = name + '_' + str(cur_offset) - return new_name - - @make_properties class Data: """ Data type descriptors that can be used as references to memory. diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 0329e31641..0b556d6a9d 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -31,7 +31,7 @@ from dace.sdfg.propagation import propagate_memlet, propagate_subset, propagate_states from dace.memlet import Memlet from dace.properties import LambdaProperty, CodeBlock -from dace.sdfg import SDFG, SDFGState +from dace.sdfg import SDFG, SDFGState, ControlFlowBlock, LoopScopeBlock, ScopeBlock from dace.sdfg.replace import replace_datadesc_names from dace.symbolic import pystr_to_symbolic, inequal_symbols @@ -1044,6 +1044,12 @@ class ProgramVisitor(ExtNodeVisitor): progress_bar = None start_time: float = 0 + sdfg: SDFG + last_block: ControlFlowBlock + cfg_target: ScopeBlock + last_cfg_target: ScopeBlock + current_state: SDFGState + def __init__(self, name: str, filename: str, @@ -1119,7 +1125,10 @@ def __init__(self, if sym.name not in self.sdfg.symbols: self.sdfg.add_symbol(sym.name, sym.dtype) self.sdfg._temp_transients = tmp_idx - self.last_state = self.sdfg.add_state('init', is_start_state=True) + self.cfg_target = self.sdfg + self.current_state = self.cfg_target.add_state('init', is_start_block=True) + self.last_block = self.current_state + self.last_cfg_target = self.sdfg self.inputs: DependencyType = {} self.outputs: DependencyType = {} @@ -1201,7 +1210,7 @@ def parse_program(self, program: ast.FunctionDef, is_tasklet: bool = False): for stmt in program.body: self.visit_TopLevel(stmt) if len(self.sdfg.nodes()) == 0: - self.sdfg.add_state("EmptyState") + self.sdfg.add_state('EmptyState') # Handle return values # Assignments to return values become __return* arrays @@ -1268,7 +1277,7 @@ def _views_to_data(state: SDFGState, nodes: List[dace.nodes.AccessNode]) -> List return new_nodes # Map view access nodes to their respective data - for state in self.sdfg.nodes(): + for state in self.sdfg.states(): # NOTE: We need to support views of views nodes = list(state.data_nodes()) while nodes: @@ -1315,13 +1324,38 @@ def defined(self): return result - def _add_state(self, label=None): - state = self.sdfg.add_state(label) - if self.last_state is not None: - self.sdfg.add_edge(self.last_state, state, dace.InterstateEdge()) - self.last_state = state + def _add_block(self, block: ControlFlowBlock): + if self.last_block is not None and self.last_cfg_target == self.cfg_target: + self.cfg_target.add_edge(self.last_block, block, dace.InterstateEdge()) + self.last_block = block + self.last_cfg_target = self.cfg_target + if not isinstance(block, SDFGState): + self.current_state = None + else: + self.current_state = block + + def _add_state(self, label=None, is_start=False) -> SDFGState: + state = self.cfg_target.add_state(label, is_start_block=is_start) + self._add_block(state) return state + def _add_loop_scope_block(self, + condition_expr: str, + label: str = 'loop', + loop_var: Optional[str] = None, + initialize_expr: Optional[str] = None, + update_expr: Optional[str] = None, + inverted: bool = False) -> LoopScopeBlock: + loop_scope_block = LoopScopeBlock(loop_var=loop_var, + initialize_expr=initialize_expr, + update_expr=update_expr, + condition_expr=condition_expr, + inverted=inverted, + label=label) + self.cfg_target.add_node(loop_scope_block) + self._add_block(loop_scope_block) + return loop_scope_block + def _parse_arg(self, arg: Any, as_list=True): """ Parse possible values to slices or objects that can be used in the SDFG API. """ @@ -2079,7 +2113,7 @@ def _add_dependencies(self, else: name = memlet.data vname = "{c}_out_of_{s}{n}".format(c=conn, - s=self.sdfg.nodes().index(state), + s=list(self.sdfg.states()).index(state), n=('_%s' % state.node_id(exit_node) if exit_node else '')) self.accesses[(name, scope_memlet.subset, 'w')] = (vname, orng) orig_shape = orng.size() @@ -2136,15 +2170,20 @@ def _recursive_visit(self, body: List[ast.AST], name: str, lineno: int, - last_state=True, + parent: ScopeBlock, + unconnected_last_block=True, extra_symbols=None) -> Tuple[SDFGState, SDFGState, SDFGState, bool]: """ Visits a subtree of the AST, creating special states before and after the visit. Returns the previous state, and the first and last internal states of the recursive visit. Also returns a boolean value indicating whether a return statement was met or not. This value can be used by other visitor methods, e.g., visit_If, to generate correct control flow. """ - before_state = self.last_state - self.last_state = None - first_internal_state = self._add_state('%s_%d' % (name, lineno)) + previous_last_cfg_target = self.last_cfg_target + previous_block = self.last_block + previous_target = self.cfg_target + self.last_block = None + self.cfg_target = parent + + first_innner_block = self._add_state('%s_%d' % (name, lineno)) # Add iteration variables to recursive visit if extra_symbols: @@ -2153,23 +2192,28 @@ def _recursive_visit(self, self.globals.update(extra_symbols) # Recursive loop processing - return_stmt = False + has_return_statement = False for stmt in body: self.visit_TopLevel(stmt) if isinstance(stmt, ast.Return): - return_stmt = True + has_return_statement = True # Create the next state - last_internal_state = self.last_state - if last_state: - self.last_state = None + last_inner_block = self.last_block + if unconnected_last_block: + self.last_block = None self._add_state('end%s_%d' % (name, lineno)) # Revert new symbols if extra_symbols: self.globals = old_globals + # Restore previous target + self.cfg_target = previous_target + self.last_cfg_target = previous_last_cfg_target + if not unconnected_last_block: + self.last_block = previous_block - return before_state, first_internal_state, last_internal_state, return_stmt + return previous_block, first_innner_block, last_inner_block, has_return_statement def _replace_with_global_symbols(self, expr: sympy.Expr) -> sympy.Expr: repldict = dict() @@ -2285,24 +2329,26 @@ def visit_For(self, node: ast.For): if (astr not in self.sdfg.symbols and not (astr in self.variables or astr in self.sdfg.arrays)): self.sdfg.add_symbol(astr, atom.dtype) - # Add an initial loop state with a None last_state (so as to not + # Add an initial loop state with a None last_block (so as to not # create an interstate edge) self.loop_idx += 1 self.continue_states.append([]) self.break_states.append([]) - laststate, first_loop_state, last_loop_state, _ = self._recursive_visit(node.body, - 'for', - node.lineno, - extra_symbols=extra_syms) - end_loop_state = self.last_state # Add loop to SDFG loop_cond = '>' if ((pystr_to_symbolic(ranges[0][2]) < 0) == True) else '<' + loop_cond_expr = '%s %s %s' % (indices[0], loop_cond, astutils.unparse(ast_ranges[0][1])) incr = {indices[0]: '%s + %s' % (indices[0], astutils.unparse(ast_ranges[0][2]))} - _, loop_guard, loop_end = self.sdfg.add_loop( - laststate, first_loop_state, end_loop_state, indices[0], astutils.unparse(ast_ranges[0][0]), - '%s %s %s' % (indices[0], loop_cond, astutils.unparse(ast_ranges[0][1])), incr[indices[0]], - last_loop_state) + loop_scope = self._add_loop_scope_block(loop_cond_expr, + label=f'for_{node.lineno}', + loop_var=indices[0], + initialize_expr=astutils.unparse(ast_ranges[0][0]), + update_expr=incr[indices[0]], + inverted=False) + _, first_subblock, _, _ = self._recursive_visit(node.body, f'for_{node.lineno}', + node.lineno, extra_symbols=extra_syms, + parent=loop_scope, unconnected_last_block=False) + loop_scope.start_block = loop_scope.node_id(first_subblock) # Handle else clause if node.orelse: @@ -2311,32 +2357,14 @@ def visit_For(self, node: ast.For): self.visit(stmt) # The state that all "break" edges go to - loop_end = self._add_state(f'postloop_{node.lineno}') - - body_states = list( - sdutil.dfs_conditional(self.sdfg, - sources=[first_loop_state], - condition=lambda p, c: c is not loop_guard)) - - continue_states = self.continue_states.pop() - while continue_states: - next_state = continue_states.pop() - out_edges = self.sdfg.out_edges(next_state) - for e in out_edges: - self.sdfg.remove_edge(e) - self.sdfg.add_edge(next_state, loop_guard, dace.InterstateEdge(assignments=incr)) - break_states = self.break_states.pop() - while break_states: - next_state = break_states.pop() - out_edges = self.sdfg.out_edges(next_state) - for e in out_edges: - self.sdfg.remove_edge(e) - self.sdfg.add_edge(next_state, loop_end, dace.InterstateEdge()) - self.loop_idx -= 1 + state = self.cfg_target.add_state(f'postloop_{node.lineno}') + if self.last_block is not None: + self.cfg_target.add_edge(self.last_block, state, dace.InterstateEdge()) + self.last_block = state + return state - for state in body_states: - if not nx.has_path(self.sdfg.nx, loop_guard, state): - self.sdfg.remove_node(state) + self.last_block = loop_scope + self.loop_idx -= 1 else: raise DaceSyntaxError(self, node, 'Unsupported for-loop iterator "%s"' % iterator) @@ -2379,19 +2407,16 @@ def _visit_test(self, node: ast.Expr): def visit_While(self, node: ast.While): # Get loop condition expression - begin_guard = self._add_state("while_guard") loop_cond, _ = self._visit_test(node.test) - end_guard = self.last_state + + loop_scope = self._add_loop_scope_block(loop_cond, + label=f'while_{node.lineno}', + inverted=False) # Parse body self.loop_idx += 1 - self.continue_states.append([]) - self.break_states.append([]) - laststate, first_loop_state, last_loop_state, _ = \ - self._recursive_visit(node.body, 'while', node.lineno) - end_loop_state = self.last_state - - assert (laststate == end_guard) + self._recursive_visit(node.body, f'while_{node.lineno}', node.lineno, parent=loop_scope, + unconnected_last_block=False) # Add symbols from test as necessary symcond = pystr_to_symbolic(loop_cond) @@ -2406,24 +2431,6 @@ def visit_While(self, node: ast.While): if (astr not in self.sdfg.symbols and astr not in self.variables): self.sdfg.add_symbol(astr, atom.dtype) - # Add loop to SDFG - _, loop_guard, loop_end = self.sdfg.add_loop(laststate, first_loop_state, end_loop_state, None, None, loop_cond, - None, last_loop_state) - - # Connect the correct while-guard state - # Current state: - # begin_guard -> ... -> end_guard/laststate -> loop_guard -> first_loop - # Desired state: - # begin_guard -> ... -> end_guard/laststate -> first_loop - for e in list(self.sdfg.in_edges(loop_guard)): - if e.src != laststate: - self.sdfg.add_edge(e.src, begin_guard, e.data) - self.sdfg.remove_edge(e) - for e in list(self.sdfg.out_edges(loop_guard)): - self.sdfg.add_edge(end_guard, e.dst, e.data) - self.sdfg.remove_edge(e) - self.sdfg.remove_node(loop_guard) - # Handle else clause if node.orelse: # Continue visiting body @@ -2431,30 +2438,11 @@ def visit_While(self, node: ast.While): self.visit(stmt) # The state that all "break" edges go to - loop_end = self._add_state(f'postwhile_{node.lineno}') - - body_states = list( - sdutil.dfs_conditional(self.sdfg, sources=[first_loop_state], condition=lambda p, c: c is not loop_guard)) - - continue_states = self.continue_states.pop() - while continue_states: - next_state = continue_states.pop() - out_edges = self.sdfg.out_edges(next_state) - for e in out_edges: - self.sdfg.remove_edge(e) - self.sdfg.add_edge(next_state, begin_guard, dace.InterstateEdge()) - break_states = self.break_states.pop() - while break_states: - next_state = break_states.pop() - out_edges = self.sdfg.out_edges(next_state) - for e in out_edges: - self.sdfg.remove_edge(e) - self.sdfg.add_edge(next_state, loop_end, dace.InterstateEdge()) - self.loop_idx -= 1 + self._add_state(f'postwhile_{node.lineno}') + + self.last_block = loop_scope - for state in body_states: - if not nx.has_path(self.sdfg.nx, end_guard, state): - self.sdfg.remove_node(state) + self.loop_idx -= 1 def visit_Break(self, node: ast.Break): if self.loop_idx < 0: @@ -2464,7 +2452,7 @@ def visit_Break(self, node: ast.Break): " used in nested DaCe program calls to break out " " of loops of outer scopes)") raise DaceSyntaxError(self, node, error_msg) - self.break_states[self.loop_idx].append(self.last_state) + self.break_states[self.loop_idx].append(self.last_block) def visit_Continue(self, node: ast.Continue): if self.loop_idx < 0: @@ -2474,37 +2462,37 @@ def visit_Continue(self, node: ast.Continue): " be used in nested DaCe program calls to " " continue loops of outer scopes)") raise DaceSyntaxError(self, node, error_msg) - self.continue_states[self.loop_idx].append(self.last_state) + self.continue_states[self.loop_idx].append(self.last_block) def visit_If(self, node: ast.If): # Add a guard state self._add_state('if_guard') - self.last_state.debuginfo = self.current_lineinfo + self.last_block.debuginfo = self.current_lineinfo # Generate conditions cond, cond_else = self._visit_test(node.test) # Visit recursively laststate, first_if_state, last_if_state, return_stmt = \ - self._recursive_visit(node.body, 'if', node.lineno) - end_if_state = self.last_state + self._recursive_visit(node.body, 'if', node.lineno, self.cfg_target, True) + end_if_state = self.last_block # Connect the states - self.sdfg.add_edge(laststate, first_if_state, dace.InterstateEdge(cond)) - self.sdfg.add_edge(last_if_state, end_if_state, dace.InterstateEdge(condition=f"{not return_stmt}")) + self.cfg_target.add_edge(laststate, first_if_state, dace.InterstateEdge(cond)) + self.cfg_target.add_edge(last_if_state, end_if_state, dace.InterstateEdge(condition=f"{not return_stmt}")) # Process 'else'/'elif' statements if len(node.orelse) > 0: # Visit recursively _, first_else_state, last_else_state, return_stmt = \ - self._recursive_visit(node.orelse, 'else', node.lineno, False) + self._recursive_visit(node.orelse, 'else', node.lineno, self.cfg_target, False) # Connect the states - self.sdfg.add_edge(laststate, first_else_state, dace.InterstateEdge(cond_else)) - self.sdfg.add_edge(last_else_state, end_if_state, dace.InterstateEdge(condition=f"{not return_stmt}")) - self.last_state = end_if_state + self.cfg_target.add_edge(laststate, first_else_state, dace.InterstateEdge(cond_else)) + self.cfg_target.add_edge(last_else_state, end_if_state, dace.InterstateEdge(condition=f"{not return_stmt}")) else: - self.sdfg.add_edge(laststate, end_if_state, dace.InterstateEdge(cond_else)) + self.cfg_target.add_edge(laststate, end_if_state, dace.InterstateEdge(cond_else)) + self.last_block = end_if_state def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None): @@ -3037,7 +3025,7 @@ def _add_access( inner_indices = set(non_squeezed) - state = self.last_state + state = self.current_state new_memlet = None if has_indirection: @@ -3336,9 +3324,9 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): view = self.sdfg.arrays[result] cname, carr = self.sdfg.add_transient(result, view.shape, view.dtype, find_new_name=True) self._add_state(f'copy_from_view_{node.lineno}') - rnode = self.last_state.add_read(result, debuginfo=self.current_lineinfo) - wnode = self.last_state.add_read(cname, debuginfo=self.current_lineinfo) - self.last_state.add_nedge(rnode, wnode, Memlet.from_array(cname, carr)) + rnode = self.current_state.add_read(result, debuginfo=self.current_lineinfo) + wnode = self.current_state.add_read(cname, debuginfo=self.current_lineinfo) + self.current_state.add_nedge(rnode, wnode, Memlet.from_array(cname, carr)) result = cname # Strict independent access check for augmented assignments @@ -3359,7 +3347,7 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): # Handle output indirection output_indirection = None if _subset_has_indirection(rng, self): - output_indirection = self.sdfg.add_state('wslice_%s_%d' % (new_name, node.lineno)) + output_indirection = self.cfg_target.add_state('wslice_%s_%d' % (new_name, node.lineno)) wnode = output_indirection.add_write(new_name, debuginfo=self.current_lineinfo) memlet = Memlet.simple(new_name, str(rng)) # Dependent augmented assignments need WCR in the @@ -3389,10 +3377,10 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): if op and independent: if _subset_has_indirection(rng, self): self._add_state('rslice_%s_%d' % (new_name, node.lineno)) - rnode = self.last_state.add_read(new_name, debuginfo=self.current_lineinfo) + rnode = self.current_state.add_read(new_name, debuginfo=self.current_lineinfo) memlet = Memlet.simple(new_name, str(rng)) tmp = self.sdfg.temp_data_name() - ind_name = add_indirection_subgraph(self.sdfg, self.last_state, rnode, None, memlet, tmp, self) + ind_name = add_indirection_subgraph(self.sdfg, self.current_state, rnode, None, memlet, tmp, self) rtarget = ind_name else: rtarget = (new_name, new_rng) @@ -3405,8 +3393,8 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): # Connect states properly when there is output indirection if output_indirection: - self.sdfg.add_edge(self.last_state, output_indirection, dace.sdfg.InterstateEdge()) - self.last_state = output_indirection + self.cfg_target.add_edge(self.last_block, output_indirection, dace.sdfg.InterstateEdge()) + self.last_block = output_indirection def visit_AugAssign(self, node: ast.AugAssign): self._visit_assign(node, node.target, augassign_ops[type(node.op).__name__]) @@ -3823,7 +3811,7 @@ def _parse_sdfg_call(self, funcname: str, func: Union[SDFG, SDFGConvertible], no output_slices = set() for arg in itertools.chain(node.args, [kw.value for kw in node.keywords]): if isinstance(arg, ast.Subscript): - slice_state = self.last_state + slice_state = self.current_state break # Make sure that any scope vars in the arguments are substituted @@ -3850,8 +3838,8 @@ def _parse_sdfg_call(self, funcname: str, func: Union[SDFG, SDFGConvertible], no for sym, local in mapping.items(): if isinstance(local, str) and local in self.sdfg.arrays: # Add assignment state and inter-state edge - symassign_state = self.sdfg.add_state_before(state) - isedge = self.sdfg.edges_between(symassign_state, state)[0] + symassign_state = self.cfg_target.add_state_before(state) + isedge = self.cfg_target.edges_between(symassign_state, state)[0] newsym = self.sdfg.find_new_symbol(f'sym_{local}') desc = self.sdfg.arrays[local] self.sdfg.add_symbol(newsym, desc.dtype) @@ -3915,7 +3903,7 @@ def _parse_sdfg_call(self, funcname: str, func: Union[SDFG, SDFGConvertible], no # Delete the old read descriptor if not isinput: conn_used = False - for s in self.sdfg.nodes(): + for s in self.sdfg.states(): for n in s.data_nodes(): if n.data == aname: conn_used = True @@ -4229,11 +4217,11 @@ def parse_target(t: Union[ast.Name, ast.Subscript]): # Create a state with a tasklet and the right arguments self._add_state('callback_%d' % node.lineno) - self.last_state.set_default_lineinfo(self.current_lineinfo) + self.last_block.set_default_lineinfo(self.current_lineinfo) if callback_type.is_scalar_function() and len(callback_type.return_types) > 0: call_args = ', '.join(str(s) for s in allargs[:-1]) - tasklet = self.last_state.add_tasklet(f'callback_{node.lineno}', {f'__in_{name}' + tasklet = self.last_block.add_tasklet(f'callback_{node.lineno}', {f'__in_{name}' for name in args} | {'__istate'}, {f'__out_{name}' for name in outargs} | {'__ostate'}, @@ -4241,7 +4229,7 @@ def parse_target(t: Union[ast.Name, ast.Subscript]): side_effects=True) else: call_args = ', '.join(str(s) for s in allargs) - tasklet = self.last_state.add_tasklet(f'callback_{node.lineno}', {f'__in_{name}' + tasklet = self.last_block.add_tasklet(f'callback_{node.lineno}', {f'__in_{name}' for name in args} | {'__istate'}, {f'__out_{name}' for name in outargs} | {'__ostate'}, @@ -4255,15 +4243,15 @@ def parse_target(t: Union[ast.Name, ast.Subscript]): # Setup arguments in graph for arg in dtypes.deduplicate(args): - r = self.last_state.add_read(arg) - self.last_state.add_edge(r, None, tasklet, f'__in_{arg}', Memlet(arg)) + r = self.current_state.add_read(arg) + self.current_state.add_edge(r, None, tasklet, f'__in_{arg}', Memlet(arg)) for arg in dtypes.deduplicate(outargs): - w = self.last_state.add_write(arg) - self.last_state.add_edge(tasklet, f'__out_{arg}', w, None, Memlet(arg)) + w = self.current_state.add_write(arg) + self.current_state.add_edge(tasklet, f'__out_{arg}', w, None, Memlet(arg)) # Connect Python state - self._connect_pystate(tasklet, self.last_state, '__istate', '__ostate') + self._connect_pystate(tasklet, self.current_state, '__istate', '__ostate') if return_type is None: return [] @@ -4449,17 +4437,17 @@ def visit_Call(self, node: ast.Call, create_callbacks=False): keywords = {arg.arg: self._parse_function_arg(arg.value) for arg in node.keywords} self._add_state('call_%d' % node.lineno) - self.last_state.set_default_lineinfo(self.current_lineinfo) + self.last_block.set_default_lineinfo(self.current_lineinfo) if found_ufunc: - result = func(self, node, self.sdfg, self.last_state, ufunc_name, args, keywords) + result = func(self, node, self.sdfg, self.last_block, ufunc_name, args, keywords) else: - result = func(self, self.sdfg, self.last_state, *args, **keywords) + result = func(self, self.sdfg, self.last_block, *args, **keywords) - self.last_state.set_default_lineinfo(None) + self.last_block.set_default_lineinfo(None) if isinstance(result, tuple) and type(result[0]) is nested_call.NestedCall: - self.last_state = result[0].last_state + self.last_block = result[0].last_block result = result[1] if not isinstance(result, (tuple, list)): @@ -4659,9 +4647,9 @@ def visit_Attribute(self, node: ast.Attribute): if func is not None: # A new state is likely needed here, e.g., for transposition (ndarray.T) self._add_state('%s_%d' % (type(node).__name__, node.lineno)) - self.last_state.set_default_lineinfo(self.current_lineinfo) - result = func(self, self.sdfg, self.last_state, result) - self.last_state.set_default_lineinfo(None) + self.last_block.set_default_lineinfo(self.current_lineinfo) + result = func(self, self.sdfg, self.last_block, result) + self.last_block.set_default_lineinfo(None) return result # Otherwise, try to find compile-time attribute (such as shape) @@ -4767,9 +4755,9 @@ def _visit_op(self, node: Union[ast.UnaryOp, ast.BinOp, ast.BoolOp], op1: ast.AS raise DaceSyntaxError(self, node, f'Operator {opname} is not defined for types {op1name} and {op2name}') self._add_state('%s_%d' % (type(node).__name__, node.lineno)) - self.last_state.set_default_lineinfo(self.current_lineinfo) + self.last_block.set_default_lineinfo(self.current_lineinfo) try: - result = func(self, self.sdfg, self.last_state, operand1, operand2) + result = func(self, self.sdfg, self.last_block, operand1, operand2) except SyntaxError as ex: raise DaceSyntaxError(self, node, str(ex)) if not isinstance(result, (list, tuple)): @@ -4782,7 +4770,7 @@ def _visit_op(self, node: Union[ast.UnaryOp, ast.BinOp, ast.BoolOp], op1: ast.AS raise DaceSyntaxError(self, node, "Variable {v} has been already defined".format(v=r)) self.variables[r] = r - self.last_state.set_default_lineinfo(None) + self.last_block.set_default_lineinfo(None) return result @@ -4826,7 +4814,7 @@ def _add_read_slice(self, array: str, node: ast.Subscript, expr: MemletExpr): self._add_state('slice_%s_%d' % (array, node.lineno)) if has_array_indirection: # Make copy slicing state - rnode = self.last_state.add_read(array, debuginfo=self.current_lineinfo) + rnode = self.current_state.add_read(array, debuginfo=self.current_lineinfo) return self._array_indirection_subgraph(rnode, expr) else: is_index = False @@ -4867,9 +4855,9 @@ def _add_read_slice(self, array: str, node: ast.Subscript, expr: MemletExpr): wcr=expr.wcr)) self.variables[tmp] = tmp if not isinstance(tmparr, data.View): - rnode = self.last_state.add_read(array, debuginfo=self.current_lineinfo) - wnode = self.last_state.add_write(tmp, debuginfo=self.current_lineinfo) - self.last_state.add_nedge( + rnode = self.current_state.add_read(array, debuginfo=self.current_lineinfo) + wnode = self.current_state.add_write(tmp, debuginfo=self.current_lineinfo) + self.current_state.add_nedge( rnode, wnode, Memlet(f'{array}[{expr.subset}]->{other_subset}', volume=expr.accesses, wcr=expr.wcr)) return tmp @@ -4902,7 +4890,7 @@ def _promote(node: ast.AST) -> Union[Any, str, symbolic.symbol]: # `not sym` returns True. This exception is benign. pass state = self._add_state(f'promote_{scalar}_to_{str(sym)}') - edge = self.sdfg.in_edges(state)[0] + edge = state.parent.in_edges(state)[0] edge.data.assignments = {str(sym): scalar} return sym return scalar @@ -5082,17 +5070,17 @@ def make_slice(self, arrname: str, rng: subsets.Range): # Add slicing state # TODO: naming issue, we don't have the linenumber here self._add_state('slice_%s' % (array)) - rnode = self.last_state.add_read(array, debuginfo=self.current_lineinfo) + rnode = self.current_state.add_read(array, debuginfo=self.current_lineinfo) other_subset = copy.deepcopy(rng) other_subset.squeeze() if _subset_has_indirection(rng, self): memlet = Memlet.simple(array, rng) tmp = self.sdfg.temp_data_name() - tmp = add_indirection_subgraph(self.sdfg, self.last_state, rnode, None, memlet, tmp, self) + tmp = add_indirection_subgraph(self.sdfg, self.current_state, rnode, None, memlet, tmp, self) else: tmp, tmparr = self.sdfg.add_temp_transient(other_subset.size(), arrobj.dtype, arrobj.storage) - wnode = self.last_state.add_write(tmp, debuginfo=self.current_lineinfo) - self.last_state.add_nedge( + wnode = self.current_state.add_write(tmp, debuginfo=self.current_lineinfo) + self.current_state.add_nedge( rnode, wnode, Memlet.simple(array, rng, num_accesses=rng.num_elements(), other_subset_str=other_subset)) return tmp, other_subset @@ -5161,7 +5149,7 @@ def _array_indirection_subgraph(self, rnode: nodes.AccessNode, expr: MemletExpr) # output shape dimensions are len(output_shape) # Make map with output shape - state: SDFGState = self.last_state + state = self.current_state wnode = state.add_write(outname) maprange = [(f'__i{i}', f'0:{s}') for i, s in enumerate(output_shape)] me, mx = state.add_map('indirect_slice', maprange, debuginfo=self.current_lineinfo) diff --git a/dace/memlet.py b/dace/memlet.py index d448ca1134..ea81dbebc4 100644 --- a/dace/memlet.py +++ b/dace/memlet.py @@ -400,11 +400,11 @@ def try_initialize(self, sdfg: 'dace.sdfg.SDFG', state: 'dace.sdfg.SDFGState', self.subset = subsets.Range.from_array(sdfg.arrays[self.data]) def get_src_subset(self, edge: 'dace.sdfg.graph.MultiConnectorEdge', state: 'dace.sdfg.SDFGState'): - self.try_initialize(state.parent, state, edge) + self.try_initialize(state.sdfg, state, edge) return self.src_subset def get_dst_subset(self, edge: 'dace.sdfg.graph.MultiConnectorEdge', state: 'dace.sdfg.SDFGState'): - self.try_initialize(state.parent, state, edge) + self.try_initialize(state.sdfg, state, edge) return self.dst_subset @staticmethod diff --git a/dace/sdfg/__init__.py b/dace/sdfg/__init__.py index 183cf841c7..9f48433bd5 100644 --- a/dace/sdfg/__init__.py +++ b/dace/sdfg/__init__.py @@ -1,7 +1,7 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. from dace.sdfg.sdfg import SDFG, InterstateEdge, LogicalGroup -from dace.sdfg.state import SDFGState +from dace.sdfg.state import SDFGState, ControlFlowBlock, ScopeBlock, LoopScopeBlock, BranchScopeBlock from dace.sdfg.scope import (scope_contains_scope, is_devicelevel_gpu, devicelevel_block_size, ScopeSubgraphView) diff --git a/dace/sdfg/analysis/cfg.py b/dace/sdfg/analysis/cfg.py index 9021a79439..3dcfeb6f00 100644 --- a/dace/sdfg/analysis/cfg.py +++ b/dace/sdfg/analysis/cfg.py @@ -2,7 +2,7 @@ """ Various analyses related to control flow in SDFG states. """ from collections import defaultdict from dace.sdfg import SDFG, SDFGState, InterstateEdge, graph as gr, utils as sdutil -from dace.symbolic import pystr_to_symbolic +from dace.sdfg.state import ScopeBlock, ControlFlowBlock import networkx as nx import sympy as sp from typing import Dict, Iterator, List, Set @@ -67,32 +67,32 @@ def back_edges(sdfg: SDFG, return [e for e in sdfg.edges() if e.dst in alldoms[e.src]] -def state_parent_tree(sdfg: SDFG) -> Dict[SDFGState, SDFGState]: +def state_parent_tree(graph: ScopeBlock) -> Dict[ControlFlowBlock, ControlFlowBlock]: """ Computes an upward-pointing tree of each state, pointing to the "parent - state" it belongs to (in terms of structured control flow). More formally, - each state is either mapped to its immediate dominator with out degree > 2, - one state upwards if state occurs after a loop, or the start state if - no such states exist. + block" it belongs to (in terms of structured control flow). More formally, + each block is either mapped to its immediate dominator with out degree > 2, + one block upwards if the block occurs after a loop, or the start block if + no such block exist. :param sdfg: The SDFG to analyze. :return: A dictionary that maps each state to a parent state, or None if the root (start) state. """ - idom = nx.immediate_dominators(sdfg.nx, sdfg.start_state) - alldoms = all_dominators(sdfg, idom) - loopexits: Dict[SDFGState, SDFGState] = defaultdict(lambda: None) + idom = nx.immediate_dominators(graph.nx, graph.start_block) + alldoms = all_dominators(graph, idom) + loopexits: Dict[ControlFlowBlock, ControlFlowBlock] = defaultdict(lambda: None) # First, annotate loops - for be in back_edges(sdfg, idom, alldoms): + for be in back_edges(graph, idom, alldoms): guard = be.dst laststate = be.src if loopexits[guard] is not None: continue # Natural loops = one edge leads back to loop, another leads out - in_edges = sdfg.in_edges(guard) - out_edges = sdfg.out_edges(guard) + in_edges = graph.in_edges(guard) + out_edges = graph.out_edges(guard) # A loop guard has two or more incoming edges (1 increment and # n init, all identical), and exactly two outgoing edges (loop and @@ -148,8 +148,8 @@ def cond_b(parent, child): return False return True # Keep traversing - list(sdutil.dfs_conditional(sdfg, (oa, ), cond_a)) - list(sdutil.dfs_conditional(sdfg, (ob, ), cond_b)) + list(sdutil.dfs_conditional(graph, (oa, ), cond_a)) + list(sdutil.dfs_conditional(graph, (ob, ), cond_b)) # Check which candidate states led back to guard is_a_begin = a_reached_guard and reachable_a @@ -168,20 +168,20 @@ def cond_b(parent, child): loopexits[guard] = exit_state # Get dominators - parents: Dict[SDFGState, SDFGState] = {} - step_up: Set[SDFGState] = set() - for state in sdfg.nodes(): + parents: Dict[ControlFlowBlock, ControlFlowBlock] = {} + step_up: Set[ControlFlowBlock] = set() + for state in graph.nodes(): curdom = idom[state] if curdom == state: parents[state] = None continue while curdom != idom[curdom]: - if sdfg.out_degree(curdom) > 1: + if graph.out_degree(curdom) > 1: break curdom = idom[curdom] - if sdfg.out_degree(curdom) == 2 and loopexits[curdom] is not None: + if graph.out_degree(curdom) == 2 and loopexits[curdom] is not None: p = state while p != curdom and p != loopexits[curdom]: p = idom[p] @@ -199,22 +199,22 @@ def cond_b(parent, child): return parents -def _stateorder_topological_sort(sdfg: SDFG, - start: SDFGState, - ptree: Dict[SDFGState, SDFGState], - branch_merges: Dict[SDFGState, SDFGState], - stop: SDFGState = None, - visited: Set[SDFGState] = None) -> Iterator[SDFGState]: +def _stateorder_topological_sort(graph: ScopeBlock, + start: ControlFlowBlock, + ptree: Dict[ControlFlowBlock, ControlFlowBlock], + branch_merges: Dict[ControlFlowBlock, ControlFlowBlock], + stop: ControlFlowBlock = None, + visited: Set[ControlFlowBlock] = None) -> Iterator[ControlFlowBlock]: """ Helper function for ``stateorder_topological_sort``. - :param sdfg: SDFG. - :param start: Starting state for traversal. + :param graph: Control flow graph or SDFG. + :param start: Starting block for traversal. :param ptree: State parent tree (computed from ``state_parent_tree``). - :param branch_merges: Dictionary mapping from branch state to its merge state. - :param stop: Stopping state to not traverse through (merge state of a - branch or guard state of a loop). - :return: Generator that yields states in state-order from ``start`` to + :param branch_merges: Dictionary mapping from branch blocks to their merge blocks. + :param stop: Stopping block to not traverse through (merge block of a + branch or guard block of a loop). + :return: Generator that yields blocks in state-order from ``start`` to ``stop``. """ # Traverse states in custom order @@ -227,7 +227,7 @@ def _stateorder_topological_sort(sdfg: SDFG, yield node visited.add(node) - oe = sdfg.out_edges(node) + oe = graph.out_edges(node) if len(oe) == 0: # End state continue elif len(oe) == 1: # No traversal change @@ -236,14 +236,14 @@ def _stateorder_topological_sort(sdfg: SDFG, elif len(oe) == 2: # Loop or branch # If loop, traverse body, then exit if ptree[oe[0].dst] == node and ptree[oe[1].dst] != node: - for s in _stateorder_topological_sort(sdfg, oe[0].dst, ptree, branch_merges, stop=node, + for s in _stateorder_topological_sort(graph, oe[0].dst, ptree, branch_merges, stop=node, visited=visited): yield s visited.add(s) stack.append(oe[1].dst) continue elif ptree[oe[1].dst] == node and ptree[oe[0].dst] != node: - for s in _stateorder_topological_sort(sdfg, oe[1].dst, ptree, branch_merges, stop=node, + for s in _stateorder_topological_sort(graph, oe[1].dst, ptree, branch_merges, stop=node, visited=visited): yield s visited.add(s) @@ -258,7 +258,7 @@ def _stateorder_topological_sort(sdfg: SDFG, try: # Otherwise (e.g., with return/break statements), traverse through each branch, # stopping at the end of the current tree level. - mergestate = next(e.dst for e in sdfg.out_edges(stop) if ptree[e.dst] != stop) + mergestate = next(e.dst for e in graph.out_edges(stop) if ptree[e.dst] != stop) except StopIteration: # If that fails, simply traverse branches in arbitrary order mergestate = stop @@ -267,7 +267,7 @@ def _stateorder_topological_sort(sdfg: SDFG, if branch.dst is mergestate: # If we hit the merge state (if without else), defer to end of branch traversal continue - for s in _stateorder_topological_sort(sdfg, + for s in _stateorder_topological_sort(graph, branch.dst, ptree, branch_merges, @@ -278,23 +278,23 @@ def _stateorder_topological_sort(sdfg: SDFG, stack.append(mergestate) -def stateorder_topological_sort(sdfg: SDFG) -> Iterator[SDFGState]: +def stateorder_topological_sort(graph: ScopeBlock) -> Iterator[ControlFlowBlock]: """ - Returns a generator that produces states in the order that they will be + Returns a generator that produces control flow blocks in the order that they will be executed, disregarding multiple loop iterations and employing topological sort for branches. - :param sdfg: The SDFG to iterate over. - :return: Generator that yields states in state-order. + :param graph: The SDFG / control flow graph to iterate over. + :return: Generator that yields control flow blocks in state-order. """ # Get parent states - ptree = state_parent_tree(sdfg) + ptree = state_parent_tree(graph) # Annotate branches - branch_merges: Dict[SDFGState, SDFGState] = {} - adf = acyclic_dominance_frontier(sdfg) - for state in sdfg.nodes(): - oedges = sdfg.out_edges(state) + branch_merges: Dict[ControlFlowBlock, ControlFlowBlock] = {} + adf = acyclic_dominance_frontier(graph) + for state in graph.nodes(): + oedges = graph.out_edges(state) # Skip if not branch if len(oedges) <= 1: continue @@ -312,4 +312,4 @@ def stateorder_topological_sort(sdfg: SDFG) -> Iterator[SDFGState]: if len(common_frontier) == 1: branch_merges[state] = next(iter(common_frontier)) - yield from _stateorder_topological_sort(sdfg, sdfg.start_state, ptree, branch_merges) + yield from _stateorder_topological_sort(graph, graph.start_block, ptree, branch_merges) diff --git a/dace/sdfg/analysis/cutout.py b/dace/sdfg/analysis/cutout.py index a72a6d7e54..781073a099 100644 --- a/dace/sdfg/analysis/cutout.py +++ b/dace/sdfg/analysis/cutout.py @@ -193,7 +193,7 @@ def singlestate_cutout(cls, if reduce_input_config: nodes = _reduce_in_configuration(state, nodes, use_alibi_nodes, symbols_map) create_element = copy.deepcopy if make_copy else (lambda x: x) - sdfg = state.parent + sdfg = state.sdfg subgraph: StateSubgraphView = StateSubgraphView(state, nodes) subgraph = _extend_subgraph_with_access_nodes(state, subgraph, use_alibi_nodes) @@ -341,8 +341,8 @@ def multistate_cutout(cls, create_element = copy.deepcopy # Check that all states are inside the same SDFG. - sdfg = list(states)[0].parent - if any(i.parent != sdfg for i in states): + sdfg = list(states)[0].sdfg + if any(i.sdfg != sdfg for i in states): raise Exception('Not all cutout states reside in the same SDFG') cutout_states: Set[SDFGState] = set(states) @@ -423,13 +423,13 @@ def multistate_cutout(cls, in_translation[is_edge.src] = new_el out_translation[new_el] = is_edge.src cutout.add_node(new_el, is_start_state=(is_edge.src == start_state)) - new_el.parent = cutout + new_el.sdfg = cutout if is_edge.dst not in in_translation: new_el: SDFGState = create_element(is_edge.dst) in_translation[is_edge.dst] = new_el out_translation[new_el] = is_edge.dst cutout.add_node(new_el, is_start_state=(is_edge.dst == start_state)) - new_el.parent = cutout + new_el.sdfg = cutout new_isedge: InterstateEdge = create_element(is_edge.data) in_translation[is_edge.data] = new_isedge out_translation[new_isedge] = is_edge.data @@ -442,7 +442,7 @@ def multistate_cutout(cls, in_translation[state] = new_el out_translation[new_el] = state cutout.add_node(new_el, is_start_state=(state == start_state)) - new_el.parent = cutout + new_el.sdfg = cutout in_translation[sdfg.sdfg_id] = cutout.sdfg_id out_translation[cutout.sdfg_id] = sdfg.sdfg_id @@ -574,8 +574,8 @@ def _reduce_in_configuration(state: SDFGState, affected_nodes: Set[nd.Node], use # For the given state, determine what should count as the input configuration if we were to cut out the entire # state. - state_reachability_dict = StateReachability().apply_pass(state.parent, None) - state_reach = state_reachability_dict[state.parent.sdfg_id] + state_reachability_dict = StateReachability().apply_pass(state.sdfg, None) + state_reach = state_reachability_dict[state.sdfg.sdfg_id] reaching_cutout: Set[SDFGState] = set() for k, v in state_reach.items(): if state in v: @@ -586,7 +586,7 @@ def _reduce_in_configuration(state: SDFGState, affected_nodes: Set[nd.Node], use if state.out_degree(dn) > 0: # This is read from, add to the system state if it is written anywhere else in the graph. # Except if it is also written to at the same time and is scalar or of size 1. - array = state.parent.arrays[dn.data] + array = state.sdfg.arrays[dn.data] if state.in_degree(dn) > 0 and (array.total_size == 1 or isinstance(array, data.Scalar)): continue elif not array.transient: @@ -608,8 +608,8 @@ def _reduce_in_configuration(state: SDFGState, affected_nodes: Set[nd.Node], use # about symbol values. Not sure how to do that yet. if symbols_map is None: symbols_map = dict() - consts = state.parent.constants - for s in state.parent.symbols: + consts = state.sdfg.constants + for s in state.sdfg.symbols: if s in consts: symbols_map[s] = consts[s] else: @@ -730,8 +730,8 @@ def _reduce_in_configuration(state: SDFGState, affected_nodes: Set[nd.Node], use for node in scope_nodes: if isinstance(node, nd.AccessNode) and node.data in state_input_configuration: - if not proxy_graph.has_edge(source, node) and node.data in state.parent.arrays: - vol = state.parent.arrays[node.data].total_size + if not proxy_graph.has_edge(source, node) and node.data in state.sdfg.arrays: + vol = state.sdfg.arrays[node.data].total_size if isinstance(vol, sp.Expr): vol = vol.subs(symbols_map) proxy_graph.add_edge(source, node, capacity=vol) @@ -767,7 +767,7 @@ def _stateset_predecessor_frontier(states: Set[SDFGState]) -> Tuple[Set[SDFGStat pred_frontier = set() pred_frontier_edges = set() for state in states: - for iedge in state.parent.in_edges(state): + for iedge in state.sdfg.in_edges(state): if iedge.src not in states: if iedge.src not in pred_frontier: pred_frontier.add(iedge.src) @@ -819,7 +819,7 @@ def _create_alibi_access_node_for_edge(target_sdfg: SDFG, target_state: SDFGStat def _extend_subgraph_with_access_nodes(state: SDFGState, subgraph: StateSubgraphView, use_alibi_nodes: bool) -> StateSubgraphView: """ Expands a subgraph view to include necessary input/output access nodes, using memlet paths. """ - sdfg = state.parent + sdfg = state.sdfg result: List[nd.Node] = copy.copy(subgraph.nodes()) queue: Deque[nd.Node] = deque(subgraph.nodes()) @@ -1014,7 +1014,7 @@ def _cutout_determine_output_configuration(ct: SDFG, cutout_reach: Set[SDFGState check_for_read_after.add(dn.data) original_state: SDFGState = out_translation[state] - for edge in original_state.parent.out_edges(original_state): + for edge in original_state.sdfg.out_edges(original_state): if edge.dst in cutout_reach: border_out_edges.add(edge.data) diff --git a/dace/sdfg/graph.py b/dace/sdfg/graph.py index 5c93149529..42fb228c3f 100644 --- a/dace/sdfg/graph.py +++ b/dace/sdfg/graph.py @@ -365,8 +365,7 @@ def sink_nodes(self) -> List[NodeT]: return [n for n in self.nodes() if self.out_degree(n) == 0] def topological_sort(self, source: NodeT = None) -> Sequence[NodeT]: - """Returns nodes in topological order iff the graph contains exactly - one node with no incoming edges.""" + """Returns nodes in topological order iff the graph contains exactly one node with no incoming edges.""" if source is not None: sources = [source] else: diff --git a/dace/sdfg/infer_types.py b/dace/sdfg/infer_types.py index 105e1d12e9..c4ba3cdcf2 100644 --- a/dace/sdfg/infer_types.py +++ b/dace/sdfg/infer_types.py @@ -61,7 +61,7 @@ def infer_connector_types(sdfg: SDFG): :param sdfg: The SDFG to infer. """ # Loop over states, and in a topological sort over each state's nodes - for state in sdfg.nodes(): + for state in sdfg.all_states_recursive(): for node in dfs_topological_sort(state): # Try to infer input connector type from node type or previous edges for e in state.in_edges(node): @@ -167,7 +167,7 @@ def set_default_schedule_and_storage_types(scope: Union[SDFG, SDFGState, nodes.E if isinstance(scope, SDFG): # Set device for default top-level schedules and storages - for state in scope.nodes(): + for state in scope.all_states_recursive(): set_default_schedule_and_storage_types(state, parent_schedules, use_parent_schedule=use_parent_schedule, @@ -257,7 +257,7 @@ def _determine_schedule_from_storage(state: SDFGState, node: nodes.Node) -> Opti # From memlets, use non-scalar data descriptors for decision constraints: Set[dtypes.ScheduleType] = set() - sdfg = state.parent + sdfg = state.sdfg for dname in memlets: if isinstance(sdfg.arrays[dname], data.Scalar): continue # Skip scalars @@ -276,7 +276,7 @@ def _determine_schedule_from_storage(state: SDFGState, node: nodes.Node) -> Opti raise validation.InvalidSDFGNodeError( f'Cannot determine default schedule for node {node}. ' 'Multiple arrays that point to it say that it should be the following schedules: ' - f'{constraints}', state.parent, state.parent.node_id(state), state.node_id(node)) + f'{constraints}', state.sdfg, state.sdfg.node_id(state), state.node_id(node)) else: child_schedule = next(iter(constraints)) @@ -338,7 +338,7 @@ def _set_default_storage_in_scope(state: SDFGState, parent_node: Optional[nodes. parent_schedules = parent_schedules + [dtypes.ScheduleType.GPU_ThreadBlock] # End of special case - sdfg = state.parent + sdfg = state.sdfg child_storage = _determine_child_storage(parent_schedules) if child_storage is None: child_storage = dtypes.SCOPEDEFAULT_STORAGE[None] @@ -378,7 +378,7 @@ def _get_storage_from_parent(data_name: str, sdfg: SDFG) -> dtypes.StorageType: """ nsdfg_node = sdfg.parent_nsdfg_node parent_state = sdfg.parent - parent_sdfg = parent_state.parent + parent_sdfg = parent_state.sdfg # Find data descriptor in parent SDFG if data_name in nsdfg_node.in_connectors: diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 32369a19a3..608579e3b8 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -262,10 +262,9 @@ def label(self): def __label__(self, sdfg, state): return self.data - def desc(self, sdfg): - from dace.sdfg import SDFGState, ScopeSubgraphView - if isinstance(sdfg, (SDFGState, ScopeSubgraphView)): - sdfg = sdfg.parent + def desc(self, sdfg: Union['dace.sdfg.SDFG', 'dace.sdfg.SDFGState', 'dace.sdfg.ScopeSubgraphView']): + if not isinstance(sdfg, dace.sdfg.SDFG): + sdfg = sdfg.sdfg return sdfg.arrays[self.data] def validate(self, sdfg, state): diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 0554775dcd..52060e0edd 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -1332,7 +1332,7 @@ def propagate_memlet(dfg_state, if memlet.is_empty(): return Memlet() - sdfg = dfg_state.parent + sdfg = dfg_state.sdfg scope_node_symbols = set(conn for conn in entry_node.in_connectors if not conn.startswith('IN_')) defined_vars = [ symbolic.pystr_to_symbolic(s) for s in (dfg_state.symbols_defined_at(entry_node).keys() diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index 4b36fad4fe..80f0eced82 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -175,17 +175,18 @@ def replace_datadesc_names(sdfg, repl: Dict[str, str]): sdfg.constants_prop[repl[aname]] = sdfg.constants_prop[aname] del sdfg.constants_prop[aname] - # Replace in interstate edges - for e in sdfg.edges(): - e.data.replace_dict(repl, replace_keys=False) - - for state in sdfg.nodes(): - # Replace in access nodes - for node in state.data_nodes(): - if node.data in repl: - node.data = repl[node.data] - - # Replace in memlets - for edge in state.edges(): - if edge.data.data in repl: - edge.data.data = repl[edge.data.data] + for cf in sdfg.all_state_scopes_recursive(recurse_into_sdfgs=False): + # Replace in interstate edges + for e in cf.edges(): + e.data.replace_dict(repl, replace_keys=False) + + for block in cf.nodes(): + if isinstance(block, dace.SDFGState): + # Replace in access nodes + for node in block.data_nodes(): + if node.data in repl: + node.data = repl[node.data] + # Replace in memlets + for edge in block.edges(): + if edge.data.data in repl: + edge.data.data = repl[edge.data.data] diff --git a/dace/sdfg/scope.py b/dace/sdfg/scope.py index 95f278b06a..2b2dd6a1e0 100644 --- a/dace/sdfg/scope.py +++ b/dace/sdfg/scope.py @@ -104,12 +104,10 @@ def _scope_dict_inner(graph, node_queue, current_scope, node_to_children, result # If this is an Entry Node, we need to recurse further if isinstance(node, nd.EntryNode): node_queue.extend(_scope_dict_inner(graph, collections.deque(successors), node, node_to_children, result)) - # If this is an Exit Node, we push the successors to the external - # queue + # If this is an Exit Node, we push the successors to the external queue elif isinstance(node, nd.ExitNode): external_queue.extend(successors) - # Otherwise, it is a plain node, and we push its successors to the - # same queue + # Otherwise, it is a plain node, and we push its successors to the same queue else: node_queue.extend(successors) @@ -248,7 +246,7 @@ def is_devicelevel_gpu_kernel(sdfg: 'dace.sdfg.SDFG', state: 'dace.sdfg.SDFGStat if is_parent_nested: return is_devicelevel_gpu(sdfg.parent.parent, sdfg.parent, sdfg.parent_nsdfg_node, with_gpu_default=True) else: - return is_devicelevel_gpu(state.parent, state, node, with_gpu_default=True) + return is_devicelevel_gpu(state.sdfg, state, node, with_gpu_default=True) def is_devicelevel_fpga(sdfg: 'dace.sdfg.SDFG', state: 'dace.sdfg.SDFGState', node: NodeType) -> bool: @@ -296,7 +294,7 @@ def devicelevel_block_size(sdfg: 'dace.sdfg.SDFG', state: 'dace.sdfg.SDFGState', # Traverse up nested SDFGs if sdfg.parent is not None: if isinstance(sdfg.parent, SDFGState): - parent = sdfg.parent.parent + parent = sdfg.parent.sdfg else: parent = sdfg.parent state, node = next((s, n) for s in parent.nodes() for n in s.nodes() diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index a85e773337..a04bf2a620 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -3,40 +3,47 @@ import collections import copy import ctypes -import itertools import gzip -from numbers import Integral +import itertools +import json import os -import pickle, json -from hashlib import md5, sha256 -from pydoc import locate +import pickle import random import re import shutil import sys import time -from typing import Any, AnyStr, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, TYPE_CHECKING, Union import warnings +from hashlib import md5, sha256 +from numbers import Integral +from pydoc import locate +from typing import (TYPE_CHECKING, Any, AnyStr, BinaryIO, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, + Union) + import numpy as np import sympy as sp import dace import dace.serialize -from dace import (data as dt, hooks, memlet as mm, subsets as sbs, dtypes, properties, symbolic) -from dace.sdfg.scope import ScopeTree -from dace.sdfg.replace import replace, replace_properties, replace_properties_dict -from dace.sdfg.validation import (InvalidSDFGError, validate_sdfg) +from dace import data as dt +from dace import dtypes, hooks +from dace import memlet as mm +from dace import properties +from dace import subsets as sbs +from dace import symbolic from dace.config import Config +from dace.distr_types import ProcessGrid, RedistrArray, SubArray +from dace.dtypes import validate_name from dace.frontend.python import astutils, wrappers +from dace.properties import (CodeBlock, CodeProperty, DebugInfoProperty, DictProperty, EnumProperty, ListProperty, + OptionalSDFGReferenceProperty, Property, TransformationHistProperty, make_properties) from dace.sdfg import nodes as nd -from dace.sdfg.graph import OrderedDiGraph, Edge, SubgraphView -from dace.sdfg.state import SDFGState +from dace.sdfg.graph import Edge, OrderedDiGraph, SubgraphView from dace.sdfg.propagation import propagate_memlets_sdfg -from dace.distr_types import ProcessGrid, SubArray, RedistrArray -from dace.dtypes import validate_name -from dace.properties import (DebugInfoProperty, EnumProperty, ListProperty, make_properties, Property, CodeProperty, - TransformationHistProperty, OptionalSDFGReferenceProperty, DictProperty, CodeBlock) -from typing import BinaryIO +from dace.sdfg.replace import replace, replace_properties, replace_properties_dict +from dace.sdfg.scope import ScopeTree +from dace.sdfg.state import SDFGState, ScopeBlock, LoopScopeBlock +from dace.sdfg.validation import InvalidSDFGError, validate_sdfg # NOTE: In shapes, we try to convert strings to integers. In ranks, a string should be interpreted as data (scalar). ShapeType = Sequence[Union[Integral, str, symbolic.symbol, symbolic.SymExpr, symbolic.sympy.Basic]] @@ -402,7 +409,7 @@ def label(self): @make_properties -class SDFG(OrderedDiGraph[SDFGState, InterstateEdge]): +class SDFG(ScopeBlock): """ The main intermediate representation of code in DaCe. A Stateful DataFlow multiGraph (SDFG) is a directed graph of directed @@ -494,15 +501,11 @@ def __init__(self, self.add_constant(cstname, cstval, cst_dtype) self._propagate = propagate - self._parent = parent self.symbols = {} self._parent_sdfg = None self._parent_nsdfg_node = None self._sdfg_list = [self] - self._start_state: Optional[int] = None - self._cached_start_state: Optional[SDFGState] = None self._arrays = NestedDict() # type: Dict[str, dt.Array] - self._labels: Set[str] = set() self.global_code = {'frame': CodeBlock("", dtypes.Language.CPP)} self.init_code = {'frame': CodeBlock("", dtypes.Language.CPP)} self.exit_code = {'frame': CodeBlock("", dtypes.Language.CPP)} @@ -531,14 +534,14 @@ def __deepcopy__(self, memo): memo[id(self)] = result for k, v in self.__dict__.items(): # Skip derivative attributes - if k in ('_cached_start_state', '_edges', '_nodes', '_parent', '_parent_sdfg', '_parent_nsdfg_node', + if k in ('_cached_start_block', '_edges', '_nodes', '_parent', '_parent_sdfg', '_parent_nsdfg_node', '_sdfg_list', '_transformation_hist'): continue setattr(result, k, copy.deepcopy(v, memo)) # Copy edges and nodes result._edges = copy.deepcopy(self._edges, memo) result._nodes = copy.deepcopy(self._nodes, memo) - result._cached_start_state = copy.deepcopy(self._cached_start_state, memo) + result._cached_start_block = copy.deepcopy(self._cached_start_block, memo) # Copy parent attributes for k in ('_parent', '_parent_sdfg', '_parent_nsdfg_node'): if id(getattr(self, k)) in memo: @@ -551,7 +554,8 @@ def __deepcopy__(self, memo): result._sdfg_list = [] if self._parent_sdfg is None: # Avoid import loops - from dace.transformation.passes.fusion_inline import FixNestedSDFGReferences + from dace.transformation.passes.fusion_inline import \ + FixNestedSDFGReferences result._sdfg_list = result.reset_sdfg_list() fixed = FixNestedSDFGReferences().apply_pass(result, {}) @@ -583,7 +587,7 @@ def to_json(self, hash=False): tmp['attributes']['constants_prop'] = json.loads(dace.serialize.dumps(tmp['attributes']['constants_prop'])) tmp['sdfg_list_id'] = int(self.sdfg_id) - tmp['start_state'] = self._start_state + tmp['start_state'] = self._start_block tmp['attributes']['name'] = self.name if hash: @@ -627,7 +631,7 @@ def from_json(cls, json_obj, context_info=None): ret.add_edge(nodelist[int(e.src)], nodelist[int(e.dst)], e.data) if 'start_state' in json_obj: - ret._start_state = json_obj['start_state'] + ret._start_block = json_obj['start_state'] return ret @@ -753,14 +757,7 @@ def replace_dict(self, for array in self.arrays.values(): replace_properties_dict(array, repldict, symrepl) - if replace_in_graph: - # Replace in inter-state edges - for edge in self.edges(): - edge.data.replace_dict(repldict, replace_keys=replace_keys) - - # Replace in states - for state in self.nodes(): - state.replace_dict(repldict, symrepl) + super().replace_dict(repldict, symrepl, replace_in_graph, replace_keys) def add_symbol(self, name, stype): """ Adds a symbol to the SDFG. @@ -787,34 +784,11 @@ def remove_symbol(self, name): @property def start_state(self): - """ Returns the starting state of this SDFG. """ - if self._cached_start_state is not None: - return self._cached_start_state - - source_nodes = self.source_nodes() - if len(source_nodes) == 1: - self._cached_start_state = source_nodes[0] - return source_nodes[0] - # If starting state is ambiguous (i.e., loop to initial state or more - # than one possible start state), allow manually overriding start state - if self._start_state is not None: - self._cached_start_state = self.node(self._start_state) - return self._cached_start_state - raise ValueError('Ambiguous or undefined starting state for SDFG, ' - 'please use "is_start_state=True" when adding the ' - 'starting state with "add_state"') + return self.start_block @start_state.setter def start_state(self, state_id): - """ Manually sets the starting state of this SDFG. - - :param state_id: The node ID (use `node_id(state)`) of the - state to set. - """ - if state_id < 0 or state_id >= self.number_of_nodes(): - raise ValueError("Invalid state ID") - self._start_state = state_id - self._cached_start_state = self.node(state_id) + self.start_block = state_id def set_global_code(self, cpp_code: str, location: str = 'frame'): """ @@ -1017,7 +991,8 @@ def get_instrumented_data(self, timestamp: Optional[int] = None) -> Optional['In :return: An InstrumentedDataReport object, or None if one does not exist. """ # Avoid import loops - from dace.codegen.instrumentation.data.data_report import InstrumentedDataReport + from dace.codegen.instrumentation.data.data_report import \ + InstrumentedDataReport if timestamp is None: reports = self.available_data_reports() @@ -1060,7 +1035,8 @@ def call_with_instrumented_data(self, dreport: 'InstrumentedDataReport', *args, :param kwargs: Keyword arguments to call SDFG with. :return: The return value(s) of this SDFG. """ - from dace.codegen.compiled_sdfg import CompiledSDFG # Avoid import loop + from dace.codegen.compiled_sdfg import \ + CompiledSDFG # Avoid import loop binaryobj: CompiledSDFG = self.compile() set_report = binaryobj.get_exported_function('__dace_set_instrumented_data_report') @@ -1127,7 +1103,7 @@ def remove_data(self, name, validate=True): # Verify that there are no access nodes that use this data if validate: - for state in self.nodes(): + for state in self.states(): for node in state.nodes(): if isinstance(node, nd.AccessNode) and node.data == name: raise ValueError(f"Cannot remove data descriptor " @@ -1216,102 +1192,32 @@ def propagate(self): def propagate(self, propagate: bool): self._propagate = propagate - @property - def parent(self) -> SDFGState: - """ Returns the parent SDFG state of this SDFG, if exists. """ - return self._parent - - @property - def parent_sdfg(self) -> 'SDFG': - """ Returns the parent SDFG of this SDFG, if exists. """ - return self._parent_sdfg - @property def parent_nsdfg_node(self) -> nd.NestedSDFG: """ Returns the parent NestedSDFG node of this SDFG, if exists. """ return self._parent_nsdfg_node - @parent.setter - def parent(self, value): - self._parent = value - - @parent_sdfg.setter - def parent_sdfg(self, value): - self._parent_sdfg = value - @parent_nsdfg_node.setter def parent_nsdfg_node(self, value): self._parent_nsdfg_node = value - def add_node(self, node, is_start_state=False): - """ Adds a new node to the SDFG. Must be an SDFGState or a subclass - thereof. + @property + def parent_sdfg(self) -> 'SDFG': + """ Returns the parent SDFG of this SDFG, if exists. """ + return self._parent_sdfg - :param node: The node to add. - :param is_start_state: If True, sets this node as the starting - state. - """ - if not isinstance(node, SDFGState): - raise TypeError("Expected SDFGState, got " + str(type(node))) - super(SDFG, self).add_node(node) - self._cached_start_state = None - if is_start_state is True: - self.start_state = len(self.nodes()) - 1 - self._cached_start_state = node + @parent_sdfg.setter + def parent_sdfg(self, value): + self._parent_sdfg = value def remove_node(self, node: SDFGState): - if node is self._cached_start_state: - self._cached_start_state = None + if node is self._cached_start_block: + self._cached_start_block = None return super().remove_node(node) - def add_edge(self, u, v, edge): - """ Adds a new edge to the SDFG. Must be an InterstateEdge or a - subclass thereof. - - :param u: Source node. - :param v: Destination node. - :param edge: The edge to add. - """ - if not isinstance(u, SDFGState): - raise TypeError("Expected SDFGState, got: {}".format(type(u).__name__)) - if not isinstance(v, SDFGState): - raise TypeError("Expected SDFGState, got: {}".format(type(v).__name__)) - if not isinstance(edge, InterstateEdge): - raise TypeError("Expected InterstateEdge, got: {}".format(type(edge).__name__)) - if v is self._cached_start_state: - self._cached_start_state = None - return super(SDFG, self).add_edge(u, v, edge) - - def states(self): - """ Alias that returns the nodes (states) in this SDFG. """ - return self.nodes() - - def all_nodes_recursive(self) -> Iterator[Tuple[nd.Node, Union['SDFG', 'SDFGState']]]: - """ Iterate over all nodes in this SDFG, including states, nodes in - states, and recursive states and nodes within nested SDFGs, - returning tuples on the form (node, parent), where the parent is - either the SDFG (for states) or a DFG (nodes). """ - for node in self.nodes(): - yield node, self - yield from node.all_nodes_recursive() - - def all_sdfgs_recursive(self): - """ Iterate over this and all nested SDFGs. """ - yield self - for state in self.nodes(): - for node in state.nodes(): - if isinstance(node, nd.NestedSDFG): - yield from node.sdfg.all_sdfgs_recursive() - - def all_edges_recursive(self): - """ Iterate over all edges in this SDFG, including state edges, - inter-state edges, and recursively edges within nested SDFGs, - returning tuples on the form (edge, parent), where the parent is - either the SDFG (for states) or a DFG (nodes). """ - for e in self.edges(): - yield e, self - for node in self.nodes(): - yield from node.all_edges_recursive() + def states(self) -> Iterator[SDFGState]: + """ Returns the states in this SDFG, recursing into state scope blocks. """ + return self.all_states_recursive() def arrays_recursive(self): """ Iterate over all arrays in this SDFG, including arrays within @@ -1323,19 +1229,15 @@ def arrays_recursive(self): if isinstance(node, nd.NestedSDFG): yield from node.sdfg.arrays_recursive() - def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]: - """ - Returns a set of symbol names that are used by the SDFG, but not - defined within it. This property is used to determine the symbolic - parameters of the SDFG. - - :param all_symbols: If False, only returns the set of symbols that will be used - in the generated code and are needed as arguments. - :param keep_defined_in_mapping: If True, symbols defined in inter-state edges that are in the symbol mapping - will be removed from the set of defined symbols. - """ - defined_syms = set() - free_syms = set() + def used_symbols(self, + all_symbols: bool, + defined_syms: Optional[Set]=None, + free_syms: Optional[Set]=None, + used_before_assignment: Optional[Set]=None, + keep_defined_in_mapping: bool=False) -> Tuple[Set[str], Set[str], Set[str]]: + defined_syms = set() if defined_syms is None else defined_syms + free_syms = set() if free_syms is None else free_syms + used_before_assignment = set() if used_before_assignment is None else used_before_assignment # Exclude data descriptor names and constants for name in self.arrays.keys(): @@ -1349,54 +1251,10 @@ def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) - for code in self.exit_code.values(): free_syms |= symbolic.symbols_in_code(code.as_string, self.symbols.keys()) - # Add free state symbols - used_before_assignment = set() - - try: - ordered_states = self.topological_sort(self.start_state) - except ValueError: # Failsafe (e.g., for invalid or empty SDFGs) - ordered_states = self.nodes() - - for state in ordered_states: - state_fsyms = state.used_symbols(all_symbols) - free_syms |= state_fsyms - - # Add free inter-state symbols - for e in self.out_edges(state): - # NOTE: First we get the true InterstateEdge free symbols, then we compute the newly defined symbols by - # subracting the (true) free symbols from the edge's assignment keys. This way we can correctly - # compute the symbols that are used before being assigned. - efsyms = e.data.used_symbols(all_symbols) - defined_syms |= set(e.data.assignments.keys()) - (efsyms | state_fsyms) - used_before_assignment.update(efsyms - defined_syms) - free_syms |= efsyms - - # Remove symbols that were used before they were assigned - defined_syms -= used_before_assignment - - # Remove from defined symbols those that are in the symbol mapping - if self.parent_nsdfg_node is not None and keep_defined_in_mapping: - defined_syms -= set(self.parent_nsdfg_node.symbol_mapping.keys()) - - # Add the set of SDFG symbol parameters - # If all_symbols is False, those symbols would only be added in the case of non-Python tasklets - if all_symbols: - free_syms |= set(self.symbols.keys()) - - # Subtract symbols defined in inter-state edges and constants - return free_syms - defined_syms - - @property - def free_symbols(self) -> Set[str]: - """ - Returns a set of symbol names that are used by the SDFG, but not - defined within it. This property is used to determine the symbolic - parameters of the SDFG and verify that ``SDFG.symbols`` is complete. - - :note: Assumes that the graph is valid (i.e., without undefined or - overlapping symbols). - """ - return self.used_symbols(all_symbols=True) + return super().used_symbols( + all_symbols=all_symbols, keep_defined_in_mapping=keep_defined_in_mapping, + defined_syms=defined_syms, free_syms=free_syms, used_before_assignment=used_before_assignment + ) def get_all_toplevel_symbols(self) -> Set[str]: """ @@ -1470,7 +1328,7 @@ def arglist(self, scalars_only=False, free_symbols=None) -> Dict[str, dt.Data]: } # Add global free symbols used in the generated code to scalar arguments - free_symbols = free_symbols if free_symbols is not None else self.used_symbols(all_symbols=False) + free_symbols = free_symbols if free_symbols is not None else self.used_symbols(all_symbols=False)[0] scalar_args.update({k: dt.Scalar(self.symbols[k]) for k in free_symbols if not k.startswith('__dace')}) # Fill up ordered dictionary @@ -1602,22 +1460,21 @@ def transients(self): return result def shared_transients(self, check_toplevel=True) -> List[str]: - """ Returns a list of transient data that appears in more than one - state. """ + """ Returns a list of transient data that appears in more than one state. """ seen = {} shared = [] # If a transient is present in an inter-state edge, it is shared - for interstate_edge in self.edges(): + for interstate_edge in self.all_interstate_edges_recursive(): for sym in interstate_edge.data.free_symbols: if sym in self.arrays and self.arrays[sym].transient: seen[sym] = interstate_edge shared.append(sym) # If transient is accessed in more than one state, it is shared - for state in self.nodes(): - for node in state.nodes(): - if isinstance(node, nd.AccessNode) and node.desc(self).transient: + for state in self.all_states_recursive(): + for node in state.data_nodes(): + if node.desc(self).transient: if (check_toplevel and node.desc(self).toplevel) or (node.data in seen and seen[node.data] != state): shared.append(node.data) @@ -1706,62 +1563,6 @@ def from_file(filename: str) -> 'SDFG': # Dynamic SDFG creation API ############################## - def add_state(self, label=None, is_start_state=False) -> 'SDFGState': - """ Adds a new SDFG state to this graph and returns it. - - :param label: State label. - :param is_start_state: If True, resets SDFG starting state to this - state. - :return: A new SDFGState object. - """ - if self._labels is None or len(self._labels) != self.number_of_nodes(): - self._labels = set(s.label for s in self.nodes()) - label = label or 'state' - existing_labels = self._labels - label = dt.find_new_name(label, existing_labels) - state = SDFGState(label, self) - self._labels.add(label) - - self.add_node(state, is_start_state=is_start_state) - return state - - def add_state_before(self, state: 'SDFGState', label=None, is_start_state=False) -> 'SDFGState': - """ Adds a new SDFG state before an existing state, reconnecting - predecessors to it instead. - - :param state: The state to prepend the new state before. - :param label: State label. - :param is_start_state: If True, resets SDFG starting state to this - state. - :return: A new SDFGState object. - """ - new_state = self.add_state(label, is_start_state) - # Reconnect - for e in self.in_edges(state): - self.remove_edge(e) - self.add_edge(e.src, new_state, e.data) - # Add unconditional connection between the new state and the current - self.add_edge(new_state, state, InterstateEdge()) - return new_state - - def add_state_after(self, state: 'SDFGState', label=None, is_start_state=False) -> 'SDFGState': - """ Adds a new SDFG state after an existing state, reconnecting - it to the successors instead. - - :param state: The state to append the new state after. - :param label: State label. - :param is_start_state: If True, resets SDFG starting state to this - state. - :return: A new SDFGState object. - """ - new_state = self.add_state(label, is_start_state) - # Reconnect - for e in self.out_edges(state): - self.remove_edge(e) - self.add_edge(new_state, e.dst, e.data) - # Add unconditional connection between the current and the new state - self.add_edge(state, new_state, InterstateEdge()) - return new_state def _find_new_name(self, name: str): """ Tries to find a new name by adding an underscore and a number. """ @@ -2208,83 +2009,6 @@ def add_rdistrarray(self, array_a: str, array_b: str): self.append_exit_code(self._rdistrarrays[rdistrarray_name].exit_code(self)) return rdistrarray_name - def add_loop( - self, - before_state, - loop_state, - after_state, - loop_var: str, - initialize_expr: str, - condition_expr: str, - increment_expr: str, - loop_end_state=None, - ): - """ - Helper function that adds a looping state machine around a - given state (or sequence of states). - - :param before_state: The state after which the loop should - begin, or None if the loop is the first - state (creates an empty state). - :param loop_state: The state that begins the loop. See also - ``loop_end_state`` if the loop is multi-state. - :param after_state: The state that should be invoked after - the loop ends, or None if the program - should terminate (creates an empty state). - :param loop_var: A name of an inter-state variable to use - for the loop. If None, ``initialize_expr`` - and ``increment_expr`` must be None. - :param initialize_expr: A string expression that is assigned - to ``loop_var`` before the loop begins. - If None, does not define an expression. - :param condition_expr: A string condition that occurs every - loop iteration. If None, loops forever - (undefined behavior). - :param increment_expr: A string expression that is assigned to - ``loop_var`` after every loop iteration. - If None, does not define an expression. - :param loop_end_state: If the loop wraps multiple states, the - state where the loop iteration ends. - If None, sets the end state to - ``loop_state`` as well. - :return: A 3-tuple of (``before_state``, generated loop guard state, - ``after_state``). - """ - from dace.frontend.python.astutils import negate_expr # Avoid import loops - - # Argument checks - if loop_var is None and (initialize_expr or increment_expr): - raise ValueError("Cannot initalize or increment an empty loop variable") - - # Handling empty states - if loop_end_state is None: - loop_end_state = loop_state - if before_state is None: - before_state = self.add_state() - if after_state is None: - after_state = self.add_state() - - # Create guard state - guard = self.add_state("guard") - - # Loop initialization - init = None if initialize_expr is None else {loop_var: initialize_expr} - self.add_edge(before_state, guard, InterstateEdge(assignments=init)) - - # Loop condition - if condition_expr: - cond_ast = CodeBlock(condition_expr).code - else: - cond_ast = CodeBlock('True').code - self.add_edge(guard, loop_state, InterstateEdge(cond_ast)) - self.add_edge(guard, after_state, InterstateEdge(negate_expr(cond_ast))) - - # Loop incrementation - incr = None if increment_expr is None else {loop_var: increment_expr} - self.add_edge(loop_end_state, guard, InterstateEdge(assignments=incr)) - - return before_state, guard, after_state - # SDFG queries ############################## @@ -2329,7 +2053,8 @@ def is_loaded(self) -> bool: process. """ # Avoid import loops - from dace.codegen import compiled_sdfg as cs, compiler + from dace.codegen import compiled_sdfg as cs + from dace.codegen import compiler binary_filename = compiler.get_binary_name(self.build_folder, self.name) dll = cs.ReloadableDLL(binary_filename, self.name) @@ -2347,6 +2072,7 @@ def compile(self, output_file=None, validate=True) -> 'CompiledSDFG': # Importing these outside creates an import loop from dace.codegen import codegen, compiler + from dace.sdfg import utils as sdutils # Compute build folder path before running codegen build_folder = self.build_folder @@ -2367,6 +2093,10 @@ def compile(self, output_file=None, validate=True) -> 'CompiledSDFG': # if the codegen modifies the SDFG (thereby changing its hash) sdfg.build_folder = build_folder + # Convert any scope blocks to old-school state machines for now. + # TODO: Adapt codegen to deal wiht scope blocks instead. + sdutils.inline_loop_blocks(sdfg) + # Rename SDFG to avoid runtime issues with clashing names index = 0 while sdfg.is_loaded(): @@ -2482,8 +2212,10 @@ def __call__(self, *args, **kwargs): def fill_scope_connectors(self): """ Fills missing scope connectors (i.e., "IN_#"/"OUT_#" on entry/exit nodes) according to data on the memlets. """ - for state in self.nodes(): - state.fill_scope_connectors() + for cf in self.all_state_scopes_recursive(): + for block in cf.nodes(): + if isinstance(block, SDFGState): + block.fill_scope_connectors() def predecessor_state_transitions(self, state): """ Yields paths (lists of edges) that the SDFG can pass through @@ -2542,7 +2274,8 @@ def _initialize_transformations_from_type( :param options: Zero or more transformation initialization option dictionaries. :return: List of PatternTransformation objects inititalized with their properties. """ - from dace.transformation import PatternTransformation # Avoid import loops + from dace.transformation import \ + PatternTransformation # Avoid import loops if isinstance(xforms, (PatternTransformation, type)): xforms = [xforms] @@ -2606,7 +2339,8 @@ def apply_transformations(self, [MapTiling, MapFusion, GPUTransformSDFG], options=[{'tile_size': 16}, {}, {}]) """ - from dace.transformation.passes.pattern_matching import PatternMatchAndApply # Avoid import loops + from dace.transformation.passes.pattern_matching import \ + PatternMatchAndApply # Avoid import loops xforms = self._initialize_transformations_from_type(xforms, options) @@ -2660,7 +2394,8 @@ def apply_transformations_repeated(self, # Applies InlineSDFG until no more subgraphs can be inlined sdfg.apply_transformations_repeated(InlineSDFG) """ - from dace.transformation.passes.pattern_matching import PatternMatchAndApplyRepeated + from dace.transformation.passes.pattern_matching import \ + PatternMatchAndApplyRepeated xforms = self._initialize_transformations_from_type(xforms, options) @@ -2711,7 +2446,8 @@ def apply_transformations_once_everywhere(self, # Tiles all maps once sdfg.apply_transformations_once_everywhere(MapTiling, options=dict(tile_size=16)) """ - from dace.transformation.passes.pattern_matching import PatternApplyOnceEverywhere + from dace.transformation.passes.pattern_matching import \ + PatternApplyOnceEverywhere xforms = self._initialize_transformations_from_type(xforms, options) @@ -2778,7 +2514,7 @@ def expand_library_nodes(self, recursive=True): including library nodes that expand to library nodes. """ - states = list(self.states()) + states = list(self.all_states_recursive()) while len(states) > 0: state = states.pop() expanded_something = False diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 1ff8fe4cf1..8017965b3c 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2,6 +2,7 @@ """ Contains classes of a single SDFG state and dataflow subgraphs. """ import ast +import abc import collections import copy import inspect @@ -19,11 +20,17 @@ from dace.properties import (CodeBlock, DictProperty, EnumProperty, Property, SubsetProperty, SymbolicProperty, CodeProperty, make_properties) from dace.sdfg import nodes as nd -from dace.sdfg.graph import MultiConnectorEdge, OrderedMultiDiConnectorGraph, SubgraphView +from dace.sdfg.graph import MultiConnectorEdge, SubgraphView, OrderedDiGraph, OrderedMultiDiConnectorGraph, Edge from dace.sdfg.propagation import propagate_memlet from dace.sdfg.validation import validate_state from dace.subsets import Range, Subset + +SomeNodeT = Union[nd.Node, 'ControlFlowBlock'] +SomeEdgeT = Union[MultiConnectorEdge[mm.Memlet], Edge['dace.sdfg.InterstateEdge']] +SomeGraphT = Union['ScopeBlock', 'SDFGState'] + + if TYPE_CHECKING: import dace.sdfg.scope @@ -66,14 +73,250 @@ def _make_iterators(ndrange): return params, map_range -class StateGraphView(object): +class BlockGraphView(abc.ABC): """ - Read-only view interface of an SDFG state, containing methods for memlet - tracking, traversal, subgraph creation, queries, and replacements. - ``SDFGState`` and ``StateSubgraphView`` inherit from this class to share + Read-only view interface of an SDFG control flow block, containing methods for memlet tracking, traversal, subgraph + creation, queries, and replacements. ``ControlFlowBlock`` and ``StateSubgraphView`` inherit from this class to share methods. """ + ################################################################### + # Typing overrides + + @overload + def nodes(self) -> List[SomeNodeT]: + ... + + @overload + def edges(self) -> List[SomeEdgeT]: + ... + + @overload + def in_degree(self, node: SomeNodeT) -> int: + ... + + @overload + def out_degree(self, node: SomeNodeT) -> int: + ... + + ################################################################### + # Traversal methods + + @abc.abstractmethod + def all_nodes_recursive(self) -> Iterator[Tuple[SomeNodeT, SomeGraphT]]: + """ + Iterate over all nodes in this graph or subgraph. + This includes control flow blocks, nodes in those blocks, and recursive control flow blocks and nodes within + nested SDFGs. It returns tuples of the form (node, parent), where the node is either a dataflow node, in which + case the parent is an SDFG state, or a control flow block, in which case the parent is a control flow graph + (i.e., an SDFG or a scope block). + """ + raise NotImplementedError() + + @abc.abstractmethod + def all_edges_recursive(self) -> Iterator[Tuple[SomeEdgeT, SomeGraphT]]: + """ + Iterate over all edges in this graph or subgraph. + This includes dataflow edges, inter-state edges, and recursive edges within nested SDFGs. It returns tuples of + the form (edge, parent), where the edge is either a dataflow edge, in which case the parent is an SDFG state, or + an inter-stte edge, in which case the parent is a control flow graph (i.e., an SDFG or a scope block). + """ + raise NotImplementedError() + + @abc.abstractmethod + def data_nodes(self) -> List[nd.AccessNode]: + """ + Returns all data nodes (i.e., AccessNodes, arrays) present in this graph or subgraph. + Note: This does not recurse into nested SDFGs. + """ + raise NotImplementedError() + + @abc.abstractmethod + def entry_node(self, node: nd.Node) -> nd.EntryNode: + """ Returns the entry node that wraps the current node, or None if it is top-level in a state. """ + raise NotImplementedError() + + @abc.abstractmethod + def exit_node(self, entry_node: nd.EntryNode) -> nd.ExitNode: + """ Returns the exit node leaving the context opened by the given entry node. """ + raise NotImplementedError() + + ################################################################### + # Memlet-tracking methods + + @abc.abstractmethod + def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnectorEdge[mm.Memlet]]: + """ + Given one edge, returns a list of edges representing a path between its source and sink nodes. + Used for memlet tracking. + + :note: Behavior is undefined when there is more than one path involving this edge. + :param edge: An edge within a state (memlet). + :return: A list of edges from a source node to a destination node. + """ + raise NotImplementedError() + + @abc.abstractmethod + def memlet_tree(self, edge: MultiConnectorEdge) -> mm.MemletTree: + """ + Given one edge, returns a tree of edges between its node source(s) and sink(s). + Used for memlet tracking. + + :param edge: An edge within a state (memlet). + :return: A tree of edges whose root is the source/sink node (depending on direction) and associated children + edges. + """ + raise NotImplementedError() + + @abc.abstractmethod + def in_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + """ + Returns a generator over edges entering the given connector of the given node. + + :param node: Destination node of edges. + :param connector: Destination connector of edges. + """ + raise NotImplementedError() + + @abc.abstractmethod + def out_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + """ + Returns a generator over edges exiting the given connector of the given node. + + :param node: Source node of edges. + :param connector: Source connector of edges. + """ + raise NotImplementedError() + + @abc.abstractmethod + def edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + """ + Returns a generator over edges entering or exiting the given connector of the given node. + + :param node: Source/destination node of edges. + :param connector: Source/destination connector of edges. + """ + raise NotImplementedError() + + ################################################################### + # Query, subgraph, and replacement methods + + @abc.abstractmethod + def used_symbols(self, all_symbols: bool) -> Set[str]: + """ + Returns a set of symbol names that are used in the graph. + + :param all_symbols: If False, only returns symbols that are needed as arguments (only used in generated code). + """ + raise NotImplementedError() + + @property + def free_symbols(self) -> Set[str]: + """ + Returns a set of symbol names that are used, but not defined, in this graph view. + In the case of an SDFG, this property is used to determine the symbolic parameters of the SDFG and + verify that ``SDFG.symbols`` is complete. + + :note: Assumes that the graph is valid (i.e., without undefined or overlapping symbols). + """ + return self.used_symbols(all_symbols=True) + + @abc.abstractmethod + def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: + """ + Determines what data is read and written in this graph. + Does not include reads to subsets of containers that have previously been written within the same state. + + :return: A two-tuple of sets of things denoting ({data read}, {data written}). + """ + raise NotImplementedError() + + @abc.abstractmethod + def unordered_arglist(self, + defined_syms=None, + shared_transients=None) -> Tuple[Dict[str, dt.Data], Dict[str, dt.Data]]: + raise NotImplementedError() + + def arglist(self, defined_syms=None, shared_transients=None) -> Dict[str, dt.Data]: + """ + Returns an ordered dictionary of arguments (names and types) required to invoke this subgraph. + + The arguments differ from SDFG.arglist, but follow the same order, + namely: , . + + Data arguments contain: + * All used non-transient data containers in the subgraph + * All used transient data containers that were allocated outside. + This includes data from memlets, transients shared across multiple states, and transients that could not + be allocated within the subgraph (due to their ``AllocationLifetime`` or according to the + ``dtypes.can_allocate`` function). + + Scalar arguments contain: + * Free symbols in this state/subgraph. + * All transient and non-transient scalar data containers used in this subgraph. + + This structure will create a sorted list of pointers followed by a sorted list of PoDs and structs. + + :return: An ordered dictionary of (name, data descriptor type) of all the arguments, sorted as defined here. + """ + data_args, scalar_args = self.unordered_arglist(defined_syms, shared_transients) + + # Fill up ordered dictionary + result = collections.OrderedDict() + for k, v in itertools.chain(sorted(data_args.items()), sorted(scalar_args.items())): + result[k] = v + + return result + + def signature_arglist(self, with_types=True, for_call=False): + """ Returns a list of arguments necessary to call this state or subgraph, formatted as a list of C definitions. + + :param with_types: If True, includes argument types in the result. + :param for_call: If True, returns arguments that can be used when calling the SDFG. + :return: A list of strings. For example: `['float *A', 'int b']`. + """ + return [v.as_arg(name=k, with_types=with_types, for_call=for_call) for k, v in self.arglist().items()] + + @abc.abstractmethod + def scope_subgraph(self, entry_node, include_entry=True, include_exit=True): + raise NotImplementedError() + + @abc.abstractmethod + def top_level_transients(self) -> Set[str]: + """Iterate over top-level transients of this graph.""" + raise NotImplementedError() + + @abc.abstractmethod + def all_transients(self) -> List[str]: + """Iterate over all transients in this graph.""" + raise NotImplementedError() + + @abc.abstractmethod + def replace(self, name: str, new_name: str): + """ + Finds and replaces all occurrences of a symbol or array in this graph. + + :param name: Name to find. + :param new_name: Name to replace. + """ + raise NotImplementedError() + + @abc.abstractmethod + def replace_dict(self, + repl: Dict[str, str], + symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None): + """ + Finds and replaces all occurrences of a set of symbols or arrays in this graph. + + :param repl: Mapping from names to replacements. + :param symrepl: Optional symbolic version of ``repl``. + """ + raise NotImplementedError() + + +@make_properties +class DataflowGraphView(BlockGraphView, abc.ABC): + def __init__(self, *args, **kwargs): self._clear_scopedict_cache() @@ -91,31 +334,26 @@ def edges(self) -> List[MultiConnectorEdge[mm.Memlet]]: ################################################################### # Traversal methods - def all_nodes_recursive(self): + def all_nodes_recursive(self) -> Iterator[Tuple[SomeNodeT, SomeGraphT]]: for node in self.nodes(): yield node, self if isinstance(node, nd.NestedSDFG): yield from node.sdfg.all_nodes_recursive() - def all_edges_recursive(self): + def all_edges_recursive(self) -> Iterator[Tuple[SomeEdgeT, SomeGraphT]]: for e in self.edges(): yield e, self for node in self.nodes(): if isinstance(node, nd.NestedSDFG): yield from node.sdfg.all_edges_recursive() - def data_nodes(self): - """ Returns all data_nodes (arrays) present in this state. """ + def data_nodes(self) -> List[nd.AccessNode]: return [n for n in self.nodes() if isinstance(n, nd.AccessNode)] - def entry_node(self, node: nd.Node) -> nd.EntryNode: - """ Returns the entry node that wraps the current node, or None if - it is top-level in a state. """ + def entry_node(self, node: nd.Node) -> Optional[nd.EntryNode]: return self.scope_dict()[node] - def exit_node(self, entry_node: nd.EntryNode) -> nd.ExitNode: - """ Returns the exit node leaving the context opened by - the given entry node. """ + def exit_node(self, entry_node: nd.EntryNode) -> Optional[nd.ExitNode]: node_to_children = self.scope_children() return next(v for v in node_to_children[entry_node] if isinstance(v, nd.ExitNode)) @@ -152,7 +390,7 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto result.insert(0, next_edge) curedge = next_edge - # Prepend outgoing edges until reaching the sink node + # Append outgoing edges until reaching the sink node curedge = edge while not isinstance(curedge.dst, (nd.CodeNode, nd.AccessNode)): # Trace through scope entry using IN_# -> OUT_# @@ -168,13 +406,6 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto return result def memlet_tree(self, edge: MultiConnectorEdge) -> mm.MemletTree: - """ Given one edge, returns a tree of edges between its node source(s) - and sink(s). Used for memlet tracking. - - :param edge: An edge within this state. - :return: A tree of edges whose root is the source/sink node - (depending on direction) and associated children edges. - """ propagate_forward = False propagate_backward = False if ((isinstance(edge.src, nd.EntryNode) and edge.src_conn is not None) or @@ -246,30 +477,12 @@ def traverse(node): return traverse(tree_root) def in_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: - """ Returns a generator over edges entering the given connector of the - given node. - - :param node: Destination node of edges. - :param connector: Destination connector of edges. - """ return (e for e in self.in_edges(node) if e.dst_conn == connector) def out_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: - """ Returns a generator over edges exiting the given connector of the - given node. - - :param node: Source node of edges. - :param connector: Source connector of edges. - """ return (e for e in self.out_edges(node) if e.src_conn == connector) def edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: - """ Returns a generator over edges entering or exiting the given - connector of the given node. - - :param node: Source/destination node of edges. - :param connector: Source/destination connector of edges. - """ return itertools.chain(self.in_edges_by_connector(node, connector), self.out_edges_by_connector(node, connector)) @@ -297,8 +510,6 @@ def scope_tree(self) -> 'dace.sdfg.scope.ScopeTree': result = {} - sdfg_symbols = self.parent.symbols.keys() - # Get scopes for node, scopenodes in sdc.items(): if node is None: @@ -325,20 +536,11 @@ def scope_leaves(self) -> List['dace.sdfg.scope.ScopeTree']: self._scope_leaves_cached = [scope for scope in st.values() if len(scope.children) == 0] return copy.copy(self._scope_leaves_cached) - def scope_dict(self, return_ids: bool = False, validate: bool = True) -> Dict[nd.Node, Optional[nd.Node]]: - """ Returns a dictionary that maps each SDFG node to its parent entry - node, or to None if the node is not in any scope. - - :param return_ids: Return node ID numbers instead of node objects. - :param validate: Ensure that the graph is not malformed when - computing dictionary. - :return: The mapping from a node to its parent scope entry node. - """ - from dace.sdfg.scope import _scope_dict_inner, _scope_dict_to_ids - result = None - result = copy.copy(self._scope_dict_toparent_cached) - - if result is None: + def scope_dict(self, validate: bool = True) -> Dict[nd.Node, Union['SDFGState', nd.Node]]: + from dace.sdfg.scope import _scope_dict_inner + if self._scope_dict_toparent_cached is not None: + return copy.copy(self._scope_dict_toparent_cached) + else: result = {} node_queue = collections.deque(self.source_nodes()) eq = _scope_dict_inner(self, node_queue, None, False, result) @@ -359,30 +561,13 @@ def scope_dict(self, return_ids: bool = False, validate: bool = True) -> Dict[nd # Cache result self._scope_dict_toparent_cached = result - result = copy.copy(result) - - if return_ids: - return _scope_dict_to_ids(self, result) - return result - - def scope_children(self, - return_ids: bool = False, - validate: bool = True) -> Dict[Optional[nd.EntryNode], List[nd.Node]]: - """ Returns a dictionary that maps each SDFG entry node to its children, - not including the children of children entry nodes. The key `None` - contains a list of top-level nodes (i.e., not in any scope). + return copy.copy(result) - :param return_ids: Return node ID numbers instead of node objects. - :param validate: Ensure that the graph is not malformed when - computing dictionary. - :return: The mapping from a node to a list of children nodes. - """ - from dace.sdfg.scope import _scope_dict_inner, _scope_dict_to_ids - result = None + def scope_children(self, validate: bool = True) -> Dict[Union[nd.Node, 'SDFGState'], List[nd.Node]]: + from dace.sdfg.scope import _scope_dict_inner if self._scope_dict_tochildren_cached is not None: - result = copy.copy(self._scope_dict_tochildren_cached) - - if result is None: + return copy.copy(self._scope_dict_tochildren_cached) + else: result = {} node_queue = collections.deque(self.source_nodes()) eq = _scope_dict_inner(self, node_queue, None, True, result) @@ -403,11 +588,7 @@ def scope_children(self, # Cache result self._scope_dict_tochildren_cached = result - result = copy.copy(result) - - if return_ids: - return _scope_dict_to_ids(self, result) - return result + return copy.copy(result) ################################################################### # Query, subgraph, and replacement methods @@ -427,7 +608,7 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: in the generated code and are needed as arguments. """ state = self.graph if isinstance(self, SubgraphView) else self - sdfg = state.parent + sdfg = state.sdfg new_symbols = set() freesyms = set() @@ -485,12 +666,8 @@ def free_symbols(self) -> Set[str]: return self.used_symbols(all_symbols=True) def defined_symbols(self) -> Dict[str, dt.Data]: - """ - Returns a dictionary that maps currently-defined symbols in this SDFG - state or subgraph to their types. - """ state = self.graph if isinstance(self, SubgraphView) else self - sdfg = state.parent + sdfg: dace.SDFG = state.sdfg # Start with SDFG global symbols defined_syms = {k: v for k, v in sdfg.symbols.items()} @@ -524,10 +701,6 @@ def update_if_not_none(dic, update): return defined_syms def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, List[Subset]]]: - """ - Determines what data is read and written in this subgraph, returning - dictionaries from data containers to all subsets that are read/written. - """ read_set = collections.defaultdict(list) write_set = collections.defaultdict(list) from dace.sdfg import utils # Avoid cyclic import @@ -535,8 +708,7 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, for sg in subgraphs: rs = collections.defaultdict(list) ws = collections.defaultdict(list) - # Traverse in topological order, so data that is written before it - # is read is not counted in the read set + # Traverse in topological order, so data that is written before it is read is not counted in the read set for n in utils.dfs_topological_sort(sg, sources=sg.source_nodes()): if isinstance(n, nd.AccessNode): in_edges = sg.in_edges(n) @@ -560,9 +732,8 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, if e.data.is_empty(): continue rs[n.data].append(e.data.subset) - # Union all subgraphs, so an array that was excluded from the read - # set because it was written first is still included if it is read - # in another subgraph + # Union all subgraphs, so an array that was excluded from the read set because it was written first is still + # included if it is read in another subgraph for data, accesses in rs.items(): read_set[data] += accesses for data, accesses in ws.items(): @@ -570,42 +741,12 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, return read_set, write_set def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: - """ - Determines what data is read and written in this subgraph. - - :return: A two-tuple of sets of things denoting - ({data read}, {data written}). - """ read_set, write_set = self._read_and_write_sets() return set(read_set.keys()), set(write_set.keys()) - def arglist(self, defined_syms=None, shared_transients=None) -> Dict[str, dt.Data]: - """ - Returns an ordered dictionary of arguments (names and types) required - to invoke this SDFG state or subgraph thereof. - - The arguments differ from SDFG.arglist, but follow the same order, - namely: , . - - Data arguments contain: - * All used non-transient data containers in the subgraph - * All used transient data containers that were allocated outside. - This includes data from memlets, transients shared across multiple - states, and transients that could not be allocated within the - subgraph (due to their ``AllocationLifetime`` or according to the - ``dtypes.can_allocate`` function). - - Scalar arguments contain: - * Free symbols in this state/subgraph. - * All transient and non-transient scalar data containers used in - this subgraph. - - This structure will create a sorted list of pointers followed by a - sorted list of PoDs and structs. - - :return: An ordered dictionary of (name, data descriptor type) of all - the arguments, sorted as defined here. - """ + def unordered_arglist(self, + defined_syms=None, + shared_transients=None) -> Tuple[Dict[str, dt.Data], Dict[str, dt.Data]]: sdfg: 'dace.sdfg.SDFG' = self.parent shared_transients = shared_transients or sdfg.shared_transients() sdict = self.scope_dict() @@ -622,8 +763,7 @@ def arglist(self, defined_syms=None, shared_transients=None) -> Dict[str, dt.Dat if isinstance(node.desc(sdfg), dt.Scalar): scalars_with_nodes.add(node.data) - # If a subgraph, and a node appears outside the subgraph as well, - # it is externally allocated + # If a subgraph, and a node appears outside the subgraph as well, it is externally allocated if isinstance(self, SubgraphView): outer_nodes = set(self.graph.nodes()) - set(self.nodes()) for node in outer_nodes: @@ -634,8 +774,7 @@ def arglist(self, defined_syms=None, shared_transients=None) -> Dict[str, dt.Dat else: data_args[node.data] = desc - # Add data arguments from memlets, if do not appear in any of the nodes - # (i.e., originate externally) + # Add data arguments from memlets, if do not appear in any of the nodes (i.e., originate externally) for edge in self.edges(): if edge.data.data is not None and edge.data.data not in descs: desc = sdfg.arrays[edge.data.data] @@ -670,8 +809,7 @@ def arglist(self, defined_syms=None, shared_transients=None) -> Dict[str, dt.Dat elif isinstance(self, SubgraphView): if (desc.lifetime != dtypes.AllocationLifetime.Scope): data_args[name] = desc - # Check for allocation constraints that would - # enforce array to be allocated outside subgraph + # Check for allocation constraints that would enforce array to be allocated outside subgraph elif desc.lifetime == dtypes.AllocationLifetime.Scope: curnode = sdict[node] while curnode is not None: @@ -679,8 +817,7 @@ def arglist(self, defined_syms=None, shared_transients=None) -> Dict[str, dt.Dat break curnode = sdict[curnode] else: - # If no internal scope can allocate node, - # mark as external + # If no internal scope can allocate node, mark as external data_args[name] = desc # End of data descriptor loop @@ -699,12 +836,7 @@ def arglist(self, defined_syms=None, shared_transients=None) -> Dict[str, dt.Dat if not str(k).startswith('__dace') and str(k) not in sdfg.constants }) - # Fill up ordered dictionary - result = collections.OrderedDict() - for k, v in itertools.chain(sorted(data_args.items()), sorted(scalar_args.items())): - result[k] = v - - return result + return data_args, scalar_args def signature_arglist(self, with_types=True, for_call=False): """ Returns a list of arguments necessary to call this state or @@ -721,50 +853,250 @@ def scope_subgraph(self, entry_node, include_entry=True, include_exit=True): from dace.sdfg.scope import _scope_subgraph return _scope_subgraph(self, entry_node, include_entry, include_exit) - def top_level_transients(self): - """Iterate over top-level transients of this state.""" + def top_level_transients(self) -> Set[str]: schildren = self.scope_children() sdfg = self.parent result = set() - for node in schildren[None]: + for node in schildren[self]: if isinstance(node, nd.AccessNode) and node.desc(sdfg).transient: result.add(node.data) return result def all_transients(self) -> List[str]: - """Iterate over all transients in this state.""" return dtypes.deduplicate( [n.data for n in self.nodes() if isinstance(n, nd.AccessNode) and n.desc(self.parent).transient]) def replace(self, name: str, new_name: str): - """ Finds and replaces all occurrences of a symbol or array in this - state. - - :param name: Name to find. - :param new_name: Name to replace. - """ from dace.sdfg.replace import replace replace(self, name, new_name) def replace_dict(self, repl: Dict[str, str], symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None): - """ Finds and replaces all occurrences of a set of symbols or arrays in this state. - - :param repl: Mapping from names to replacements. - :param symrepl: Optional symbolic version of ``repl``. - """ from dace.sdfg.replace import replace_dict replace_dict(self, repl, symrepl) @make_properties -class SDFGState(OrderedMultiDiConnectorGraph[nd.Node, mm.Memlet], StateGraphView): +class ControlGraphView(BlockGraphView, abc.ABC): + + ################################################################### + # Typing overrides + + @overload + def nodes(self) -> List['ControlFlowBlock']: + ... + + @overload + def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]: + ... + + ################################################################### + # Traversal methods + + def all_nodes_recursive(self) -> Iterator[Tuple[SomeNodeT, SomeGraphT]]: + for node in self.nodes(): + yield node, self + yield from node.all_nodes_recursive() + + def all_edges_recursive(self) -> Iterator[Tuple[SomeEdgeT, SomeGraphT]]: + for e in self.edges(): + yield e, self + for node in self.nodes(): + yield from node.all_edges_recursive() + + def data_nodes(self) -> List[nd.AccessNode]: + data_nodes = [] + for node in self.nodes(): + data_nodes.extend(node.data_nodes()) + return data_nodes + + def entry_node(self, node: nd.Node) -> Optional[nd.EntryNode]: + for block in self.nodes(): + if node in block.nodes(): + return block.exit_node(node) + return None + + def exit_node(self, entry_node: nd.EntryNode) -> Optional[nd.ExitNode]: + for block in self.nodes(): + if entry_node in block.nodes(): + return block.exit_node(entry_node) + return None + + ################################################################### + # Memlet-tracking methods + + def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnectorEdge[mm.Memlet]]: + for block in self.nodes(): + if edge in block.edges(): + return block.memlet_path(edge) + return [] + + def memlet_tree(self, edge: MultiConnectorEdge) -> mm.MemletTree: + for block in self.nodes(): + if edge in block.edges(): + return block.memlet_tree(edge) + return mm.MemletTree(edge) + + def in_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + for block in self.nodes(): + if node in block.nodes(): + return block.in_edges_by_connector(node, connector) + return [] + + def out_edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + for block in self.nodes(): + if node in block.nodes(): + return block.out_edges_by_connector(node, connector) + return [] + + def edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[MultiConnectorEdge[mm.Memlet]]: + for block in self.nodes(): + if node in block.nodes(): + return block.edges_by_connector(node, connector) + + ################################################################### + # Query, subgraph, and replacement methods + + def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: + read_set = set() + write_set = set() + for block in self.nodes(): + for edge in self.in_edges(block): + read_set |= edge.data.free_symbols & self.sdfg.arrays.keys() + rs, ws = block.read_and_write_sets() + read_set.update(rs) + write_set.update(ws) + return read_set, write_set + + def unordered_arglist(self, + defined_syms=None, + shared_transients=None) -> Tuple[Dict[str, dt.Data], Dict[str, dt.Data]]: + data_args = {} + scalar_args = {} + for block in self.nodes(): + n_data_args, n_scalar_args = block.unordered_arglist(defined_syms, shared_transients) + data_args.update(n_data_args) + scalar_args.update(n_scalar_args) + return data_args, scalar_args + + def scope_subgraph(self, entry_node, include_entry=True, include_exit=True): + # TODO: Not sure if this makes sense here. + from dace.sdfg.scope import _scope_subgraph + return _scope_subgraph(self, entry_node, include_entry, include_exit) + + def top_level_transients(self) -> Set[str]: + res = set() + for block in self.nodes(): + res.update(block.top_level_transients()) + return res + + def all_transients(self) -> List[str]: + res = [] + for block in self.nodes(): + res.extend(block.all_transients()) + return dtypes.deduplicate(res) + + def replace(self, name: str, new_name: str): + for n in self.nodes(): + n.replace(name, new_name) + + def replace_dict(self, + repl: Dict[str, str], + symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None, + replace_in_graph: bool = True, replace_keys: bool = False): + symrepl = symrepl or { + symbolic.symbol(k): symbolic.pystr_to_symbolic(v) if isinstance(k, str) else v + for k, v in repl.items() + } + + if replace_in_graph: + # Replace in inter-state edges + for edge in self.edges(): + edge.data.replace_dict(repl, replace_keys=replace_keys) + + # Replace in states + for state in self.nodes(): + state.replace_dict(repl, symrepl) + + +@make_properties +class ControlFlowBlock(BlockGraphView, abc.ABC): + + is_collapsed = Property(dtype=bool, desc='Show this block as collapsed', default=False) + + _sdfg: Optional['dace.SDFG'] = None + _parent: Optional['ScopeBlock'] = None + _label: str + + def __init__(self, + label: str='', + parent: Optional['ScopeBlock']=None, + sdfg: Optional['dace.SDFG'] = None): + super(ControlFlowBlock, self).__init__() + self._label = label + self._parent = parent + self._sdfg = sdfg + self._default_lineinfo = None + self.is_collapsed = False + + def set_default_lineinfo(self, lineinfo: dace.dtypes.DebugInfo): + """ + Sets the default source line information to be lineinfo, or None to + revert to default mode. + """ + self._default_lineinfo = lineinfo + + def to_json(self, parent=None): + tmp = { + 'type': self.__class__.__name__, + 'collapsed': self.is_collapsed, + 'label': self._label, + 'id': parent.node_id(self) if parent is not None else None, + } + return tmp + + def __str__(self): + return self._label + + def __repr__(self) -> str: + return f'ControlFlowBlock ({self.label})' + + @property + def label(self) -> str: + return self._label + + @label.setter + def label(self, label: str): + self._label = label + + @property + def name(self) -> str: + return self._label + + @property + def parent(self) -> Optional['ScopeBlock']: + """ Returns the parent block of this block. """ + return self._parent + + @parent.setter + def parent(self, block: Optional['ScopeBlock']): + self._parent = block + + @property + def sdfg(self) -> Optional['dace.SDFG']: + return self._sdfg + + @sdfg.setter + def sdfg(self, sdfg: Optional['dace.SDFG']): + self._sdfg = sdfg + + +@make_properties +class SDFGState(OrderedMultiDiConnectorGraph[nd.Node, mm.Memlet], ControlFlowBlock, DataflowGraphView): """ An acyclic dataflow multigraph in an SDFG, corresponding to a single state in the SDFG state machine. """ - is_collapsed = Property(dtype=bool, desc="Show this node/scope/state as collapsed", default=False) - nosync = Property(dtype=bool, default=False, desc="Do not synchronize at the end of the state") instrument = EnumProperty(dtype=dtypes.InstrumentationType, @@ -795,7 +1127,7 @@ class SDFGState(OrderedMultiDiConnectorGraph[nd.Node, mm.Memlet], StateGraphView def __repr__(self) -> str: return f"SDFGState ({self.label})" - def __init__(self, label=None, sdfg=None, debuginfo=None, location=None): + def __init__(self, label=None, sdfg=None, debuginfo=None, location=None, parent=None): """ Constructs an SDFG state. :param label: Name for the state (optional). @@ -803,13 +1135,11 @@ def __init__(self, label=None, sdfg=None, debuginfo=None, location=None): :param debuginfo: Source code locator for debugging. """ from dace.sdfg.sdfg import SDFG # Avoid import loop - super(SDFGState, self).__init__() - self._label = label - self._parent: SDFG = sdfg + OrderedMultiDiConnectorGraph.__init__(self) + ControlFlowBlock.__init__(self, label, parent, sdfg) self._graph = self # Allowing MemletTrackingView mixin to work self._clear_scopedict_cache() self._debuginfo = debuginfo - self.is_collapsed = False self.nosync = False self.location = location if location is not None else {} self._default_lineinfo = None @@ -830,42 +1160,12 @@ def __deepcopy__(self, memo): pass return result - @property - def parent(self): - """ Returns the parent SDFG of this state. """ - return self._parent - - @parent.setter - def parent(self, value): - self._parent = value - - def __str__(self): - return self._label - - @property - def label(self): - return self._label - - @property - def name(self): - return self._label - - def set_label(self, label): - self._label = label - def is_empty(self): return self.number_of_nodes() == 0 def validate(self) -> None: validate_state(self) - def set_default_lineinfo(self, lineinfo: dtypes.DebugInfo): - """ - Sets the default source line information to be lineinfo, or None to - revert to default mode. - """ - self._default_lineinfo = lineinfo - def nodes(self) -> List[nd.Node]: # Added for type hints return super().nodes() @@ -886,7 +1186,7 @@ def add_node(self, node): # Correct nested SDFG's parent attributes if isinstance(node, nd.NestedSDFG): node.sdfg.parent = self - node.sdfg.parent_sdfg = self.parent + node.sdfg.parent_sdfg = self.sdfg node.sdfg.parent_nsdfg_node = node self._clear_scopedict_cache() return super(SDFGState, self).add_node(node) @@ -914,7 +1214,7 @@ def add_edge(self, u, u_connector, v, v_connector, memlet): self._clear_scopedict_cache() result = super(SDFGState, self).add_edge(u, u_connector, v, v_connector, memlet) - memlet.try_initialize(self.parent, self, result) + memlet.try_initialize(self.sdfg, self, result) return result def remove_edge(self, edge): @@ -930,15 +1230,18 @@ def remove_edge_and_connectors(self, edge): edge.dst.remove_in_connector(edge.dst_conn) def to_json(self, parent=None): + from dace.sdfg.scope import _scope_dict_to_ids # Create scope dictionary with a failsafe try: - scope_dict = {k: sorted(v) for k, v in sorted(self.scope_children(return_ids=True).items())} + scope_dict_regular = self.scope_children() + scope_dict_ids = _scope_dict_to_ids(self, scope_dict_regular) + scope_dict = {k: sorted(v) for k, v in sorted(scope_dict_ids.items())} except (RuntimeError, ValueError): scope_dict = {} # Try to initialize edges before serialization for edge in self.edges(): - edge.data.try_initialize(self.parent, self, edge) + edge.data.try_initialize(self.sdfg, self, edge) ret = { 'type': type(self).__name__, @@ -1010,7 +1313,7 @@ def _repr_html_(self): from dace.sdfg import SDFG arrays = set(n.data for n in self.data_nodes()) sdfg = SDFG(self.label) - sdfg._arrays = {k: self._parent.arrays[k] for k in arrays} + sdfg._arrays = {k: self.sdfg.arrays[k] for k in arrays} sdfg.add_node(self) return sdfg._repr_html_() @@ -1030,7 +1333,7 @@ def symbols_defined_at(self, node: nd.Node) -> Dict[str, dtypes.typeclass]: if node is None: return collections.OrderedDict() - sdfg: SDFG = self.parent + sdfg: SDFG = self.sdfg # Start with global symbols symbols = collections.OrderedDict(sdfg.symbols) @@ -1172,7 +1475,7 @@ def add_nested_sdfg( debuginfo = _getdebuginfo(debuginfo or self._default_lineinfo) sdfg.parent = self - sdfg.parent_sdfg = self.parent + sdfg.parent_sdfg = self.sdfg sdfg.update_sdfg_list([]) @@ -1438,7 +1741,7 @@ def add_mapped_tasklet(self, # Try to initialize memlets for edge in edges: - edge.data.try_initialize(self.parent, self, edge) + edge.data.try_initialize(self.sdfg, self, edge) return tasklet, map_entry, map_exit @@ -1620,8 +1923,8 @@ def add_edge_pair( ) # Try to initialize memlets - iedge.data.try_initialize(self.parent, self, iedge) - eedge.data.try_initialize(self.parent, self, eedge) + iedge.data.try_initialize(self.sdfg, self, iedge) + eedge.data.try_initialize(self.sdfg, self, eedge) return (iedge, eedge) @@ -1724,7 +2027,7 @@ def add_memlet_path(self, *path_nodes, memlet=None, src_conn=None, dst_conn=None cur_memlet = propagate_memlet(self, cur_memlet, snode, True) # Try to initialize memlets for edge in edges: - edge.data.try_initialize(self.parent, self, edge) + edge.data.try_initialize(self.sdfg, self, edge) def remove_memlet_path(self, edge: MultiConnectorEdge, remove_orphans: bool = True) -> None: """ Removes all memlets and associated connectors along a path formed @@ -1821,9 +2124,9 @@ def add_array(self, 'The "SDFGState.add_array" API is deprecated, please ' 'use "SDFG.add_array" and "SDFGState.add_access"', DeprecationWarning) # Workaround to allow this legacy API - if name in self.parent._arrays: - del self.parent._arrays[name] - self.parent.add_array(name, + if name in self.sdfg._arrays: + del self.sdfg._arrays[name] + self.sdfg.add_array(name, shape, dtype, storage=storage, @@ -1854,9 +2157,9 @@ def add_stream( 'The "SDFGState.add_stream" API is deprecated, please ' 'use "SDFG.add_stream" and "SDFGState.add_access"', DeprecationWarning) # Workaround to allow this legacy API - if name in self.parent._arrays: - del self.parent._arrays[name] - self.parent.add_stream( + if name in self.sdfg._arrays: + del self.sdfg._arrays[name] + self.sdfg.add_stream( name, dtype, buffer_size, @@ -1883,9 +2186,9 @@ def add_scalar( 'The "SDFGState.add_scalar" API is deprecated, please ' 'use "SDFG.add_scalar" and "SDFGState.add_access"', DeprecationWarning) # Workaround to allow this legacy API - if name in self.parent._arrays: - del self.parent._arrays[name] - self.parent.add_scalar(name, dtype, storage, transient, lifetime, debuginfo) + if name in self.sdfg._arrays: + del self.sdfg._arrays[name] + self.sdfg.add_scalar(name, dtype, storage, transient, lifetime, debuginfo) return self.add_access(name, debuginfo) def add_transient(self, @@ -1981,8 +2284,396 @@ def fill_scope_connectors(self): node.add_in_connector(edge.dst_conn) -class StateSubgraphView(SubgraphView, StateGraphView): +class StateSubgraphView(SubgraphView, DataflowGraphView): """ A read-only subgraph view of an SDFG state. """ def __init__(self, graph, subgraph_nodes): super().__init__(graph, subgraph_nodes) + + +@make_properties +class ScopeBlock(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.InterstateEdge'], ControlGraphView, ControlFlowBlock): + + def __init__(self, + label: str='', + parent: Optional['ScopeBlock']=None, + sdfg: Optional['dace.SDFG'] = None): + OrderedDiGraph.__init__(self) + ControlGraphView.__init__(self) + ControlFlowBlock.__init__(self, label, parent, sdfg) + + self._labels: Set[str] = set() + self._start_block: Optional[int] = None + self._cached_start_block: Optional[ControlFlowBlock] = None + + def add_edge(self, src: ControlFlowBlock, dst: ControlFlowBlock, data: 'dace.sdfg.InterstateEdge'): + """ Adds a new edge to the graph. Must be an InterstateEdge or a subclass thereof. + + :param u: Source node. + :param v: Destination node. + :param edge: The edge to add. + """ + if not isinstance(src, ControlFlowBlock): + raise TypeError('Expected ControlFlowBlock, got ' + str(type(src))) + if not isinstance(dst, ControlFlowBlock): + raise TypeError('Expected ControlFlowBlock, got ' + str(type(dst))) + if not isinstance(data, dace.sdfg.InterstateEdge): + raise TypeError('Expected InterstateEdge, got ' + str(type(data))) + if dst is self._cached_start_block: + self._cached_start_block = None + return super().add_edge(src, dst, data) + + def add_node(self, node, is_start_block=False): + if not isinstance(node, ControlFlowBlock): + raise TypeError('Expected ControlFlowBlock, got ' + str(type(node))) + super().add_node(node) + node.parent = self + node.sdfg = self if isinstance(self, dace.SDFG) else self.sdfg + self._cached_start_block = None + if is_start_block is True: + self.start_block = len(self.nodes()) - 1 + self._cached_start_block = node + + def add_state(self, label=None, is_start_block=False) -> SDFGState: + if self._labels is None or len(self._labels) != self.number_of_nodes(): + self._labels = set(s.label for s in self.nodes()) + label = label or 'state' + existing_labels = self._labels + label = dt.find_new_name(label, existing_labels) + state = SDFGState(label) + self._labels.add(label) + self.add_node(state, is_start_block=is_start_block) + return state + + def add_state_before(self, state: SDFGState, label=None, is_start_state=False) -> SDFGState: + """ Adds a new SDFG state before an existing state, reconnecting predecessors to it instead. + + :param state: The state to prepend the new state before. + :param label: State label. + :param is_start_state: If True, resets scope block starting state to this state. + :return: A new SDFGState object. + """ + new_state = self.add_state(label, is_start_state) + # Reconnect + for e in self.in_edges(state): + self.remove_edge(e) + self.add_edge(e.src, new_state, e.data) + # Add unconditional connection between the new state and the current + self.add_edge(new_state, state, dace.sdfg.InterstateEdge()) + return new_state + + def add_state_after(self, state: SDFGState, label=None, is_start_state=False) -> SDFGState: + """ Adds a new SDFG state after an existing state, reconnecting it to the successors instead. + + :param state: The state to append the new state after. + :param label: State label. + :param is_start_state: If True, resets SDFG starting state to this state. + :return: A new SDFGState object. + """ + new_state = self.add_state(label, is_start_state) + # Reconnect + for e in self.out_edges(state): + self.remove_edge(e) + self.add_edge(new_state, e.dst, e.data) + # Add unconditional connection between the current and the new state + self.add_edge(state, new_state, dace.sdfg.InterstateEdge()) + return new_state + + def add_loop( + self, + before_state: SDFGState, + after_state: SDFGState, + loop_var: str, + initialize_expr: str, + condition_expr: str, + increment_expr: str, + inverted: bool = False, + ): + """ + Helper function that adds a looping state machine around a given state (or sequence of states). + + :param before_state: The state after which the loop should begin, or None if the loop is the first state + (creates an empty state). + :param loop_state: The state that begins the loop. See also ``loop_end_state`` if the loop is multi-state. + :param after_state: The state that should be invoked after the loop ends, or None if the program should + terminate (creates an empty state). + :param loop_var: A name of an inter-state variable to use for the loop. If None, ``initialize_expr`` and + ``increment_expr`` must be None. + :param initialize_expr: A string expression that is assigned to ``loop_var`` before the loop begins. If None, + does not define an expression. + :param condition_expr: A string condition that occurs every loop iteration. If None, loops forever (undefined + behavior). + :param increment_expr: A string expression that is assigned to ``loop_var`` after every loop iteration. If None, + does not define an expression. + :param loop_end_state: If the loop wraps multiple states, the state where the loop iteration ends. If None, sets + the end state to ``loop_state`` as well. + :return: A 3-tuple of (``before_state``, generated loop guard state, ``after_state``). + """ + # Argument checks + if loop_var is None and (initialize_expr or increment_expr): + raise ValueError("Cannot initalize or increment an empty loop variable") + + loop_scope = LoopScopeBlock(loop_var=loop_var, + initialize_expr=initialize_expr, + update_expr=increment_expr, + condition_expr=condition_expr, + inverted=inverted) + + # Handling empty states + if before_state is None: + before_state = self.add_state() + if after_state is None: + after_state = self.add_state() + + self.add_node(loop_scope) + self.add_edge(before_state, loop_scope, dace.sdfg.InterstateEdge()) + self.add_edge(loop_scope, after_state, dace.sdfg.InterstateEdge()) + + return before_state, loop_scope, after_state + + @abc.abstractmethod + def used_symbols(self, + all_symbols: bool, + defined_syms: Optional[Set]=None, + free_syms: Optional[Set]=None, + used_before_assignment: Optional[Set]=None, + keep_defined_in_mapping: bool=False) -> Tuple[Set[str], Set[str], Set[str]]: + """ + Returns a set of symbol names that are used by the scope, but not defined within it. + + :param all_symbols: If False, only returns the set of symbols that will be used + in the generated code and are needed as arguments. + :param defined_syms: Set of already defined symbols, if any. Otherwise None. + :param free_syms: Set of already found free symbols, if any. Otherwise None. + :param used_before_assignment: Set of already found symbols that are used before they are assigned to, if any. + Otherwise None. + :param keep_defined_in_mapping: If True, symbols defined in inter-state edges that are in the symbol mapping + will be removed from the set of defined symbols. + :returns: Three-Tuple (Set of free symbols, set of defined symbols, set of symbols used before assignment). + """ + defined_syms = set() if defined_syms is None else defined_syms + free_syms = set() if free_syms is None else free_syms + used_before_assignment = set() if used_before_assignment is None else used_before_assignment + + try: + ordered_blocks = self.topological_sort(self.start_block) + except ValueError: # Failsafe (e.g., for invalid or empty SDFGs) + ordered_blocks = self.nodes() + + for block in ordered_blocks: + state_symbols = set() + if isinstance(block, ScopeBlock): + b_free_syms, b_defined_syms, b_used_before_syms = block.used_symbols(all_symbols) + free_syms |= b_free_syms + defined_syms |= b_defined_syms + used_before_assignment |= b_used_before_syms + state_symbols = b_free_syms + else: + state_symbols = block.used_symbols(all_symbols) + free_syms |= state_symbols + + # Add free inter-state symbols + for e in self.out_edges(block): + # NOTE: First we get the true InterstateEdge free symbols, then we compute the newly defined symbols by + # subracting the (true) free symbols from the edge's assignment keys. This way we can correctly + # compute the symbols that are used before being assigned. + efsyms = e.data.used_symbols(all_symbols) + defined_syms |= set(e.data.assignments.keys()) - (efsyms | state_symbols) + used_before_assignment.update(efsyms - defined_syms) + free_syms |= efsyms + + # Remove symbols that were used before they were assigned. + defined_syms -= used_before_assignment + + if isinstance(self, dace.SDFG): + # Remove from defined symbols those that are in the symbol mapping + if self.parent_nsdfg_node is not None and keep_defined_in_mapping: + defined_syms -= set(self.parent_nsdfg_node.symbol_mapping.keys()) + + # Add the set of SDFG symbol parameters + # If all_symbols is False, those symbols would only be added in the case of non-Python tasklets + if all_symbols: + free_syms |= set(self.symbols.keys()) + + # Subtract symbols defined in inter-state edges and constants from the list of free symbols. + free_syms -= defined_syms + + return free_syms, defined_syms, used_before_assignment + + @property + def free_symbols(self) -> Set[str]: + return self.used_symbols(all_symbols=True)[0] + + def to_json(self, parent=None): + graph_json = OrderedDiGraph.to_json(self) + block_json = ControlFlowBlock.to_json(self, parent) + graph_json.update(block_json) + return graph_json + + ################################################################### + # Traversal methods + + def all_state_scopes_recursive(self, recurse_into_sdfgs=True) -> Iterator['ScopeBlock']: + """ Iterate over this and all nested state scopes. """ + yield self + for block in self.nodes(): + if isinstance(block, SDFGState) and recurse_into_sdfgs: + for node in block.nodes(): + if isinstance(node, nd.NestedSDFG): + yield from node.sdfg.all_state_scopes_recursive(recurse_into_sdfgs=recurse_into_sdfgs) + elif isinstance(block, ScopeBlock): + yield from block.all_state_scopes_recursive(recurse_into_sdfgs=recurse_into_sdfgs) + + def all_sdfgs_recursive(self) -> Iterator['dace.SDFG']: + """ Iterate over this and all nested SDFGs. """ + for cfg in self.all_state_scopes_recursive(recurse_into_sdfgs=True): + if isinstance(cfg, dace.SDFG): + yield cfg + + def all_states_recursive(self) -> Iterator[SDFGState]: + """ Iterate over all states in this control flow graph. """ + for block in self.nodes(): + if isinstance(block, SDFGState): + yield block + elif isinstance(block, ScopeBlock): + yield from block.all_states_recursive() + + def all_control_flow_blocks_recursive(self, recurse_into_sdfgs=True) -> Iterator[ControlFlowBlock]: + """ Iterate over all control flow blocks in this control flow graph. """ + for cfg in self.all_state_scopes_recursive(recurse_into_sdfgs=recurse_into_sdfgs): + for block in cfg.nodes(): + yield block + + def all_interstate_edges_recursive(self, recurse_into_sdfgs=True) -> Iterator[Edge['dace.sdfg.InterstateEdge']]: + """ Iterate over all interstate edges in this control flow graph. """ + for cfg in self.all_state_scopes_recursive(recurse_into_sdfgs=recurse_into_sdfgs): + for edge in cfg.edges(): + yield edge + + ################################################################### + # Getters & setters, overrides + + def __str__(self): + return ControlFlowBlock.__str__(self) + + def __repr__(self) -> str: + return f'{self.__class__.__name__} ({self.label})' + + @property + def start_block(self): + """ Returns the starting block of this ControlFlowGraph. """ + if self._cached_start_block is not None: + return self._cached_start_block + + source_nodes = self.source_nodes() + if len(source_nodes) == 1: + self._cached_start_block = source_nodes[0] + return source_nodes[0] + # If the starting block is ambiguous allow manual override. + if self._start_block is not None: + self._cached_start_block = self.node(self._start_block) + return self._cached_start_block + raise ValueError('Ambiguous or undefined starting block for ControlFlowGraph, ' + 'please use "is_start_block=True" when adding the ' + 'starting block with "add_state" or "add_node"') + + @start_block.setter + def start_block(self, block_id): + """ Manually sets the starting block of this ControlFlowGraph. + + :param block_id: The node ID (use `node_id(block)`) of the block to set. + """ + if block_id < 0 or block_id >= self.number_of_nodes(): + raise ValueError('Invalid state ID') + self._start_block = block_id + self._cached_start_block = self.node(block_id) + + +@make_properties +class LoopScopeBlock(ScopeBlock): + + update_statement = CodeProperty(optional=True, allow_none=True, default=None) + init_statement = CodeProperty(optional=True, allow_none=True, default=None) + scope_condition = CodeProperty(allow_none=True, default=None) + inverted = Property(dtype=bool, default=False) + loop_variable = Property(dtype=str, default='') + + def __init__(self, + loop_var: str, + initialize_expr: str, + condition_expr: str, + update_expr: str, + label: str = '', + parent: Optional[ScopeBlock] = None, + sdfg: Optional['dace.SDFG'] = None, + inverted: bool = False): + super(LoopScopeBlock, self).__init__(label, parent, sdfg) + + if initialize_expr is not None: + self.init_statement = CodeBlock('%s = %s' % (loop_var, initialize_expr)) + else: + self.init_statement = None + + if condition_expr: + self.scope_condition = CodeBlock(condition_expr) + else: + self.scope_condition = CodeBlock('True') + + if update_expr is not None: + self.update_statement = CodeBlock('%s = %s' % (loop_var, update_expr)) + else: + self.update_statement = None + + self.loop_variable = loop_var or '' + self.inverted = inverted + + def used_symbols(self, + all_symbols: bool, + defined_syms: Optional[Set]=None, + free_syms: Optional[Set]=None, + used_before_assignment: Optional[Set]=None, + keep_defined_in_mapping: bool=False) -> Tuple[Set[str], Set[str], Set[str]]: + defined_syms = set() if defined_syms is None else defined_syms + free_syms = set() if free_syms is None else free_syms + used_before_assignment = set() if used_before_assignment is None else used_before_assignment + + defined_syms.add(self.loop_variable) + if self.init_statement is not None: + free_syms |= self.init_statement.get_free_symbols() + if self.update_statement is not None: + free_syms |= self.update_statement.get_free_symbols() + free_syms |= self.scope_condition.get_free_symbols() + + b_free_symbols, b_defined_symbols, b_used_before_assignment = super().used_symbols( + all_symbols, keep_defined_in_mapping=keep_defined_in_mapping + ) + free_syms |= b_free_symbols + defined_syms |= b_defined_symbols + used_before_assignment |= b_used_before_assignment + + defined_syms -= used_before_assignment + free_syms -= defined_syms + + return free_syms, defined_syms, used_before_assignment + + def replace_dict(self, repl: Dict[str, str], + symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None, + replace_in_graph: bool = True, replace_keys: bool = True): + if replace_keys: + from dace.sdfg.replace import replace_properties_dict + replace_properties_dict(self, repl, symrepl) + + if self.loop_variable and self.loop_variable in repl: + self.loop_variable = repl[self.loop_variable] + + super().replace_dict(repl, symrepl, replace_in_graph) + + def to_json(self, parent=None): + return super().to_json(parent) + + +@make_properties +class BranchScopeBlock(ScopeBlock): + + def __init__(self): + super(BranchScopeBlock, self).__init__() diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 1078414161..b62799fc0a 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -13,7 +13,7 @@ from dace.sdfg.graph import MultiConnectorEdge from dace.sdfg.sdfg import SDFG from dace.sdfg.nodes import Node, NestedSDFG -from dace.sdfg.state import SDFGState, StateSubgraphView +from dace.sdfg.state import SDFGState, StateSubgraphView, LoopScopeBlock, ScopeBlock from dace.sdfg.scope import ScopeSubgraphView from dace.sdfg import nodes as nd, graph as gr, propagation from dace import config, data as dt, dtypes, memlet as mm, subsets as sbs, symbolic @@ -668,7 +668,7 @@ def consolidate_edges(sdfg: SDFG, starting_scope=None) -> int: from dace.sdfg.propagation import propagate_memlets_scope total_consolidated = 0 - for state in sdfg.nodes(): + for state in sdfg.states(): # Start bottom-up if starting_scope and starting_scope.entry not in state.nodes(): continue @@ -765,7 +765,7 @@ def get_last_view_node(state: SDFGState, view: nd.AccessNode) -> nd.AccessNode: Given a view access node, returns the last viewed access node if existent, else None """ - sdfg = state.parent + sdfg = state.sdfg node = view desc = sdfg.arrays[node.data] while isinstance(desc, dt.View): @@ -781,7 +781,7 @@ def get_all_view_nodes(state: SDFGState, view: nd.AccessNode) -> List[nd.AccessN Given a view access node, returns a list of viewed access nodes if existent, else None """ - sdfg = state.parent + sdfg = state.sdfg node = view desc = sdfg.arrays[node.data] result = [node] @@ -910,10 +910,10 @@ def is_parallel(state: SDFGState, node: Optional[nd.Node] = None) -> bool: curnode = sdict[curnode] if curnode.schedule != dtypes.ScheduleType.Sequential: return True - if state.parent.parent is not None: + if state.sdfg.parent is not None: # Find nested SDFG node and continue recursion - nsdfg_node = next(n for n in state.parent.parent if isinstance(n, nd.NestedSDFG) and n.sdfg == state.parent) - return is_parallel(state.parent.parent, nsdfg_node) + nsdfg_node = next(n for n in state.sdfg.parent if isinstance(n, nd.NestedSDFG) and n.sdfg == state.sdfg) + return is_parallel(state.sdfg.parent, nsdfg_node) return False @@ -1206,8 +1206,8 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> counter = 0 if progress is True or progress is None: fusible_states = 0 - for sd in sdfg.all_sdfgs_recursive(): - fusible_states += sd.number_of_edges() + for cfg in sdfg.all_state_scopes_recursive(): + fusible_states += cfg.number_of_edges() if progress is True: pbar = tqdm(total=fusible_states, desc='Fusing states') @@ -1216,36 +1216,61 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> for sd in sdfg.all_sdfgs_recursive(): id = sd.sdfg_id - - while True: - edges = list(sd.nx.edges) - applied = 0 - skip_nodes = set() - for u, v in edges: - if (progress is None and tqdm is not None and (time.time() - start) > 5): - progress = True - pbar = tqdm(total=fusible_states, desc='Fusing states', initial=counter) - - if u in skip_nodes or v in skip_nodes: - continue - candidate = {StateFusion.first_state: u, StateFusion.second_state: v} - sf = StateFusion() - sf.setup_match(sd, id, -1, candidate, 0, override=True) - if sf.can_be_applied(sd, 0, sd, permissive=permissive): - sf.apply(sd, sd) - applied += 1 - counter += 1 - if progress: - pbar.update(1) - skip_nodes.add(u) - skip_nodes.add(v) - if applied == 0: - break + for cfg in sd.all_state_scopes_recursive(recurse_into_sdfgs=False): + while True: + edges = list(cfg.nx.edges) + applied = 0 + skip_nodes = set() + for u, v in edges: + if (progress is None and tqdm is not None and (time.time() - start) > 5): + progress = True + pbar = tqdm(total=fusible_states, desc='Fusing states', initial=counter) + + if u in skip_nodes or v in skip_nodes or not isinstance(v, SDFGState) or not isinstance(u, SDFGState): + continue + candidate = {StateFusion.first_state: u, StateFusion.second_state: v} + sf = StateFusion() + sf.setup_match(cfg, id, -1, candidate, 0, override=True) + if sf.can_be_applied(cfg, 0, sd, permissive=permissive): + sf.apply(cfg, sd) + applied += 1 + counter += 1 + if progress: + pbar.update(1) + skip_nodes.add(u) + skip_nodes.add(v) + if applied == 0: + break if progress: pbar.close() return counter +def inline_loop_blocks(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int: + # Avoid import loops + from dace.transformation.interstate import LoopScopeInline + + counter = 0 + blocks = [(n, p) for n, p in sdfg.all_nodes_recursive() if isinstance(n, LoopScopeBlock)] + + for block, graph in optional_progressbar(reversed(blocks), title='Inlining Loops', n=len(blocks), progress=progress): + id = block.sdfg.sdfg_id + + # We have to reevaluate every time due to changing IDs + block_id = graph.node_id(block) + + candidate = { + LoopScopeInline.block: block, + } + inliner = LoopScopeInline() + inliner.setup_match(graph, id, block_id, candidate, 0, override=True) + if inliner.can_be_applied(graph, 0, block.sdfg, permissive=permissive): + inliner.apply(graph, block.sdfg) + counter += 1 + + return counter + + def inline_sdfgs(sdfg: SDFG, permissive: bool = False, progress: bool = None, multistate: bool = True) -> int: """ Inlines all possible nested SDFGs (or sub-SDFGs) using an optimized @@ -1269,18 +1294,18 @@ def inline_sdfgs(sdfg: SDFG, permissive: bool = False, progress: bool = None, mu for node, state in optional_progressbar(reversed(nsdfgs), title='Inlining SDFGs', n=len(nsdfgs), progress=progress): id = node.sdfg.sdfg_id - sd = state.parent + graph = state.parent # We have to reevaluate every time due to changing IDs - state_id = sd.node_id(state) + state_id = graph.node_id(state) if multistate: candidate = { InlineMultistateSDFG.nested_sdfg: node, } inliner = InlineMultistateSDFG() - inliner.setup_match(sd, id, state_id, candidate, 0, override=True) - if inliner.can_be_applied(state, 0, sd, permissive=permissive): - inliner.apply(state, sd) + inliner.setup_match(graph, id, state_id, candidate, 0, override=True) + if inliner.can_be_applied(state, 0, graph, permissive=permissive): + inliner.apply(state, state.sdfg) counter += 1 continue @@ -1288,9 +1313,9 @@ def inline_sdfgs(sdfg: SDFG, permissive: bool = False, progress: bool = None, mu InlineSDFG.nested_sdfg: node, } inliner = InlineSDFG() - inliner.setup_match(sd, id, state_id, candidate, 0, override=True) - if inliner.can_be_applied(state, 0, sd, permissive=permissive): - inliner.apply(state, sd) + inliner.setup_match(graph, id, state_id, candidate, 0, override=True) + if inliner.can_be_applied(state, 0, graph, permissive=permissive): + inliner.apply(state, state.sdfg) counter += 1 return counter @@ -1379,7 +1404,7 @@ def unique_node_repr(graph: Union[SDFGState, ScopeSubgraphView], node: Node) -> """ # Build a unique representation - sdfg = graph.parent + sdfg = graph.sdfg state = graph if isinstance(graph, SDFGState) else graph._graph return str(sdfg.sdfg_id) + "_" + str(sdfg.node_id(state)) + "_" + str(state.node_id(node)) @@ -1413,7 +1438,7 @@ def is_nonfree_sym_dependent(node: nd.AccessNode, desc: dt.Data, state: SDFGStat # is the View. n = get_view_node(state, node) if n and isinstance(n, nd.AccessNode): - d = state.parent.arrays[n.data] + d = state.sdfg.arrays[n.data] return is_nonfree_sym_dependent(n, d, state, fsymbols) elif isinstance(desc, dt.Array): if any(str(s) not in fsymbols for s in desc.free_symbols): @@ -1451,31 +1476,26 @@ def _traverse(scope: Node, symbols: Dict[str, dtypes.typeclass]): yield from _traverse(None, symbols) -def traverse_sdfg_with_defined_symbols( - sdfg: SDFG, - recursive: bool = False) -> Generator[Tuple[SDFGState, Node, Dict[str, dtypes.typeclass]], None, None]: - """ - Traverses the SDFG, its states and nodes, yielding the defined symbols and their types at each node. - - :return: A generator that yields tuples of (state, node in state, currently-defined symbols) - """ - # Start with global symbols - symbols = copy.copy(sdfg.symbols) - symbols.update({k: dt.create_datadescriptor(v).dtype for k, v in sdfg.constants.items()}) - for desc in sdfg.arrays.values(): - symbols.update({str(s): s.dtype for s in desc.free_symbols}) - +def _tswds_scope_block( + sdfg: SDFG, + scope: ScopeBlock, + symbols: Dict[str, dtypes.typeclass], + recursive: bool, +) -> Generator[Tuple[SDFGState, Node, Dict[str, dtypes.typeclass]], None, None]: # Add symbols from inter-state edges along the state machine - start_state = sdfg.start_state + start_block = scope.start_block visited = set() visited_edges = set() - for edge in sdfg.dfs_edges(start_state): + for edge in scope.dfs_edges(start_block): # Source -> inter-state definition -> Destination visited_edges.add(edge) # Source if edge.src not in visited: visited.add(edge.src) - yield from _tswds_state(sdfg, edge.src, symbols, recursive) + if isinstance(edge.src, SDFGState): + yield from _tswds_state(sdfg, edge.src, symbols, recursive) + else: + yield from _tswds_scope_block(sdfg, edge.src, symbols, recursive) # Add edge symbols into defined symbols issyms = edge.data.new_symbols(sdfg, symbols) @@ -1484,11 +1504,34 @@ def traverse_sdfg_with_defined_symbols( # Destination if edge.dst not in visited: visited.add(edge.dst) - yield from _tswds_state(sdfg, edge.dst, symbols, recursive) + if isinstance(edge.dst, SDFGState): + yield from _tswds_state(sdfg, edge.dst, symbols, recursive) + else: + yield from _tswds_scope_block(sdfg, edge.dst, symbols, recursive) # If there is only one state, the DFS will miss it - if start_state not in visited: - yield from _tswds_state(sdfg, start_state, symbols, recursive) + if start_block not in visited: + if isinstance(start_block, SDFGState): + yield from _tswds_state(sdfg, start_block, symbols, recursive) + else: + yield from _tswds_scope_block(sdfg, start_block, symbols, recursive) + + +def traverse_sdfg_with_defined_symbols( + sdfg: SDFG, + recursive: bool = False) -> Generator[Tuple[SDFGState, Node, Dict[str, dtypes.typeclass]], None, None]: + """ + Traverses the SDFG, its states and nodes, yielding the defined symbols and their types at each node. + + :return: A generator that yields tuples of (state, node in state, currently-defined symbols) + """ + # Start with global symbols + symbols = copy.copy(sdfg.symbols) + symbols.update({k: dt.create_datadescriptor(v).dtype for k, v in sdfg.constants.items()}) + for desc in sdfg.arrays.values(): + symbols.update({str(s): s.dtype for s in desc.free_symbols}) + + yield from _tswds_scope_block(sdfg, sdfg, symbols, recursive) def is_fpga_kernel(sdfg, state): diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index 0bb3e9a64e..a2976c760a 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -12,6 +12,7 @@ import dace from dace.sdfg import SDFG from dace.sdfg import graph as gr + from dace.sdfg.state import ControlFlowGraph from dace.memlet import Memlet ########################################### @@ -27,6 +28,145 @@ def validate(graph: 'dace.sdfg.graph.SubgraphView'): validate_state(graph) +def validate_cf_scope(sdfg: 'dace.sdfg.SDFG', + scope: 'ControlFlowGraph', + initialized_transients: Set[str], + symbols: dict, + references: Set[int] = None, + **context: bool): + from dace.sdfg import SDFGState + from dace.sdfg.scope import is_in_scope + + if len(scope.source_nodes()) > 1 and scope.start_block is None: + raise InvalidSDFGError("Starting block undefined", sdfg, None) + + in_default_scope = None + + # Check every state separately + start_block = scope.start_block + visited = set() + visited_edges = set() + # Run through blocks via DFS, ensuring that only the defined symbols are available for validation + for edge in scope.dfs_edges(start_block): + # Source -> inter-state definition -> Destination + ########################################## + visited_edges.add(edge) + + # Reference check + if id(edge) in references: + raise InvalidSDFGInterstateEdgeError( + f'Duplicate inter-state edge object detected: "{edge}". Please ' + 'copy objects rather than using multiple references to the same one', sdfg, scope.edge_id(edge)) + references.add(id(edge)) + if id(edge.data) in references: + raise InvalidSDFGInterstateEdgeError( + f'Duplicate inter-state edge object detected: "{edge}". Please ' + 'copy objects rather than using multiple references to the same one', sdfg, scope.edge_id(edge)) + references.add(id(edge.data)) + + # Source + if edge.src not in visited: + visited.add(edge.src) + validate_state(edge.src, scope.node_id(edge.src), sdfg, symbols, initialized_transients, references, + **context) + + ########################################## + # Edge + # Check inter-state edge for undefined symbols + undef_syms = set(edge.data.free_symbols) - set(symbols.keys()) + if len(undef_syms) > 0: + eid = scope.edge_id(edge) + raise InvalidSDFGInterstateEdgeError( + f'Undefined symbols in edge: {undef_syms}. Add those with ' + '`sdfg.add_symbol()` or define outside with `dace.symbol()`', sdfg, eid) + + # Validate inter-state edge names + issyms = edge.data.new_symbols(sdfg, symbols) + if any(not dtypes.validate_name(s) for s in issyms): + invalid = next(s for s in issyms if not dtypes.validate_name(s)) + eid = scope.edge_id(edge) + raise InvalidSDFGInterstateEdgeError("Invalid interstate symbol name %s" % invalid, sdfg, eid) + + # Ensure accessed data containers in assignments and conditions are accessible in this context + ise_memlets = edge.data.get_read_memlets(sdfg.arrays) + for memlet in ise_memlets: + container = memlet.data + if not _accessible(sdfg, container, context): + # Check context w.r.t. maps + if in_default_scope is None: # Lazy-evaluate in_default_scope + in_default_scope = False + if sdfg.parent_nsdfg_node is not None: + if is_in_scope(sdfg.parent_sdfg, sdfg.parent, sdfg.parent_nsdfg_node, + [dtypes.ScheduleType.Default]): + in_default_scope = True + if in_default_scope is False: + eid = scope.edge_id(edge) + raise InvalidSDFGInterstateEdgeError( + f'Trying to read an inaccessible data container "{container}" ' + f'(Storage: {sdfg.arrays[container].storage}) in host code interstate edge', sdfg, eid) + + # Add edge symbols into defined symbols + symbols.update(issyms) + + ########################################## + # Destination + if edge.dst not in visited: + visited.add(edge.dst) + if isinstance(edge.dst, SDFGState): + validate_state(edge.dst, scope.node_id(edge.dst), sdfg, symbols, initialized_transients, references, + **context) + else: + validate_cf_scope(sdfg, edge.dst, initialized_transients, symbols, references, **context) + # End of block DFS + + # If there is only one block, the DFS will miss it + if start_block not in visited: + if isinstance(start_block, SDFGState): + validate_state(start_block, scope.node_id(start_block), sdfg, symbols, initialized_transients, references, + **context) + else: + validate_cf_scope(sdfg, start_block, initialized_transients, symbols, references, **context) + + # Validate all inter-state edges (including self-loops not found by DFS) + for eid, edge in enumerate(scope.edges()): + if edge in visited_edges: + continue + + # Reference check + if id(edge) in references: + raise InvalidSDFGInterstateEdgeError( + f'Duplicate inter-state edge object detected: "{edge}". Please ' + 'copy objects rather than using multiple references to the same one', sdfg, eid) + references.add(id(edge)) + if id(edge.data) in references: + raise InvalidSDFGInterstateEdgeError( + f'Duplicate inter-state edge object detected: "{edge}". Please ' + 'copy objects rather than using multiple references to the same one', sdfg, eid) + references.add(id(edge.data)) + + issyms = edge.data.assignments.keys() + if any(not dtypes.validate_name(s) for s in issyms): + invalid = next(s for s in issyms if not dtypes.validate_name(s)) + raise InvalidSDFGInterstateEdgeError("Invalid interstate symbol name %s" % invalid, sdfg, eid) + + # Ensure accessed data containers in assignments and conditions are accessible in this context + ise_memlets = edge.data.get_read_memlets(sdfg.arrays) + for memlet in ise_memlets: + container = memlet.data + if not _accessible(sdfg, container, context): + # Check context w.r.t. maps + if in_default_scope is None: # Lazy-evaluate in_default_scope + in_default_scope = False + if sdfg.parent_nsdfg_node is not None: + if is_in_scope(sdfg.parent_sdfg, sdfg.parent, sdfg.parent_nsdfg_node, + [dtypes.ScheduleType.Default]): + in_default_scope = True + if in_default_scope is False: + raise InvalidSDFGInterstateEdgeError( + f'Trying to read an inaccessible data container "{container}" ' + f'(Storage: {sdfg.arrays[container].storage}) in host code interstate edge', sdfg, eid) + + def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context: bool): """ Verifies the correctness of an SDFG by applying multiple tests. @@ -42,7 +182,7 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context """ # Avoid import loop from dace.codegen.targets import fpga - from dace.sdfg.scope import is_devicelevel_gpu, is_devicelevel_fpga, is_in_scope + from dace.sdfg.scope import is_devicelevel_gpu, is_devicelevel_fpga references = references or set() @@ -58,11 +198,9 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context if not dtypes.validate_name(sdfg.name): raise InvalidSDFGError("Invalid name", sdfg, None) - if len(sdfg.source_nodes()) > 1 and sdfg.start_state is None: - raise InvalidSDFGError("Starting state undefined", sdfg, None) - - if len(set([s.label for s in sdfg.nodes()])) != len(sdfg.nodes()): - raise InvalidSDFGError("Found multiple states with the same name", sdfg, None) + all_blocks = set(sdfg.all_control_flow_blocks_recursive(recurse_into_sdfgs=False)) + if len(set([s.label for s in all_blocks])) != len(all_blocks): + raise InvalidSDFGError("Found multiple blocks with the same name", sdfg, None) # Validate data descriptors for name, desc in sdfg._arrays.items(): @@ -111,10 +249,7 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context # Check if SDFG is located within a GPU kernel context['in_gpu'] = is_devicelevel_gpu(sdfg, None, None) context['in_fpga'] = is_devicelevel_fpga(sdfg, None, None) - in_default_scope = None - # Check every state separately - start_state = sdfg.start_state initialized_transients = {'__pystate'} initialized_transients.update(sdfg.constants_prop.keys()) symbols = copy.deepcopy(sdfg.symbols) @@ -123,122 +258,8 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context for desc in sdfg.arrays.values(): for sym in desc.free_symbols: symbols[str(sym)] = sym.dtype - visited = set() - visited_edges = set() - # Run through states via DFS, ensuring that only the defined symbols - # are available for validation - for edge in sdfg.dfs_edges(start_state): - # Source -> inter-state definition -> Destination - ########################################## - visited_edges.add(edge) - - # Reference check - if id(edge) in references: - raise InvalidSDFGInterstateEdgeError( - f'Duplicate inter-state edge object detected: "{edge}". Please ' - 'copy objects rather than using multiple references to the same one', sdfg, sdfg.edge_id(edge)) - references.add(id(edge)) - if id(edge.data) in references: - raise InvalidSDFGInterstateEdgeError( - f'Duplicate inter-state edge object detected: "{edge}". Please ' - 'copy objects rather than using multiple references to the same one', sdfg, sdfg.edge_id(edge)) - references.add(id(edge.data)) - - # Source - if edge.src not in visited: - visited.add(edge.src) - validate_state(edge.src, sdfg.node_id(edge.src), sdfg, symbols, initialized_transients, references, - **context) - ########################################## - # Edge - # Check inter-state edge for undefined symbols - undef_syms = set(edge.data.free_symbols) - set(symbols.keys()) - if len(undef_syms) > 0: - eid = sdfg.edge_id(edge) - raise InvalidSDFGInterstateEdgeError( - f'Undefined symbols in edge: {undef_syms}. Add those with ' - '`sdfg.add_symbol()` or define outside with `dace.symbol()`', sdfg, eid) - - # Validate inter-state edge names - issyms = edge.data.new_symbols(sdfg, symbols) - if any(not dtypes.validate_name(s) for s in issyms): - invalid = next(s for s in issyms if not dtypes.validate_name(s)) - eid = sdfg.edge_id(edge) - raise InvalidSDFGInterstateEdgeError("Invalid interstate symbol name %s" % invalid, sdfg, eid) - - # Ensure accessed data containers in assignments and conditions are accessible in this context - ise_memlets = edge.data.get_read_memlets(sdfg.arrays) - for memlet in ise_memlets: - container = memlet.data - if not _accessible(sdfg, container, context): - # Check context w.r.t. maps - if in_default_scope is None: # Lazy-evaluate in_default_scope - in_default_scope = False - if sdfg.parent_nsdfg_node is not None: - if is_in_scope(sdfg.parent_sdfg, sdfg.parent, sdfg.parent_nsdfg_node, - [dtypes.ScheduleType.Default]): - in_default_scope = True - if in_default_scope is False: - eid = sdfg.edge_id(edge) - raise InvalidSDFGInterstateEdgeError( - f'Trying to read an inaccessible data container "{container}" ' - f'(Storage: {sdfg.arrays[container].storage}) in host code interstate edge', sdfg, eid) - - # Add edge symbols into defined symbols - symbols.update(issyms) - - ########################################## - # Destination - if edge.dst not in visited: - visited.add(edge.dst) - validate_state(edge.dst, sdfg.node_id(edge.dst), sdfg, symbols, initialized_transients, references, - **context) - # End of state DFS - - # If there is only one state, the DFS will miss it - if start_state not in visited: - validate_state(start_state, sdfg.node_id(start_state), sdfg, symbols, initialized_transients, references, - **context) - - # Validate all inter-state edges (including self-loops not found by DFS) - for eid, edge in enumerate(sdfg.edges()): - if edge in visited_edges: - continue - - # Reference check - if id(edge) in references: - raise InvalidSDFGInterstateEdgeError( - f'Duplicate inter-state edge object detected: "{edge}". Please ' - 'copy objects rather than using multiple references to the same one', sdfg, eid) - references.add(id(edge)) - if id(edge.data) in references: - raise InvalidSDFGInterstateEdgeError( - f'Duplicate inter-state edge object detected: "{edge}". Please ' - 'copy objects rather than using multiple references to the same one', sdfg, eid) - references.add(id(edge.data)) - - issyms = edge.data.assignments.keys() - if any(not dtypes.validate_name(s) for s in issyms): - invalid = next(s for s in issyms if not dtypes.validate_name(s)) - raise InvalidSDFGInterstateEdgeError("Invalid interstate symbol name %s" % invalid, sdfg, eid) - - # Ensure accessed data containers in assignments and conditions are accessible in this context - ise_memlets = edge.data.get_read_memlets(sdfg.arrays) - for memlet in ise_memlets: - container = memlet.data - if not _accessible(sdfg, container, context): - # Check context w.r.t. maps - if in_default_scope is None: # Lazy-evaluate in_default_scope - in_default_scope = False - if sdfg.parent_nsdfg_node is not None: - if is_in_scope(sdfg.parent_sdfg, sdfg.parent, sdfg.parent_nsdfg_node, - [dtypes.ScheduleType.Default]): - in_default_scope = True - if in_default_scope is False: - raise InvalidSDFGInterstateEdgeError( - f'Trying to read an inaccessible data container "{container}" ' - f'(Storage: {sdfg.arrays[container].storage}) in host code interstate edge', sdfg, eid) + validate_cf_scope(sdfg, sdfg, initialized_transients, symbols, references, **context) except InvalidSDFGError as ex: # If the SDFG is invalid, save it @@ -314,8 +335,7 @@ def validate_state(state: 'dace.sdfg.SDFGState', from dace.sdfg import utils as sdutil from dace.sdfg.scope import scope_contains_scope, is_devicelevel_gpu, is_devicelevel_fpga - sdfg = sdfg or state.parent - state_id = state_id or sdfg.node_id(state) + state_id = state_id or state.parent.node_id(state) symbols = symbols or {} initialized_transients = (initialized_transients if initialized_transients is not None else {'__pystate'}) references = references or set() @@ -337,13 +357,13 @@ def validate_state(state: 'dace.sdfg.SDFGState', if not dtypes.validate_name(state._label): raise InvalidSDFGError("Invalid state name", sdfg, state_id) - if state._parent != sdfg: - raise InvalidSDFGError("State does not point to the correct " - "parent", sdfg, state_id) + if state.sdfg != sdfg: + raise InvalidSDFGError("State does not point to the correct sdfg", sdfg, state_id) # Unreachable ######################################## - if (sdfg.number_of_nodes() > 1 and sdfg.in_degree(state) == 0 and sdfg.out_degree(state) == 0): + parent = state.sdfg + if (parent.number_of_nodes() > 1 and parent.in_degree(state) == 0 and parent.out_degree(state) == 0): raise InvalidSDFGError("Unreachable state", sdfg, state_id) if state.has_cycles(): diff --git a/dace/sdfg/work_depth_analysis/helpers.py b/dace/sdfg/work_depth_analysis/helpers.py index e592fd11b5..37de3298e4 100644 --- a/dace/sdfg/work_depth_analysis/helpers.py +++ b/dace/sdfg/work_depth_analysis/helpers.py @@ -34,7 +34,7 @@ def get_uuid(element, state=None): if isinstance(element, SDFG): return ids_to_string(element.sdfg_id) elif isinstance(element, SDFGState): - return ids_to_string(element.parent.sdfg_id, element.parent.node_id(element)) + return ids_to_string(element.sdfg.sdfg_id, element.sdfg.node_id(element)) elif isinstance(element, nodes.Node): return ids_to_string(state.parent.sdfg_id, state.parent.node_id(state), state.node_id(element)) else: diff --git a/dace/transformation/auto/auto_optimize.py b/dace/transformation/auto/auto_optimize.py index 54dbc8d4ac..50e8e021e5 100644 --- a/dace/transformation/auto/auto_optimize.py +++ b/dace/transformation/auto/auto_optimize.py @@ -73,7 +73,7 @@ def greedy_fuse(graph_or_subgraph: GraphViewType, # we are in graph or subgraph sdfg, graph, subgraph = None, None, None if isinstance(graph_or_subgraph, SDFGState): - sdfg = graph_or_subgraph.parent + sdfg = graph_or_subgraph.sdfg sdfg.apply_transformations_repeated(MapFusion, validate_all=validate_all) graph = graph_or_subgraph subgraph = SubgraphView(graph, graph.nodes()) @@ -196,7 +196,7 @@ def tile_wcrs(graph_or_subgraph: GraphViewType, validate_all: bool, prefer_parti return if not isinstance(graph, SDFGState): raise TypeError('Graph must be a state, an SDFG, or a subgraph of either') - sdfg = graph.parent + sdfg = graph.sdfg edges_to_consider: Set[Tuple[gr.MultiConnectorEdge[Memlet], nodes.MapEntry]] = set() for edge in graph_or_subgraph.edges(): diff --git a/dace/transformation/dataflow/double_buffering.py b/dace/transformation/dataflow/double_buffering.py index 8ff70a6355..6efe6543ca 100644 --- a/dace/transformation/dataflow/double_buffering.py +++ b/dace/transformation/dataflow/double_buffering.py @@ -128,7 +128,7 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): ############################## # Add initial reads to initial nested state initial_state: sd.SDFGState = nsdfg_node.sdfg.start_state - initial_state.set_label('%s_init' % map_entry.map.label) + initial_state.label = '%s_init' % map_entry.map.label for edge in edges_to_replace: initial_state.add_node(edge.src) rnode = edge.src @@ -152,7 +152,7 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): # Add the main state's contents to the last state, modifying # memlets appropriately. final_state: sd.SDFGState = nsdfg_node.sdfg.sink_nodes()[0] - final_state.set_label('%s_final_computation' % map_entry.map.label) + final_state.label = '%s_final_computation' % map_entry.map.label dup_nstate = copy.deepcopy(nstate) final_state.add_nodes_from(dup_nstate.nodes()) for e in dup_nstate.edges(): @@ -183,7 +183,7 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): nstate.add_edge(rnode, edge.src_conn, wnode, edge.dst_conn, new_memlet) - nstate.set_label('%s_double_buffered' % map_entry.map.label) + nstate.label = '%s_double_buffered' % map_entry.map.label # Divide by loop stride new_expr = symbolic.pystr_to_symbolic('((%s / %s) + 1) %% 2' % (map_param, map_rstride)) sd.replace(nstate, '__dace_db_param', new_expr) diff --git a/dace/transformation/dataflow/hbm_transform.py b/dace/transformation/dataflow/hbm_transform.py index 15076a0ca6..e4e2aa277c 100644 --- a/dace/transformation/dataflow/hbm_transform.py +++ b/dace/transformation/dataflow/hbm_transform.py @@ -83,7 +83,7 @@ def _update_memlet_hbm(state: SDFGState, inner_edge: graph.MultiConnectorEdge, # If the memlet already contains the distributed subset, ignore it # That's helpful because of inconsistencies when nesting and because # one can 'hint' the correct bank assignment when using this function - if len(mem.subset) == len(state.parent.arrays[this_node.data].shape): + if len(mem.subset) == len(state.sdfg.arrays[this_node.data].shape): return new_subset = subsets.Range([[inner_subset_index, inner_subset_index, 1]] + [x for x in mem.subset]) @@ -100,7 +100,7 @@ def _update_memlet_hbm(state: SDFGState, inner_edge: graph.MultiConnectorEdge, other_node, nd.NestedSDFG): # Ignore those and update them via propagation new_subset = subsets.Range.from_array( - state.parent.arrays[this_node.data]) + state.sdfg.arrays[this_node.data]) if isinstance(other_node, nd.AccessNode): fwtasklet = state.add_tasklet("fwtasklet", set(["_in"]), set(["_out"]), diff --git a/dace/transformation/dataflow/streaming_memory.py b/dace/transformation/dataflow/streaming_memory.py index 4cf40b30bf..93c86a0d0e 100644 --- a/dace/transformation/dataflow/streaming_memory.py +++ b/dace/transformation/dataflow/streaming_memory.py @@ -160,10 +160,10 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi while curstate is not None: if curstate.entry_node(node) is not None: return False - if curstate.parent.parent_nsdfg_node is None: + if curstate.sdfg.parent_nsdfg_node is None: break - node = curstate.parent.parent_nsdfg_node - curstate = curstate.parent.parent + node = curstate.sdfg.parent_nsdfg_node + curstate = curstate.sdfg.parent # Only one memlet path is allowed per outgoing/incoming edge edges = (graph.out_edges(access) if expr_index == 0 else graph.in_edges(access)) @@ -628,10 +628,10 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi while curstate is not None: if curstate.entry_node(node) is not None: return False - if curstate.parent.parent_nsdfg_node is None: + if curstate.sdfg.parent_nsdfg_node is None: break - node = curstate.parent.parent_nsdfg_node - curstate = curstate.parent.parent + node = curstate.sdfg.parent_nsdfg_node + curstate = curstate.sdfg.parent # Array must not be used anywhere else in the state if any(n is not access and n.data == access.data for n in graph.data_nodes()): diff --git a/dace/transformation/dataflow/warp_tiling.py b/dace/transformation/dataflow/warp_tiling.py index 211910eebf..810848b677 100644 --- a/dace/transformation/dataflow/warp_tiling.py +++ b/dace/transformation/dataflow/warp_tiling.py @@ -55,7 +55,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG) -> nodes.MapEntry: # Stride and offset all internal maps maps_to_stride = xfh.get_internal_scopes(graph, new_me, immediate=True) for nstate, nmap in maps_to_stride: - nsdfg = nstate.parent + nsdfg = nstate.sdfg nsdfg_node = nsdfg.parent_nsdfg_node # Map cannot be partitioned across a warp @@ -123,7 +123,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG) -> nodes.MapEntry: write = nstate.add_write(name) edge = nstate.add_nedge(read, write, copy.deepcopy(out_edge.data)) edge.data.wcr = None - xfh.state_fission(nsdfg, SubgraphView(nstate, [read, write])) + xfh.state_fission(SubgraphView(nstate, [read, write])) newnode = nstate.add_access(name) nstate.remove_edge(out_edge) diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 8986c4e37f..19490061ca 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -376,7 +376,7 @@ def nest_state_subgraph(sdfg: SDFG, SDFG. :raise ValueError: The subgraph is contained in more than one scope. """ - if state.parent != sdfg: + if state.sdfg != sdfg: raise KeyError('State does not belong to given SDFG') if subgraph is not state and subgraph.graph is not state: raise KeyError('Subgraph does not belong to given state') @@ -646,7 +646,7 @@ def nest_state_subgraph(sdfg: SDFG, return nested_sdfg -def state_fission(sdfg: SDFG, subgraph: graph.SubgraphView, label: Optional[str] = None) -> SDFGState: +def state_fission(subgraph: graph.SubgraphView, label: Optional[str] = None) -> SDFGState: """ Given a subgraph, adds a new SDFG state before the state that contains it, removes the subgraph from the original state, and connects the two states. @@ -656,7 +656,7 @@ def state_fission(sdfg: SDFG, subgraph: graph.SubgraphView, label: Optional[str] """ state: SDFGState = subgraph.graph - newstate = sdfg.add_state_before(state, label=label) + newstate = state.parent.add_state_before(state, label=label) # Save edges before removing nodes orig_edges = subgraph.edges() @@ -1002,7 +1002,7 @@ def simplify_state(state: SDFGState, remove_views: bool = False) -> MultiDiGraph :return: The MultiDiGraph object. """ - sdfg = state.parent + sdfg = state.sdfg # Copy the whole state G = MultiDiGraph() @@ -1226,7 +1226,7 @@ def contained_in(state: SDFGState, node: nodes.Node, scope: nodes.EntryNode) -> # A node is contained within itself if node is scope: return True - cursdfg = state.parent + cursdfg = state.sdfg curstate = state curscope = state.entry_node(node) while cursdfg is not None: @@ -1249,7 +1249,7 @@ def get_parent_map(state: SDFGState, node: Optional[nodes.Node] = None) -> Optio :param node: The node to test (optional). :return: A tuple of (entry node, state) or None. """ - cursdfg = state.parent + cursdfg = state.sdfg curstate = state curscope = node while cursdfg is not None: @@ -1335,8 +1335,8 @@ def can_run_state_on_fpga(state: SDFGState): return False # Streams have strict conditions due to code generator limitations - if (isinstance(node, nodes.AccessNode) and isinstance(graph.parent.arrays[node.data], data.Stream)): - nodedesc = graph.parent.arrays[node.data] + if (isinstance(node, nodes.AccessNode) and isinstance(graph.sdfg.arrays[node.data], data.Stream)): + nodedesc = graph.sdfg.arrays[node.data] sdict = graph.scope_dict() if nodedesc.storage in [ dtypes.StorageType.CPU_Heap, dtypes.StorageType.CPU_Pinned, dtypes.StorageType.CPU_ThreadLocal @@ -1348,7 +1348,7 @@ def can_run_state_on_fpga(state: SDFGState): return False # Arrays of streams cannot have symbolic size on FPGA - if symbolic.issymbolic(nodedesc.total_size, graph.parent.constants): + if symbolic.issymbolic(nodedesc.total_size, graph.sdfg.constants): return False # Streams cannot be unbounded on FPGA diff --git a/dace/transformation/interstate/__init__.py b/dace/transformation/interstate/__init__.py index 0bd168751c..f3fc18e273 100644 --- a/dace/transformation/interstate/__init__.py +++ b/dace/transformation/interstate/__init__.py @@ -12,6 +12,7 @@ from .loop_unroll import LoopUnroll from .loop_peeling import LoopPeeling from .loop_to_map import LoopToMap +from .scope_inline import LoopScopeInline from .move_loop_into_map import MoveLoopIntoMap from .trivial_loop_elimination import TrivialLoopElimination from .multistate_inline import InlineMultistateSDFG diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 8fb6600b76..3ed2bd4bea 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -423,7 +423,7 @@ def apply(self, _, sdfg: sd.SDFG): new_body = sdfg.add_state('single_state_body') nsdfg = SDFG("loop_body", constants=sdfg.constants_prop, parent=new_body) nsdfg.add_node(body, is_start_state=True) - body.parent = nsdfg + body.sdfg = nsdfg exit_state = nsdfg.add_state('exit') nsymbols = dict() for state in states: diff --git a/dace/transformation/interstate/loop_unroll.py b/dace/transformation/interstate/loop_unroll.py index 47d438a2fc..d6b08f4401 100644 --- a/dace/transformation/interstate/loop_unroll.py +++ b/dace/transformation/interstate/loop_unroll.py @@ -116,8 +116,8 @@ def instantiate_loop( # Replace iterate with value in each state for state in new_states: - state.set_label(state.label + '_' + itervar + '_' + - (state_suffix if state_suffix is not None else str(value))) + new_label = state.label + '_' + itervar + '_' + (state_suffix if state_suffix is not None else str(value)) + state.label = new_label state.replace(itervar, value) # Add subgraph to original SDFG diff --git a/dace/transformation/interstate/multistate_inline.py b/dace/transformation/interstate/multistate_inline.py index 74dd51a483..a4af63b482 100644 --- a/dace/transformation/interstate/multistate_inline.py +++ b/dace/transformation/interstate/multistate_inline.py @@ -158,15 +158,15 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): sdfg.append_exit_code(code.code, loc) # Environments - for nstate in nsdfg.nodes(): - for node in nstate.nodes(): - if isinstance(node, nodes.CodeNode): - node.environments |= nsdfg_node.environments + for node, _ in nsdfg.all_nodes_recursive(): + if isinstance(node, nodes.CodeNode): + node.environments |= nsdfg_node.environments # Symbols outer_symbols = {str(k): v for k, v in sdfg.symbols.items()} - for ise in sdfg.edges(): - outer_symbols.update(ise.data.new_symbols(sdfg, outer_symbols)) + for cf in sdfg.all_state_scopes_recursive(recurse_into_sdfgs=False): + for ise in cf.edges(): + outer_symbols.update(ise.data.new_symbols(sdfg, outer_symbols)) # Find original source/destination edges (there is only one edge per # connector, according to match) @@ -189,12 +189,14 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # Collect and modify interstate edges as necessary outer_assignments = set() - for e in sdfg.edges(): - outer_assignments |= e.data.assignments.keys() + for cf in sdfg.all_state_scopes_recursive(): + for e in cf.edges(): + outer_assignments |= e.data.assignments.keys() inner_assignments = set() - for e in nsdfg.edges(): - inner_assignments |= e.data.assignments.keys() + for cf in nsdfg.all_state_scopes_recursive(): + for e in cf.edges(): + inner_assignments |= e.data.assignments.keys() allnames = set(outer_symbols.keys()) | set(sdfg.arrays.keys()) assignments_to_replace = inner_assignments & (outer_assignments | allnames) @@ -235,8 +237,8 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # All transients become transients of the parent (if data already # exists, find new name) - for nstate in nsdfg.nodes(): - for node in nstate.nodes(): + for state in nsdfg.states(): + for node in state.nodes(): if isinstance(node, nodes.AccessNode): datadesc = nsdfg.arrays[node.data] if node.data not in transients and datadesc.transient: @@ -248,7 +250,7 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): transients[node.data] = name # All transients of edges between code nodes are also added to parent - for edge in nstate.edges(): + for edge in state.edges(): if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): if edge.data.data is not None: datadesc = nsdfg.arrays[edge.data.data] @@ -329,12 +331,12 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # e.dst_conn, e.data) # Make unique names for states - statenames = set(s.label for s in sdfg.nodes()) - for nstate in nsdfg.nodes(): - if nstate.label in statenames: - newname = data.find_new_name(nstate.label, statenames) - statenames.add(newname) - nstate.set_label(newname) + blocknames = set(s.label for s in sdfg.all_control_flow_blocks_recursive(recurse_into_sdfgs=False)) + for nblock in nsdfg.all_control_flow_blocks_recursive(recurse_into_sdfgs=False): + if nblock.label in blocknames: + newname = data.find_new_name(nblock.label, blocknames) + blocknames.add(newname) + nblock.label = newname ####################################################### # Add nested SDFG states into top-level SDFG @@ -352,19 +354,20 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): sinks = nsdfg.sink_nodes() # Reconnect state machine - for e in sdfg.in_edges(outer_state): - sdfg.add_edge(e.src, source, e.data) - for e in sdfg.out_edges(outer_state): + parent_graph = outer_state.parent + for e in parent_graph.in_edges(outer_state): + parent_graph.add_edge(e.src, source, e.data) + for e in parent_graph.out_edges(outer_state): for sink in sinks: - sdfg.add_edge(sink, e.dst, dc(e.data)) + parent_graph.add_edge(sink, e.dst, dc(e.data)) # Redirect sink incoming edges with a `False` condition to e.dst (return statements) - for e2 in sdfg.in_edges(sink): + for e2 in parent_graph.in_edges(sink): if e2.data.condition_sympy() == False: - sdfg.add_edge(e2.src, e.dst, InterstateEdge()) + parent_graph.add_edge(e2.src, e.dst, InterstateEdge()) # Modify start state as necessary if outer_start_state is outer_state: - sdfg.start_state = sdfg.node_id(source) + parent_graph.start_block = parent_graph.node_id(source) # TODO: Modify memlets by offsetting # If both source and sink nodes are inputs/outputs, reconnect once @@ -408,13 +411,16 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # e._data = helpers.unsqueeze_memlet( # e.data, outer_edge.data) - # Replace nested SDFG parents with new SDFG - for nstate in nsdfg.nodes(): - nstate.parent = sdfg - for node in nstate.nodes(): - if isinstance(node, nodes.NestedSDFG): - node.sdfg.parent_sdfg = sdfg - node.sdfg.parent_nsdfg_node = node + # Replace nested SDFG parents and SDFG pointers. + for n in nsdfg.nodes(): + n.parent = outer_state.parent + for block in nsdfg.all_control_flow_blocks_recursive(recurse_into_sdfgs=False): + block.sdfg = outer_state.sdfg + if isinstance(block, SDFGState): + for node in block.nodes(): + if isinstance(node, nodes.NestedSDFG): + node.sdfg.parent_sdfg = sdfg + node.sdfg.parent_nsdfg_node = node ####################################################### # Remove nested SDFG and state diff --git a/dace/transformation/interstate/scope_inline.py b/dace/transformation/interstate/scope_inline.py new file mode 100644 index 0000000000..c05516400a --- /dev/null +++ b/dace/transformation/interstate/scope_inline.py @@ -0,0 +1,81 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Inline all scope blocks in SDFGs. """ + +from typing import Any, Set, Optional + +from dace.frontend.python import astutils +from dace.sdfg import SDFG, InterstateEdge, SDFGState, ScopeBlock +from dace.sdfg import utils as sdutil +from dace.sdfg.nodes import CodeBlock +from dace.sdfg.state import LoopScopeBlock, ScopeBlock +from dace.transformation import transformation + + +class LoopScopeInline(transformation.MultiStateTransformation): + """ + Inlines a loop scope block into a legacy-style state machine. + """ + + block = transformation.PatternNode(LoopScopeBlock) + + @staticmethod + def annotates_memlets(): + return False + + @classmethod + def expressions(cls): + return [sdutil.node_path_graph(cls.block)] + + def can_be_applied(self, graph: ScopeBlock, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: + return True + + def apply(self, graph: ScopeBlock, sdfg: SDFG) -> Optional[int]: + parent: ScopeBlock = graph + + internal_start = self.block.start_block + + # Construct the basic loop state structure. + init_state = parent.add_state(self.block.label + '_init') + for b_edge in parent.in_edges(self.block): + parent.add_edge(b_edge.src, init_state, b_edge.data) + parent.remove_edge(b_edge) + + guard_state = parent.add_state(self.block.label + '_guard') + init_edge = InterstateEdge() + if self.block.init_statement is not None: + init_edge.assignments = { + self.block.loop_variable: self.block.init_statement.as_string.rpartition('=')[2].strip() + } + parent.add_edge(init_state, guard_state, init_edge) + + end_state = parent.add_state(self.block.label + '_end') + cond_expr = self.block.scope_condition.code + parent.add_edge(guard_state, end_state, + InterstateEdge(CodeBlock(astutils.negate_expr(cond_expr)).code)) + for a_edge in parent.out_edges(self.block): + parent.add_edge(end_state, a_edge.dst, a_edge.data) + parent.remove_edge(a_edge) + + last_loop_state = parent.add_state(self.block.label + '_loop') + loop_edge = InterstateEdge() + if self.block.update_statement is not None: + loop_edge.assignments = { + self.block.loop_variable: self.block.update_statement.as_string.rpartition('=')[2].strip() + } + parent.add_edge(last_loop_state, guard_state, loop_edge) + + to_connect: Set[SDFGState] = set() + for node in self.block.nodes(): + parent.add_node(node) + if self.block.out_degree(node) == 0: + to_connect.add(node) + for edge in self.block.edges(): + parent.add_edge(edge.src, edge.dst, edge.data) + + # Connect the loop states + parent.add_edge(guard_state, internal_start, + InterstateEdge(CodeBlock(cond_expr).code)) + for node in to_connect: + parent.add_edge(node, last_loop_state, InterstateEdge()) + + parent.remove_node(self.block) diff --git a/dace/transformation/interstate/sdfg_nesting.py b/dace/transformation/interstate/sdfg_nesting.py index fc3ebfbdca..8081472199 100644 --- a/dace/transformation/interstate/sdfg_nesting.py +++ b/dace/transformation/interstate/sdfg_nesting.py @@ -565,7 +565,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): # Fission state if necessary cc = utils.weakly_connected_component(state, node) if not any(n in cc for n in subgraph.nodes()): - helpers.state_fission(state.parent, cc) + helpers.state_fission(cc) for edge in removed_out_edges: # Find last access node that refers to this edge try: @@ -580,7 +580,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): cc = utils.weakly_connected_component(state, node) if not any(n in cc for n in subgraph.nodes()): cc2 = SubgraphView(state, [n for n in state.nodes() if n not in cc]) - state = helpers.state_fission(sdfg, cc2) + state = helpers.state_fission(cc2) ####################################################### # Remove nested SDFG node @@ -604,7 +604,7 @@ def _modify_access_to_access(self, """ Deals with access->access edges where both sides are non-transient. """ - nsdfg_node = nstate.parent.parent_nsdfg_node + nsdfg_node = nstate.sdfg.parent_nsdfg_node edges_to_ignore = edges_to_ignore or set() result = set() edges = input_edges diff --git a/dace/transformation/interstate/state_elimination.py b/dace/transformation/interstate/state_elimination.py index cbb5d7b957..70fbb2978b 100644 --- a/dace/transformation/interstate/state_elimination.py +++ b/dace/transformation/interstate/state_elimination.py @@ -2,14 +2,15 @@ """ State elimination transformations """ import networkx as nx -from typing import Dict, List, Set +from typing import Dict, Set -from dace import data as dt, dtypes, registry, sdfg, symbolic +from dace import data as dt, sdfg, symbolic +from dace.sdfg.graph import Edge from dace.properties import CodeBlock from dace.sdfg import nodes, SDFG, SDFGState, InterstateEdge +from dace.sdfg.state import ScopeBlock from dace.sdfg import utils as sdutil from dace.transformation import transformation -from dace.sdfg.analysis import cfg class EndStateElimination(transformation.MultiStateTransformation): @@ -47,12 +48,12 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return True - def apply(self, _, sdfg): + def apply(self, graph, sdfg): state = self.end_state # Handle orphan symbols (due to the deletion the incoming edge) - edge = sdfg.in_edges(state)[0] + edge = graph.in_edges(state)[0] sym_assign = edge.data.assignments.keys() - sdfg.remove_node(state) + graph.remove_node(state) # Remove orphan symbols for sym in sym_assign: if sym in sdfg.free_symbols: @@ -75,7 +76,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): state = self.start_state # The transformation applies only to nested SDFGs - if not graph.parent: + if not sdfg.parent_nsdfg_node: return False # Only empty states can be eliminated @@ -97,22 +98,22 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return False # Assignments that make descriptors into symbols cannot be eliminated for assign in edge.data.assignments.values(): - if graph.arrays.keys() & symbolic.free_symbols_and_functions(assign): + if sdfg.arrays.keys() & symbolic.free_symbols_and_functions(assign): return False return True - def apply(self, _, sdfg): + def apply(self, graph, sdfg): state = self.start_state # Move assignments to the nested SDFG node's symbol mappings node = sdfg.parent_nsdfg_node - edge = sdfg.out_edges(state)[0] + edge = graph.out_edges(state)[0] for k, v in edge.data.assignments.items(): node.symbol_mapping[k] = v - sdfg.remove_node(state) + graph.remove_node(state) -def _assignments_to_consider(sdfg, edge, is_constant=False): +def _assignments_to_consider(sdfg: SDFG, edge: Edge[InterstateEdge], is_constant: bool = False): assignments_to_consider = {} for var, assign in edge.data.assignments.items(): as_symbolic = symbolic.pystr_to_symbolic(assign) @@ -166,14 +167,14 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Otherwise, ensure the symbols are never set/used again in edges akeys = set(assignments_to_consider.keys()) - for e in sdfg.edges(): + for e in graph.edges(): if e is edge: continue if e.data.free_symbols & akeys: return False # If used in any state that is not the current one, fail - for s in sdfg.nodes(): + for s in graph.nodes(): if s is state: continue if s.free_symbols & akeys: @@ -181,9 +182,9 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return True - def apply(self, _, sdfg): + def apply(self, graph, sdfg): state = self.end_state - edge = sdfg.in_edges(state)[0] + edge = graph.in_edges(state)[0] # Since inter-state assignments that use an assigned value leads to # undefined behavior (e.g., {m: n, n: m}), we can replace each # assignment separately. @@ -199,7 +200,7 @@ def apply(self, _, sdfg): # Remove assignments from edge del edge.data.assignments[varname] - for e in sdfg.edges(): + for e in graph.edges(): if varname in e.data.free_symbols: break else: @@ -294,12 +295,12 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return True - def apply(self, _, sdfg): + def apply(self, graph, sdfg): fstate = self.first_state sstate = self.second_state - edge = sdfg.edges_between(fstate, sstate)[0].data - in_edge = sdfg.in_edges(fstate)[0].data + edge = graph.edges_between(fstate, sstate)[0].data + in_edge = graph.in_edges(fstate)[0].data to_consider = _alias_assignments(sdfg, edge) @@ -496,7 +497,7 @@ class TrueConditionElimination(transformation.MultiStateTransformation): def expressions(cls): return [sdutil.node_path_graph(cls.state_a, cls.state_b)] - def can_be_applied(self, graph: SDFG, expr_index, sdfg: SDFG, permissive=False): + def can_be_applied(self, graph, expr_index, sdfg, permissive=False): a: SDFGState = self.state_a b: SDFGState = self.state_b # Directed graph has only one edge between two nodes @@ -512,10 +513,10 @@ def can_be_applied(self, graph: SDFG, expr_index, sdfg: SDFG, permissive=False): return False - def apply(self, _, sdfg: SDFG): + def apply(self, graph, sdfg): a: SDFGState = self.state_a b: SDFGState = self.state_b - edge = sdfg.edges_between(a, b)[0] + edge = graph.edges_between(a, b)[0] edge.data.condition = CodeBlock("1") @@ -531,7 +532,7 @@ class FalseConditionElimination(transformation.MultiStateTransformation): def expressions(cls): return [sdutil.node_path_graph(cls.state_a, cls.state_b)] - def can_be_applied(self, graph: SDFG, expr_index, sdfg: SDFG, permissive=False): + def can_be_applied(self, graph, expr_index, sdfg, permissive=False): a: SDFGState = self.state_a b: SDFGState = self.state_b @@ -556,8 +557,8 @@ def can_be_applied(self, graph: SDFG, expr_index, sdfg: SDFG, permissive=False): return False - def apply(self, _, sdfg: SDFG): + def apply(self, graph, sdfg): a: SDFGState = self.state_a b: SDFGState = self.state_b - edge = sdfg.edges_between(a, b)[0] + edge = graph.edges_between(a, b)[0] sdfg.remove_edge(edge) diff --git a/dace/transformation/interstate/state_fusion.py b/dace/transformation/interstate/state_fusion.py index 6db62a097e..6cff57e339 100644 --- a/dace/transformation/interstate/state_fusion.py +++ b/dace/transformation/interstate/state_fusion.py @@ -1,15 +1,16 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. """ State fusion transformation """ -from typing import Dict, List, Set +from typing import Dict, List, Set, Union import networkx as nx -from dace import data as dt, dtypes, registry, sdfg, subsets +from dace import data as dt +from dace import sdfg, subsets from dace.config import Config -from dace.sdfg import nodes +from dace.sdfg import SDFG, nodes from dace.sdfg import utils as sdutil -from dace.sdfg.state import SDFGState +from dace.sdfg.state import SDFGState, ScopeBlock from dace.transformation import transformation @@ -263,13 +264,13 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): first_input = {node for node in first_state.source_nodes() if isinstance(node, nodes.AccessNode)} first_output = { node - for node in first_state.scope_children()[None] + for node in top_level_nodes(first_state) if isinstance(node, nodes.AccessNode) and node not in first_input } second_input = {node for node in second_state.source_nodes() if isinstance(node, nodes.AccessNode)} second_output = { node - for node in second_state.scope_children()[None] + for node in top_level_nodes(second_state) if isinstance(node, nodes.AccessNode) and node not in second_input } @@ -454,33 +455,33 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return True - def apply(self, _, sdfg): + def apply(self, graph: Union[ScopeBlock, SDFGState], sdfg: SDFG): first_state: SDFGState = self.first_state second_state: SDFGState = self.second_state # Remove interstate edge(s) - edges = sdfg.edges_between(first_state, second_state) + edges = graph.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: - for src, dst, other_data in sdfg.in_edges(first_state): + for src, dst, other_data in graph.in_edges(first_state): other_data.assignments.update(edge.data.assignments) - sdfg.remove_edge(edge) + graph.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): - sdutil.change_edge_dest(sdfg, first_state, second_state) - sdfg.remove_node(first_state) - if sdfg.start_state == first_state: - sdfg.start_state = sdfg.node_id(second_state) + sdutil.change_edge_dest(graph, first_state, second_state) + graph.remove_node(first_state) + if graph.start_block == first_state: + graph.start_block = graph.node_id(second_state) return # Special case 2: second state is empty if second_state.is_empty(): - sdutil.change_edge_src(sdfg, second_state, first_state) - sdutil.change_edge_dest(sdfg, second_state, first_state) - sdfg.remove_node(second_state) - if sdfg.start_state == second_state: - sdfg.start_state = sdfg.node_id(first_state) + sdutil.change_edge_src(graph, second_state, first_state) + sdutil.change_edge_dest(graph, second_state, first_state) + graph.remove_node(second_state) + if graph.start_block == second_state: + graph.start_block = graph.node_id(first_state) return # Normal case: both states are not empty @@ -488,7 +489,6 @@ def apply(self, _, sdfg): # Find source/sink (data) nodes first_input = [node for node in first_state.source_nodes() if isinstance(node, nodes.AccessNode)] first_output = [node for node in first_state.sink_nodes() if isinstance(node, nodes.AccessNode)] - second_input = [node for node in second_state.source_nodes() if isinstance(node, nodes.AccessNode)] top2 = top_level_nodes(second_state) @@ -562,7 +562,7 @@ def apply(self, _, sdfg): merged_nodes.add(n) # Redirect edges and remove second state - sdutil.change_edge_src(sdfg, second_state, first_state) - sdfg.remove_node(second_state) - if sdfg.start_state == second_state: - sdfg.start_state = sdfg.node_id(first_state) + sdutil.change_edge_src(graph, second_state, first_state) + graph.remove_node(second_state) + if graph.start_block == second_state: + graph.start_block = graph.node_id(first_state) diff --git a/dace/transformation/pass_pipeline.py b/dace/transformation/pass_pipeline.py index 4e16bb6207..cc5b943fd9 100644 --- a/dace/transformation/pass_pipeline.py +++ b/dace/transformation/pass_pipeline.py @@ -3,7 +3,7 @@ API for SDFG analysis and manipulation Passes, as well as Pipelines that contain multiple dependent passes. """ from dace import properties, serialize -from dace.sdfg import SDFG, SDFGState, graph as gr, nodes, utils as sdutil +from dace.sdfg import SDFG, SDFGState, graph as gr, nodes, utils as sdutil, ScopeBlock from enum import Flag, auto from typing import Any, Dict, Iterator, List, Optional, Set, Type, Union @@ -307,6 +307,58 @@ def apply(self, scope: nodes.EntryNode, state: SDFGState, pipeline_results: Dict raise NotImplementedError +@properties.make_properties +class ControlFlowScopePass(Pass): + """ + A specialized Pass type that applies to each control flow scope (i.e., CFG) separately. Such a pass is + realized by implementing the ``apply`` method, which accepts a CFG and the SDFG it belongs to. + + :see: Pass + """ + + CATEGORY: str = 'Helper' + + def apply_pass( + self, + sdfg: SDFG, + pipeline_results: Dict[str, Any], + **kwargs, + ) -> Optional[Dict[nodes.EntryNode, Optional[Any]]]: + """ + Applies the pass to the CFGs of the given SDFG by calling ``apply`` on each CFG. + + :param sdfg: The SDFG to apply the pass to. + :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass + results as ``{Pass subclass name: returned object from pass}``. If not run in a + pipeline, an empty dictionary is expected. + :return: A dictionary of ``{entry node: return value}`` for visited CFGs with a non-None return value, or None + if nothing was returned. + """ + result = {} + for scope_block in sdfg.all_state_scopes_recursive(recurse_into_sdfgs=False): + retval = self.apply(scope_block, scope_block if isinstance(scope_block, SDFG) else scope_block.sdfg, + pipeline_results, **kwargs) + if retval is not None: + result[scope_block] = retval + + if not result: + return None + return result + + def apply(self, scope_block: ScopeBlock, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Any]: + """ + Applies this pass on the given scope. + + :param scope_block: The control flow scope block to apply the pass to. + :param sdfg: The parent SDFG of the given scope. + :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass + results as ``{Pass subclass name: returned object from pass}``. If not run in a + pipeline, an empty dictionary is expected. + :return: Some object if pass was applied, or None if nothing changed. + """ + raise NotImplementedError + + @dataclass @properties.make_properties class Pipeline(Pass): diff --git a/dace/transformation/passes/analysis.py b/dace/transformation/passes/analysis.py index 86e1cde062..b08206b06e 100644 --- a/dace/transformation/passes/analysis.py +++ b/dace/transformation/passes/analysis.py @@ -1,10 +1,12 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +import copy from collections import defaultdict from dace.transformation import pass_pipeline as ppl from dace import SDFG, SDFGState, properties, InterstateEdge from dace.sdfg.graph import Edge -from dace.sdfg import nodes as nd +from dace.sdfg import nodes as nd, utils as sdutils +from dace.sdfg.state import ControlFlowBlock, ScopeBlock, LoopScopeBlock from dace.sdfg.analysis import cfg from typing import Dict, Set, Tuple, Any, Optional, Union import networkx as nx @@ -23,6 +25,49 @@ class StateReachability(ppl.Pass): CATEGORY: str = 'Analysis' + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.Nothing + + def should_reapply(self, modified: ppl.Modifies) -> bool: + # If anything was modified, reapply + return modified & ppl.Modifies.States + + def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]]: + """ + :return: A dictionary mapping each control flow block to its other reachable control flow blocks. + """ + reachable: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = {} + for sdfg in top_sdfg.all_sdfgs_recursive(): + result: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = {} + + # In networkx this is currently implemented naively for directed graphs. + # The implementation below is faster + # tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) + for scope in sdfg.all_state_scopes_recursive(recurse_into_sdfgs=False): + for n, v in reachable_nodes(scope.nx): + result[n] = set(v) + + for k, _ in result.items(): + parent = k.parent + if parent is not None: + # Everything in a loop can also reach anything else in the loop. + # TODO: Unless it's a break or return state. + if isinstance(parent, LoopScopeBlock): + result[k].update(parent.nodes()) + + reachable[sdfg.sdfg_id] = result + + return reachable + + +@properties.make_properties +class LegacyStateReachability(ppl.Pass): + """ + Evaluates state reachability (which other states can be executed after each state). + """ + + CATEGORY: str = 'Analysis' + def modifies(self) -> ppl.Modifies: return ppl.Modifies.Nothing @@ -35,17 +80,38 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Set[SDFGSta :return: A dictionary mapping each state to its other reachable states. """ reachable: Dict[int, Dict[SDFGState, Set[SDFGState]]] = {} - for sdfg in top_sdfg.all_sdfgs_recursive(): + + # Convert any scope blocks to old-school state machines for now. + sdfg = copy.deepcopy(top_sdfg) + translation: Dict[int, Dict[str, SDFGState]] = {} + orig_states: Dict[int, Set[str]] = {} + for sd in top_sdfg.all_sdfgs_recursive(): + orig_states[sd.sdfg_id] = set() + translation[sd.sdfg_id] = {} + for scope in sd.all_state_scopes_recursive(recurse_into_sdfgs=False): + for state in scope.nodes(): + if isinstance(state, SDFGState): + orig_states[sd.sdfg_id].add(state.name) + translation[sd.sdfg_id][state.name] = state + + sdutils.inline_loop_blocks(sdfg) + + for sd in sdfg.all_sdfgs_recursive(): result: Dict[SDFGState, Set[SDFGState]] = {} + inlined_states = set([s.name for s in sd.all_states_recursive()]) + difference = inlined_states - orig_states[sd.sdfg_id] + # In networkx this is currently implemented naively for directed graphs. # The implementation below is faster # tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) + for n, v in reachable_nodes(sd.nx): + if n.name not in difference: + result[translation[sd.sdfg_id][n.name]] = set( + [translation[sd.sdfg_id][r.name] for r in v.keys() if r.name not in difference] + ) - for n, v in reachable_nodes(sdfg.nx): - result[n] = set(v) - - reachable[sdfg.sdfg_id] = result + reachable[sd.sdfg_id] = result return reachable @@ -62,8 +128,6 @@ def _single_shortest_path_length_no_self(adj, source): Adjacency dict or view firstlevel : dict starting nodes, e.g. {source: 1} or {target: 1} - cutoff : int or float - level at which we stop the process """ firstlevel = {source: 1} @@ -101,7 +165,7 @@ def reachable_nodes(G): @properties.make_properties class SymbolAccessSets(ppl.Pass): """ - Evaluates symbol access sets (which symbols are read/written in each state or interstate edge). + Evaluates symbol access sets (which symbols are read/written in each control flow block or interstate edge). """ CATEGORY: str = 'Analysis' @@ -113,31 +177,33 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: # If anything was modified, reapply return modified & ppl.Modifies.States | ppl.Modifies.Edges | ppl.Modifies.Symbols | ppl.Modifies.Nodes - def apply_pass(self, top_sdfg: SDFG, - _) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]: + def apply_pass(self, + top_sdfg: SDFG, + _) -> Dict[int, Dict[Union[ControlFlowBlock, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]: """ - :return: A dictionary mapping each state to a tuple of its (read, written) data descriptors. + :return: A mapping of control flow blocks and interstate edges to a tuple of used (read, written) symbols. """ - top_result: Dict[int, Dict[SDFGState, Tuple[Set[str], Set[str]]]] = {} + top_result: Dict[int, Dict[Union[ControlFlowBlock, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): adesc = set(sdfg.arrays.keys()) - result: Dict[SDFGState, Tuple[Set[str], Set[str]]] = {} - for state in sdfg.nodes(): - readset = state.free_symbols - # No symbols may be written to inside states. - result[state] = (readset, set()) - for oedge in sdfg.out_edges(state): - edge_readset = oedge.data.read_symbols() - adesc - edge_writeset = set(oedge.data.assignments.keys()) - result[oedge] = (edge_readset, edge_writeset) - top_result[sdfg.sdfg_id] = result + result: Dict[Union[ControlFlowBlock, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]] = {} + for cfg in sdfg.all_state_scopes_recursive(recurse_into_sdfgs=False): + for block in cfg.nodes(): + readset = block.free_symbols + # No symbols may be written to inside states. + result[block] = (readset, set()) + for oedge in sdfg.out_edges(block): + edge_readset = oedge.data.read_symbols() - adesc + edge_writeset = set(oedge.data.assignments.keys()) + result[oedge] = (edge_readset, edge_writeset) + top_result[sdfg.sdfg_id] = result return top_result @properties.make_properties class AccessSets(ppl.Pass): """ - Evaluates memory access sets (which arrays/data descriptors are read/written in each state). + Evaluates memory access sets (which arrays/data descriptors are read/written in each control flow block). """ CATEGORY: str = 'Analysis' @@ -149,26 +215,32 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: # If anything was modified, reapply return modified & ppl.Modifies.AccessNodes - def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Tuple[Set[str], Set[str]]]]: + def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]]]: """ - :return: A dictionary mapping each state to a tuple of its (read, written) data descriptors. + :return: A dictionary mapping each control flow block to a tuple of its (read, written) data descriptors. """ - top_result: Dict[int, Dict[SDFGState, Tuple[Set[str], Set[str]]]] = {} + top_result: Dict[int, Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): - result: Dict[SDFGState, Tuple[Set[str], Set[str]]] = {} - for state in sdfg.nodes(): + result: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]] = {} + for state in sdfg.states(): readset, writeset = set(), set() for anode in state.data_nodes(): if state.in_degree(anode) > 0: writeset.add(anode.data) if state.out_degree(anode) > 0: readset.add(anode.data) - result[state] = (readset, writeset) + for scope in sdfg.all_state_scopes_recursive(recurse_into_sdfgs=False): + readset, writeset = set(), set() + for substate in scope.all_states_recursive(): + readset.update(result[substate][0]) + writeset.update(result[substate][1]) + result[scope] = (readset, writeset) + # Edges that read from arrays add to both ends' access sets anames = sdfg.arrays.keys() - for e in sdfg.edges(): + for e in sdfg.all_interstate_edges_recursive(recurse_into_sdfgs=False): fsyms = e.data.free_symbols & anames if fsyms: result[e.src][0].update(fsyms) @@ -201,13 +273,13 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[SDFGState]]]: for sdfg in top_sdfg.all_sdfgs_recursive(): result: Dict[str, Set[SDFGState]] = defaultdict(set) - for state in sdfg.nodes(): + for state in sdfg.states(): for anode in state.data_nodes(): result[anode.data].add(state) # Edges that read from arrays add to both ends' access sets anames = sdfg.arrays.keys() - for e in sdfg.edges(): + for e in sdfg.all_interstate_edges_recursive(recurse_into_sdfgs=False): fsyms = e.data.free_symbols & anames for access in fsyms: result[access].update({e.src, e.dst}) @@ -242,7 +314,7 @@ def apply_pass(self, top_sdfg: SDFG, for sdfg in top_sdfg.all_sdfgs_recursive(): result: Dict[str, Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]]] = defaultdict( lambda: defaultdict(lambda: [set(), set()])) - for state in sdfg.nodes(): + for state in sdfg.states(): for anode in state.data_nodes(): if state.in_degree(anode) > 0: result[anode.data][state][1].add(anode) @@ -253,10 +325,10 @@ def apply_pass(self, top_sdfg: SDFG, @properties.make_properties -class SymbolWriteScopes(ppl.Pass): +class SymbolWriteScopes(ppl.Pass): # TODO: adapt """ For each symbol, create a dictionary mapping each writing interstate edge to that symbol to the set of interstate - edges and states reading that symbol that are dominated by that write. + edges and control flow blocks reading that symbol that are dominated by that write. """ CATEGORY: str = 'Analysis' @@ -275,7 +347,7 @@ def _find_dominating_write(self, sym: str, read: Union[SDFGState, Edge[Interstat state_idom: Dict[SDFGState, SDFGState]) -> Optional[Edge[InterstateEdge]]: last_state: SDFGState = read if isinstance(read, SDFGState) else read.src - in_edges = last_state.parent.in_edges(last_state) + in_edges = last_state.sdfg.in_edges(last_state) deg = len(in_edges) if deg == 0: return None @@ -285,7 +357,7 @@ def _find_dominating_write(self, sym: str, read: Union[SDFGState, Edge[Interstat write_isedge = None n_state = state_idom[last_state] if state_idom[last_state] != last_state else None while n_state is not None and write_isedge is None: - oedges = n_state.parent.out_edges(n_state) + oedges = n_state.sdfg.out_edges(n_state) odeg = len(oedges) if odeg == 1: if any([sym == k for k in oedges[0].data.assignments.keys()]): @@ -293,7 +365,7 @@ def _find_dominating_write(self, sym: str, read: Union[SDFGState, Edge[Interstat else: dom_edge = None for cand in oedges: - if nxsp.has_path(n_state.parent.nx, cand.dst, last_state): + if nxsp.has_path(n_state.sdfg.nx, cand.dst, last_state): if dom_edge is not None: dom_edge = None break @@ -311,9 +383,9 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[int, idom = nx.immediate_dominators(sdfg.nx, sdfg.start_state) all_doms = cfg.all_dominators(sdfg, idom) - symbol_access_sets: Dict[Union[SDFGState, Edge[InterstateEdge]], - Tuple[Set[str], - Set[str]]] = pipeline_results[SymbolAccessSets.__name__][sdfg.sdfg_id] + symbol_access_sets: Dict[ + Union[ControlFlowBlock, Edge[InterstateEdge]], Tuple[Set[str], Set[str]] + ] = pipeline_results[SymbolAccessSets.__name__][sdfg.sdfg_id] state_reach: Dict[SDFGState, Set[SDFGState]] = pipeline_results[StateReachability.__name__][sdfg.sdfg_id] for read_loc, (reads, _) in symbol_access_sets.items(): @@ -331,7 +403,7 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[int, dominators = all_doms[write.dst] reach = state_reach[write.dst] for dom in dominators: - iedges = dom.parent.in_edges(dom) + iedges = dom.sdfg.in_edges(dom) if len(iedges) == 1 and iedges[0] in result[sym]: other_accesses = result[sym][iedges[0]] coarsen = False @@ -357,7 +429,7 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[int, @properties.make_properties -class ScalarWriteShadowScopes(ppl.Pass): +class ScalarWriteShadowScopes(ppl.Pass): # TODO: Adapt """ For each scalar or array of size 1, create a dictionary mapping writes to that data container to the set of reads and writes that are dominated by that write. diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index 9cec6d11af..60c6a3f81d 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -8,6 +8,7 @@ from dace.transformation import pass_pipeline as ppl from dace.cli.progress import optional_progressbar from dace import SDFG, SDFGState, dtypes, symbolic, properties +from dace.sdfg.state import ScopeBlock, ControlFlowBlock, LoopScopeBlock from typing import Any, Dict, Set, Optional, Tuple @@ -36,11 +37,11 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: # If anything was modified, reapply return modified != ppl.Modifies.Nothing - def should_apply(self, sdfg: SDFG) -> bool: + def should_apply(self, scope_block: ScopeBlock) -> bool: """ - Fast check (O(m)) whether the pass should early-exit without traversing the SDFG. + Fast check (O(m)) whether the pass should early-exit without traversing the scope. """ - for edge in sdfg.edges(): + for edge in scope_block.edges(): # If there are no assignments, there are no constants to propagate if len(edge.data.assignments) == 0: continue @@ -50,39 +51,36 @@ def should_apply(self, sdfg: SDFG) -> bool: return False - def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = None) -> Optional[Set[str]]: - """ - Propagates constants throughout the SDFG. - - :param sdfg: The SDFG to modify. - :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass - results as ``{Pass subclass name: returned object from pass}``. If not run in a - pipeline, an empty dictionary is expected. - :param initial_symbols: If not None, sets values of initial symbols. - :return: A set of propagated constants, or None if nothing was changed. - """ + def apply_to_scope(self, + sdfg: SDFG, + scope_block: ScopeBlock, + initial_symbols: Optional[Dict[str, Any]] = None, + do_not_remove: Set[str] = None) -> Tuple[Set[str], Set[str], Dict[str, Any]]: initial_symbols = initial_symbols or {} + do_not_remove = do_not_remove or set() + + # Keep track of replaced and ambiguous symbols + symbols_replaced: Dict[str, Any] = {} + remaining_unknowns: Set[str] = set() # Early exit if no constants can be propagated - if not initial_symbols and not self.should_apply(sdfg): + if not initial_symbols and not self.should_apply(scope_block): result = {} else: # Trace all constants and symbols through states - per_state_constants: Dict[SDFGState, Dict[str, Any]] = self.collect_constants(sdfg, initial_symbols) - - # Keep track of replaced and ambiguous symbols - symbols_replaced: Dict[str, Any] = {} - remaining_unknowns: Set[str] = set() + per_block_constants: Dict[ControlFlowBlock, Dict[str, Any]] = self.collect_constants(scope_block, + initial_symbols) # Collect symbols from symbol-dependent data descriptors # If there can be multiple values over the SDFG, the symbols are not propagated - desc_symbols, multivalue_desc_symbols = self._find_desc_symbols(sdfg, per_state_constants) + desc_symbols, multivalue_desc_symbols = self._find_desc_symbols(sdfg, per_block_constants) # Replace constants per state - for state, mapping in optional_progressbar(per_state_constants.items(), + for block, mapping in optional_progressbar(per_block_constants.items(), 'Propagating constants', - n=len(per_state_constants), + n=len(per_block_constants), progress=self.progress): + block: ControlFlowBlock = block remaining_unknowns.update( {k for k, v in mapping.items() if v is _UnknownValue or k in multivalue_desc_symbols}) @@ -97,53 +95,77 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = symbols_replaced.update(mapping) # Replace in state contents - state.replace_dict(mapping) + block.replace_dict(mapping) # Replace in outgoing edges as well - for e in sdfg.out_edges(state): + for e in scope_block.out_edges(block): e.data.replace_dict(mapping, replace_keys=False) # Gather initial propagated symbols result = {k: v for k, v in symbols_replaced.items() if k not in remaining_unknowns} # Remove single-valued symbols from data descriptors (e.g., symbolic array size) - sdfg.replace_dict({k: v - for k, v in result.items() if k in desc_symbols}, - replace_in_graph=False, - replace_keys=False) + scope_block.replace_dict({k: v + for k, v in result.items() if k in desc_symbols}, + replace_in_graph=False, + replace_keys=False) # Remove constant symbol assignments in interstate edges - for edge in sdfg.edges(): + for edge in scope_block.edges(): intersection = result & edge.data.assignments.keys() for sym in intersection: - del edge.data.assignments[sym] + if sym not in do_not_remove: + del edge.data.assignments[sym] # If symbols are never unknown any longer, remove from SDFG fsyms = sdfg.used_symbols(all_symbols=False) result = {k: v for k, v in result.items() if k not in fsyms} for sym in result: - if sym in sdfg.symbols: + if sym in sdfg.symbols and not sym in do_not_remove: # Remove from symbol repository and nested SDFG symbol mapipng sdfg.remove_symbol(sym) result = set(result.keys()) - if self.recursive: - # Change result to set of tuples - sid = sdfg.sdfg_id - result = set((sid, sym) for sym in result) - - for state in sdfg.nodes(): - for node in state.nodes(): - if isinstance(node, nodes.NestedSDFG): - nested_id = node.sdfg.sdfg_id - const_syms = {k: v for k, v in node.symbol_mapping.items() if not symbolic.issymbolic(v)} - internal = self.apply_pass(node.sdfg, _, const_syms) - if internal: - for nid, removed in internal: - result.add((nid, removed)) - # Remove symbol mapping if constant was completely propagated - if nid == nested_id and removed in node.symbol_mapping: - del node.symbol_mapping[removed] + return result, remaining_unknowns, symbols_replaced + + def apply_pass(self, + sdfg: SDFG, + pipeline_results: Dict[str, Any], + initial_symbols: Optional[Dict[str, Any]] = None) -> Optional[Set[str]]: + """ + Propagates constants throughout the SDFG. + + :param sdfg: The SDFG to modify. + :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass + results as ``{Pass subclass name: returned object from pass}``. If not run in a + pipeline, an empty dictionary is expected. + :param initial_symbols: If not None, sets values of initial symbols. + :return: A set of propagated constants, or None if nothing was changed. + """ + + result = set() + + for scope in sdfg.all_state_scopes_recursive(recurse_into_sdfgs=False): + scope_res, scope_unknown, scope_replaced = self.apply_to_scope(sdfg, scope, initial_symbols) + pass + + #if self.recursive: + # # Change result to set of tuples + # sid = sdfg.sdfg_id + # result = set((sid, sym) for sym in result) + + # for state in scope_block.all_states_recursive(): + # for node in state.nodes(): + # if isinstance(node, nodes.NestedSDFG): + # nested_id = node.sdfg.sdfg_id + # const_syms = {k: v for k, v in node.symbol_mapping.items() if not symbolic.issymbolic(v)} + # internal = self.apply_pass(node.sdfg, pipeline_results, initial_symbols=const_syms) + # if internal: + # for nid, removed in internal: + # result.add((nid, removed)) + # # Remove symbol mapping if constant was completely propagated + # if nid == nested_id and removed in node.symbol_mapping: + # del node.symbol_mapping[removed] # Return result if not result: @@ -154,15 +176,16 @@ def report(self, pass_retval: Set[str]) -> str: return f'Propagated {len(pass_retval)} constants.' def collect_constants(self, - sdfg: SDFG, - initial_symbols: Optional[Dict[str, Any]] = None) -> Dict[SDFGState, Dict[str, Any]]: + scope: ScopeBlock, + initial_symbols: Optional[Dict[str, Any]] = None) -> Dict[ControlFlowBlock, Dict[str, Any]]: """ - Finds all constants and constant-assigned symbols in the SDFG for each state. + Finds all constants and constant-assigned symbols in the scope for each block. - :param sdfg: The SDFG to traverse. + :param scope: The scope to traverse. :param initial_symbols: If not None, sets values of initial symbols. - :return: A dictionary mapping an SDFG state to a mapping of constants and their corresponding values. + :return: A dictionary mapping an control flow blocks to a mapping of constants and their corresponding values. """ + sdfg = scope if isinstance(scope, SDFG) else scope.sdfg arrays: Set[str] = set(sdfg.arrays.keys() | sdfg.constants_prop.keys()) result: Dict[SDFGState, Dict[str, Any]] = {} @@ -172,62 +195,64 @@ def collect_constants(self, # * If unvisited state has more than one incoming edge, consider all paths (use reverse DFS on unvisited paths) # * If value is ambiguous (not the same), set value to UNKNOWN - start_state = sdfg.start_state + start_block = scope.start_block if initial_symbols: - result[start_state] = {} - result[start_state].update(initial_symbols) + result[start_block] = {} + result[start_block].update(initial_symbols) # Traverse SDFG topologically - for state in optional_progressbar(sdfg.topological_sort(start_state), 'Collecting constants', - sdfg.number_of_nodes(), self.progress): + for block in optional_progressbar(scope.topological_sort(start_block), 'Collecting constants', + scope.number_of_nodes(), self.progress): # NOTE: We must always check the start-state regardless if there are initial symbols. This is necessary # when the start-state is a scope's guard instead of a special initialization state, i.e., when the start- # state has incoming edges that may involve the initial symbols. See also: # `tests.passes.constant_propagation_test.test_for_with_external_init_nested_start_with_guard`` - if state in result and state is not start_state: + if block in result and block is not start_block: continue # Get predecessors - in_edges = sdfg.in_edges(state) + in_edges = scope.in_edges(block) + assignments = {} if len(in_edges) == 1: # Special case, propagate as-is - if state not in result: # Condition evaluates to False when state is the start-state - result[state] = {} - + if block not in result: # Condition evaluates to False when state is the start-state + result[block] = {} + # First the prior state if in_edges[0].src in result: # Condition evaluates to False when state is the start-state - self._propagate(result[state], result[in_edges[0].src]) + self._propagate(result[block], result[in_edges[0].src]) # Then assignments on the incoming edge - self._propagate(result[state], self._data_independent_assignments(in_edges[0].data, arrays)) - continue - - # More than one incoming edge: may require reversed traversal - assignments = {} - for edge in in_edges: - # If source was already visited, use its propagated constants - constants: Dict[str, Any] = {} - if edge.src in result: - constants.update(result[edge.src]) - else: # Otherwise, reverse DFS to find constants until a visited state - constants = self._constants_from_unvisited_state(sdfg, edge.src, arrays, result) - - # Update constants with incoming edge - self._propagate(constants, self._data_independent_assignments(edge.data, arrays)) - - for aname, aval in constants.items(): - # If something was assigned more than once (to a different value), it's not a constant - if aname in assignments and aval != assignments[aname]: - assignments[aname] = _UnknownValue - else: - assignments[aname] = aval - - if state not in result: # Condition may evaluate to False when state is the start-state - result[state] = {} - self._propagate(result[state], assignments) + self._propagate(result[block], self._data_independent_assignments(in_edges[0].data, arrays)) + else: + # More than one incoming edge: may require reversed traversal + for edge in in_edges: + # If source was already visited, use its propagated constants + constants: Dict[str, Any] = {} + if edge.src in result: + constants.update(result[edge.src]) + else: # Otherwise, reverse DFS to find constants until a visited state + constants = self._constants_from_unvisited_state(scope, edge.src, arrays, result) + + # Update constants with incoming edge + self._propagate(constants, self._data_independent_assignments(edge.data, arrays)) + + for aname, aval in constants.items(): + # If something was assigned more than once (to a different value), it's not a constant + if aname in assignments and aval != assignments[aname]: + assignments[aname] = _UnknownValue + else: + assignments[aname] = aval + + if isinstance(block, LoopScopeBlock): # Add the loop variable as unknown assignment + assignments[block.loop_variable] = _UnknownValue + + if block not in result: # Condition may evaluate to False when state is the start-state + result[block] = {} + self._propagate(result[block], assignments) return result - def _find_desc_symbols(self, sdfg: SDFG, constants: Dict[SDFGState, Dict[str, Any]]) -> Tuple[Set[str], Set[str]]: + def _find_desc_symbols(self, sdfg: SDFG, constants: Dict[ControlFlowBlock, Dict[str, Any]]) -> Tuple[Set[str], Set[str]]: """ Finds constant symbols that data descriptors (e.g., arrays) depend on. @@ -309,7 +334,7 @@ def _data_independent_assignments(self, edge: InterstateEdge, arrays: Set[str]) for k, v in edge.assignments.items() } - def _constants_from_unvisited_state(self, sdfg: SDFG, state: SDFGState, arrays: Set[str], + def _constants_from_unvisited_state(self, scope: ScopeBlock, state: SDFGState, arrays: Set[str], existing_constants: Dict[SDFGState, Dict[str, Any]]) -> Dict[str, Any]: """ Collects constants from an unvisited state, traversing backwards until reaching states that do have @@ -317,7 +342,7 @@ def _constants_from_unvisited_state(self, sdfg: SDFG, state: SDFGState, arrays: """ result: Dict[str, Any] = {} - for parent, node in sdutil.dfs_conditional(sdfg, + for parent, node in sdutil.dfs_conditional(scope, sources=[state], reverse=True, condition=lambda p, c: c not in existing_constants, @@ -327,7 +352,7 @@ def _constants_from_unvisited_state(self, sdfg: SDFG, state: SDFGState, arrays: continue # Get connecting edge (reversed) - edge = sdfg.edges_between(node, parent)[0] + edge = scope.edges_between(node, parent)[0] # If node already has propagated constants, update dictionary and stop traversal self._propagate(result, self._data_independent_assignments(edge.data, arrays), True) diff --git a/dace/transformation/passes/dead_dataflow_elimination.py b/dace/transformation/passes/dead_dataflow_elimination.py index aeaf1cdbd1..6f677def12 100644 --- a/dace/transformation/passes/dead_dataflow_elimination.py +++ b/dace/transformation/passes/dead_dataflow_elimination.py @@ -41,7 +41,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & (ppl.Modifies.Nodes | ppl.Modifies.Edges | ppl.Modifies.States) def depends_on(self) -> Set[Type[ppl.Pass]]: - return {ap.StateReachability, ap.AccessSets} + return {ap.LegacyStateReachability, ap.AccessSets} def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Dict[SDFGState, Set[str]]]: """ @@ -56,7 +56,7 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[D # Depends on the following analysis passes: # * State reachability # * Read/write access sets per state - reachable: Dict[SDFGState, Set[SDFGState]] = pipeline_results['StateReachability'][sdfg.sdfg_id] + reachable: Dict[SDFGState, Set[SDFGState]] = pipeline_results[ap.LegacyStateReachability.__name__][sdfg.sdfg_id] access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]] = pipeline_results['AccessSets'][sdfg.sdfg_id] result: Dict[SDFGState, Set[str]] = defaultdict(set) @@ -70,6 +70,9 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[D # Analysis ############################################# + if not isinstance(state, SDFGState): + continue + # Compute states where memory will no longer be read writes = access_sets[state][1] descendants = reachable[state] @@ -80,9 +83,10 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[D dead_nodes: List[nodes.Node] = [] # Propagate deadness backwards within a state - for node in sdutil.dfs_topological_sort(state, reverse=True): - if self._is_node_dead(node, sdfg, state, dead_nodes, no_longer_used, access_sets[state]): - dead_nodes.append(node) + if isinstance(state, SDFGState): + for node in sdutil.dfs_topological_sort(state, reverse=True): + if self._is_node_dead(node, sdfg, state, dead_nodes, no_longer_used, access_sets[state]): + dead_nodes.append(node) # Scope exit nodes are only dead if their corresponding entry nodes are live_nodes = set() diff --git a/dace/transformation/passes/dead_state_elimination.py b/dace/transformation/passes/dead_state_elimination.py index a5ff0ba71a..53755e01ef 100644 --- a/dace/transformation/passes/dead_state_elimination.py +++ b/dace/transformation/passes/dead_state_elimination.py @@ -5,6 +5,7 @@ from typing import Optional, Set, Tuple, Union from dace import SDFG, InterstateEdge, SDFGState, symbolic, properties +from dace.sdfg.state import ScopeBlock from dace.properties import CodeBlock from dace.sdfg.graph import Edge from dace.sdfg.validation import InvalidSDFGInterstateEdgeError @@ -12,7 +13,7 @@ @properties.make_properties -class DeadStateElimination(ppl.Pass): +class DeadStateElimination(ppl.ControlFlowScopePass): """ Removes all unreachable states (e.g., due to a branch that will never be taken) from an SDFG. """ @@ -26,23 +27,23 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: # If connectivity or any edges were changed, some more states might be dead return modified & (ppl.Modifies.InterstateEdges | ppl.Modifies.States) - def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Union[SDFGState, Edge[InterstateEdge]]]]: + def apply(self, scope_block: ScopeBlock, sdfg: SDFG, _) -> Optional[Set[Union[SDFGState, Edge[InterstateEdge]]]]: """ Removes unreachable states throughout an SDFG. + :param scope_block: The scope block to modify. :param sdfg: The SDFG to modify. :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass results as ``{Pass subclass name: returned object from pass}``. If not run in a pipeline, an empty dictionary is expected. - :param initial_symbols: If not None, sets values of initial symbols. :return: A set of the removed states, or None if nothing was changed. """ # Mark dead states and remove them - dead_states, dead_edges, annotated = self.find_dead_states(sdfg, set_unconditional_edges=True) + dead_states, dead_edges, annotated = self.find_dead_states(scope_block, sdfg, set_unconditional_edges=True) for e in dead_edges: - sdfg.remove_edge(e) - sdfg.remove_nodes_from(dead_states) + scope_block.remove_edge(e) + scope_block.remove_nodes_from(dead_states) result = dead_states | dead_edges @@ -53,6 +54,7 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Union[SDFGState, Edge[Inters def find_dead_states( self, + scope_block: ScopeBlock, sdfg: SDFG, set_unconditional_edges: bool = True) -> Tuple[Set[SDFGState], Set[Edge[InterstateEdge]], bool]: """ @@ -72,7 +74,7 @@ def find_dead_states( # Run a modified BFS where definitely False edges are not traversed, or if there is an # unconditional edge the rest are not. The inverse of the visited states is the dead set. - queue = collections.deque([sdfg.start_state]) + queue = collections.deque([scope_block.start_block]) while len(queue) > 0: node = queue.popleft() if node in visited: @@ -81,13 +83,13 @@ def find_dead_states( # First, check for unconditional edges unconditional = None - for e in sdfg.out_edges(node): + for e in scope_block.out_edges(node): # If an unconditional edge is found, ignore all other outgoing edges if self.is_definitely_taken(e.data, sdfg): # If more than one unconditional outgoing edge exist, fail with Invalid SDFG if unconditional is not None: raise InvalidSDFGInterstateEdgeError('Multiple unconditional edges leave the same state', sdfg, - sdfg.edge_id(e)) + scope_block.edge_id(e)) unconditional = e if set_unconditional_edges and not e.data.is_unconditional(): # Annotate edge as unconditional @@ -100,7 +102,7 @@ def find_dead_states( continue if unconditional is not None: # Unconditional edge exists, skip traversal # Remove other (now never taken) edges from graph - for e in sdfg.out_edges(node): + for e in scope_block.out_edges(node): if e is not unconditional: dead_edges.add(e) @@ -108,7 +110,7 @@ def find_dead_states( # End of unconditional check # Check outgoing edges normally - for e in sdfg.out_edges(node): + for e in scope_block.out_edges(node): next_node = e.dst # Test for edges that definitely evaluate to False @@ -121,7 +123,7 @@ def find_dead_states( queue.append(next_node) # Dead states are states that are not live (i.e., visited) - return set(sdfg.nodes()) - visited, dead_edges, edges_annotated + return set(scope_block.nodes()) - visited, dead_edges, edges_annotated def report(self, pass_retval: Set[Union[SDFGState, Edge[InterstateEdge]]]) -> str: if pass_retval is not None and not pass_retval: diff --git a/dace/transformation/passes/fusion_inline.py b/dace/transformation/passes/fusion_inline.py index 93764670e8..9e900f9443 100644 --- a/dace/transformation/passes/fusion_inline.py +++ b/dace/transformation/passes/fusion_inline.py @@ -8,7 +8,7 @@ from dace import SDFG, properties from dace.sdfg import nodes -from dace.sdfg.utils import fuse_states, inline_sdfgs +from dace.sdfg.utils import fuse_states, inline_sdfgs, inline_loop_blocks from dace.transformation import pass_pipeline as ppl @@ -85,6 +85,35 @@ def report(self, pass_retval: int) -> str: return f'Inlined {pass_retval} SDFGs.' +@dataclass(unsafe_hash=True) +@properties.make_properties +class InlineScopes(ppl.Pass): + """ + Inlines all possible sub-scopes of an SDFG to create a state machine. + """ + + CATEGORY: str = 'Cleanup' + + permissive = properties.Property(dtype=bool, default=False, desc='If True, ignores some checks on inlining.') + progress = properties.Property(dtype=bool, + default=None, + allow_none=True, + desc='Whether to print progress, or None for default (print after 5 seconds).') + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & (ppl.Modifies.States | ppl.Modifies.InterstateEdges) + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.States | ppl.Modifies.InterstateEdges + + def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]) -> Optional[int]: + inlined = inline_loop_blocks(sdfg, self.permissive, self.progress) + return inlined or None + + def report(self, pass_retval: int) -> str: + return f'Inlined {pass_retval} scopes.' + + @dataclass(unsafe_hash=True) @properties.make_properties class FixNestedSDFGReferences(ppl.Pass): @@ -102,19 +131,19 @@ def modifies(self) -> ppl.Modifies: def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]) -> Optional[int]: modified = 0 - for node, state in sdfg.all_nodes_recursive(): + for node, parent in sdfg.all_nodes_recursive(): if not isinstance(node, nodes.NestedSDFG): continue was_modified = False if node.sdfg.parent_nsdfg_node is not node: was_modified = True node.sdfg.parent_nsdfg_node = node - if node.sdfg.parent is not state: + if node.sdfg.parent is not parent: was_modified = True - node.sdfg.parent = state - if node.sdfg.parent_sdfg is not state.parent: + node.sdfg.parent = parent + if node.sdfg.parent_sdfg is not parent.sdfg: was_modified = True - node.sdfg.parent_sdfg = state.parent + node.sdfg.parent_sdfg = parent.sdfg if was_modified: modified += 1 diff --git a/dace/transformation/passes/prune_symbols.py b/dace/transformation/passes/prune_symbols.py index cf55f7a9b2..4e83c25d65 100644 --- a/dace/transformation/passes/prune_symbols.py +++ b/dace/transformation/passes/prune_symbols.py @@ -4,8 +4,9 @@ from dataclasses import dataclass from typing import Optional, Set, Tuple -from dace import SDFG, dtypes, properties, symbolic +from dace import SDFG, dtypes, properties, symbolic, SDFGState from dace.sdfg import nodes +from dace.sdfg.state import LoopScopeBlock from dace.transformation import pass_pipeline as ppl @@ -57,7 +58,7 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Tuple[int, str]]]: sid = sdfg.sdfg_id result = set((sid, sym) for sym in result) - for state in sdfg.nodes(): + for state in sdfg.states(): for node in state.nodes(): if isinstance(node, nodes.NestedSDFG): old_symbols = self.symbols @@ -83,25 +84,29 @@ def used_symbols(self, sdfg: SDFG) -> Set[str]: for desc in sdfg.arrays.values(): result |= set(map(str, desc.free_symbols)) - for state in sdfg.nodes(): - result |= state.free_symbols - # In addition to the standard free symbols, we are conservative with other tasklet languages by - # tokenizing their code. Since this is intersected with `sdfg.symbols`, keywords such as "if" are - # ok to include - for node in state.nodes(): - if isinstance(node, nodes.Tasklet): - if node.code.language != dtypes.Language.Python: - result |= symbolic.symbols_in_code(node.code.as_string, sdfg.symbols.keys(), - node.ignored_symbols) - if node.code_global.language != dtypes.Language.Python: - result |= symbolic.symbols_in_code(node.code_global.as_string, sdfg.symbols.keys(), - node.ignored_symbols) - if node.code_init.language != dtypes.Language.Python: - result |= symbolic.symbols_in_code(node.code_init.as_string, sdfg.symbols.keys(), - node.ignored_symbols) - if node.code_exit.language != dtypes.Language.Python: - result |= symbolic.symbols_in_code(node.code_exit.as_string, sdfg.symbols.keys(), - node.ignored_symbols) + for block in sdfg.all_control_flow_blocks_recursive(recurse_into_sdfgs=False): + result |= block.free_symbols + + if isinstance(block, SDFGState): + # In addition to the standard free symbols, we are conservative with other tasklet languages by + # tokenizing their code. Since this is intersected with `sdfg.symbols`, keywords such as "if" are + # ok to include + for node in block.nodes(): + if isinstance(node, nodes.Tasklet): + if node.code.language != dtypes.Language.Python: + result |= symbolic.symbols_in_code(node.code.as_string, sdfg.symbols.keys(), + node.ignored_symbols) + if node.code_global.language != dtypes.Language.Python: + result |= symbolic.symbols_in_code(node.code_global.as_string, sdfg.symbols.keys(), + node.ignored_symbols) + if node.code_init.language != dtypes.Language.Python: + result |= symbolic.symbols_in_code(node.code_init.as_string, sdfg.symbols.keys(), + node.ignored_symbols) + if node.code_exit.language != dtypes.Language.Python: + result |= symbolic.symbols_in_code(node.code_exit.as_string, sdfg.symbols.keys(), + node.ignored_symbols) + elif isinstance(block, LoopScopeBlock): + result.add(block.loop_variable) for e in sdfg.edges(): result |= e.data.free_symbols diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index 124efdaae1..05dace8166 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -95,7 +95,7 @@ def find_promotable_scalars(sdfg: sd.SDFG, transients_only: bool = True, integer # Check all occurrences of candidates in SDFG and filter out candidates_seen: Set[str] = set() - for state in sdfg.nodes(): + for state in sdfg.states(): candidates_in_state: Set[str] = set() for node in state.nodes(): @@ -225,7 +225,7 @@ def find_promotable_scalars(sdfg: sd.SDFG, transients_only: bool = True, integer # Filter out non-integral symbols that do not appear in inter-state edges interstate_symbols = set() - for edge in sdfg.edges(): + for edge in sdfg.all_interstate_edges_recursive(recurse_into_sdfgs=False): interstate_symbols |= edge.data.free_symbols for candidate in (candidates - interstate_symbols): if integers_only and sdfg.arrays[candidate].dtype not in dtypes.INTEGER_TYPES: @@ -508,7 +508,7 @@ def remove_scalar_reads(sdfg: sd.SDFG, array_names: Dict[str, str]): replacement symbol name. :note: Operates in-place on the SDFG. """ - for state in sdfg.nodes(): + for state in sdfg.states(): scalar_nodes = [n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data in array_names] for node in scalar_nodes: symname = array_names[node.data] @@ -633,8 +633,8 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: if len(to_promote) == 0: return None - for state in sdfg.nodes(): - scalar_nodes = [n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data in to_promote] + for state in sdfg.states(): + scalar_nodes = [n for n in state.data_nodes() if n.data in to_promote] # Step 2: Assignment tasklets for node in scalar_nodes: if state.in_degree(node) == 0: @@ -645,8 +645,8 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: # There is only zero or one incoming edges by definition tasklet_inputs = [e.src for e in state.in_edges(input)] # Step 2.1 - new_state = xfh.state_fission(sdfg, gr.SubgraphView(state, set([input, node] + tasklet_inputs))) - new_isedge: sd.InterstateEdge = sdfg.out_edges(new_state)[0] + new_state = xfh.state_fission(gr.SubgraphView(state, set([input, node] + tasklet_inputs))) + new_isedge: sd.InterstateEdge = state.parent.out_edges(new_state)[0] # Step 2.2 node: nodes.AccessNode = new_state.sink_nodes()[0] input = new_state.in_edges(node)[0].src @@ -683,7 +683,7 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: remove_scalar_reads(sdfg, {k: k for k in to_promote}) # Step 4: Isolated nodes - for state in sdfg.nodes(): + for state in sdfg.states(): scalar_nodes = [n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data in to_promote] state.remove_nodes_from([n for n in scalar_nodes if len(state.all_edges(n)) == 0]) @@ -699,7 +699,7 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: # Step 6: Inter-state edge cleanup cleanup_re = {s: re.compile(fr'\b{re.escape(s)}\[.*?\]') for s in to_promote} promo = TaskletPromoterDict({k: k for k in to_promote}) - for edge in sdfg.edges(): + for edge in sdfg.all_interstate_edges_recursive(recurse_into_sdfgs=False): ise: InterstateEdge = edge.data # Condition if not edge.data.is_unconditional(): diff --git a/dace/transformation/passes/simplify.py b/dace/transformation/passes/simplify.py index 0a2539457a..a8a94ba6c7 100644 --- a/dace/transformation/passes/simplify.py +++ b/dace/transformation/passes/simplify.py @@ -15,6 +15,7 @@ from dace.transformation.passes.scalar_to_symbol import ScalarToSymbolPromotion from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols +# TODO: Re-enable everything SIMPLIFY_PASSES = [ InlineSDFGs, ScalarToSymbolPromotion, diff --git a/dace/transformation/transformation.py b/dace/transformation/transformation.py index 75e591cb1e..bd07b76c90 100644 --- a/dace/transformation/transformation.py +++ b/dace/transformation/transformation.py @@ -22,7 +22,7 @@ import copy from dace import dtypes, serialize from dace.dtypes import ScheduleType -from dace.sdfg import SDFG, SDFGState +from dace.sdfg import SDFG, SDFGState, ScopeBlock from dace.sdfg import nodes as nd, graph as gr, utils as sdutil, propagation, infer_types, state as st from dace.properties import make_properties, Property, DictProperty, SetProperty from dace.transformation import pass_pipeline as ppl @@ -108,7 +108,7 @@ def expressions(cls) -> List[gr.SubgraphView]: raise NotImplementedError def can_be_applied(self, - graph: Union[SDFG, SDFGState], + graph: Union[ScopeBlock, SDFGState], expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: @@ -126,10 +126,11 @@ def can_be_applied(self, """ raise NotImplementedError - def apply(self, graph: Union[SDFG, SDFGState], sdfg: SDFG) -> Union[Any, None]: + def apply(self, graph: Union[ScopeBlock, SDFGState], sdfg: SDFG) -> Union[Any, None]: """ Applies this transformation instance on the matched pattern graph. + :param graph: The graph object on which the transformation operates. :param sdfg: The SDFG to apply the transformation to. :return: A transformation-defined return value, which could be used to pass analysis data out, or nothing. @@ -499,7 +500,7 @@ def expressions(cls) -> List[gr.SubgraphView]: pass @abc.abstractmethod - def can_be_applied(self, graph: SDFG, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: + def can_be_applied(self, graph: ScopeBlock, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: """ Returns True if this transformation can be applied on the candidate matched subgraph. :param graph: SDFG object in which the match was found. @@ -707,7 +708,7 @@ def setup_match(self, subgraph: Union[Set[int], gr.SubgraphView], sdfg_id: int = self.subgraph = set(subgraph.graph.node_id(n) for n in subgraph.nodes()) if isinstance(subgraph.graph, SDFGState): - sdfg = subgraph.graph.parent + sdfg = subgraph.graph.sdfg self.sdfg_id = sdfg.sdfg_id self.state_id = sdfg.node_id(subgraph.graph) elif isinstance(subgraph.graph, SDFG): diff --git a/dace/viewer/webclient b/dace/viewer/webclient index dd34948875..cfc0443ecf 160000 --- a/dace/viewer/webclient +++ b/dace/viewer/webclient @@ -1 +1 @@ -Subproject commit dd34948875d01f63749faee5dd0fd34a198aaaa6 +Subproject commit cfc0443ecf64e130d13ac86b6c7cf5ae138928e4 diff --git a/samples/optimization/matmul.py b/samples/optimization/matmul.py index 06b6a38939..9c8fbbc383 100644 --- a/samples/optimization/matmul.py +++ b/samples/optimization/matmul.py @@ -113,7 +113,7 @@ def optimize_for_cpu(sdfg: dace.SDFG, m: int, n: int, k: int): # Vectorize microkernel map postamble = n % 4 != 0 entry_inner, inner_state = find_map_and_state_by_param(sdfg, 'k') - Vectorization.apply_to(inner_state.parent, + Vectorization.apply_to(inner_state.sdfg, dict(vector_len=4, preamble=False, postamble=postamble), map_entry=entry_inner) diff --git a/tests/memlet_propagation_test.py b/tests/memlet_propagation_test.py index f90834cbb7..f1196348da 100644 --- a/tests/memlet_propagation_test.py +++ b/tests/memlet_propagation_test.py @@ -73,7 +73,7 @@ def sparse(A: dace.float32[M, N], ind: dace.int32[M, N]): propagate_memlets_sdfg(sdfg) # Verify all memlet subsets and volumes in the main state of the program, i.e. around the NSDFG. - map_state = sdfg.states()[1] + map_state = list(sdfg.states())[1] i = dace.symbol('i') j = dace.symbol('j') diff --git a/tests/passes/state_reach_test.py b/tests/passes/state_reach_test.py new file mode 100644 index 0000000000..0dc12f09ff --- /dev/null +++ b/tests/passes/state_reach_test.py @@ -0,0 +1,37 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import pytest +import dace +from dace.sdfg import SDFG +from dace.transformation import pass_pipeline as ppl +from dace.transformation import passes + + +def test_loop_scope_reach(): + sdfg = SDFG('loop_scope_reach_test') + s1 = sdfg.add_state('s1') + s6 = sdfg.add_state('s6') + (_, loop1, _) = sdfg.add_loop(s1, s6, 'i', 'i=0', 'i<10', 'i+1') + loop1.label = 'loop1' + s2 = loop1.add_state('s2') + s5 = loop1.add_state('s5') + (_, loop2, _) = loop1.add_loop(s2, s5, 'j', 'j=0', 'j<10', 'j+1') + loop2.label = 'loop2' + s3 = loop2.add_state('s3') + s4 = loop2.add_state('s4') + loop2.add_edge(s3, s4, dace.InterstateEdge()) + + res = {} + ppl.Pipeline([passes.analysis.LegacyStateReachability()]).apply_pass(sdfg, res) + + reach = res[passes.analysis.LegacyStateReachability.__name__][0] + assert reach[s1] == {s2, s3, s4, s5, s6} + assert reach[s2] == {s2, s3, s4, s5, s6} + assert reach[s3] == {s2, s3, s4, s5, s6} + assert reach[s4] == {s2, s3, s4, s5, s6} + assert reach[s5] == {s2, s3, s4, s5, s6} + assert len(reach[s6]) == 0 + + +if __name__ == '__main__': + test_loop_scope_reach() diff --git a/tests/sdfg_validate_names_test.py b/tests/sdfg_validate_names_test.py index dad79c8950..1650a4e4b1 100644 --- a/tests/sdfg_validate_names_test.py +++ b/tests/sdfg_validate_names_test.py @@ -28,7 +28,7 @@ def test_state_duplication(self): sdfg = dace.SDFG('ok') s1 = sdfg.add_state('also_ok') s2 = sdfg.add_state('also_ok') - s2.set_label('also_ok') + s2.label = 'also_ok' sdfg.add_edge(s1, s2, dace.InterstateEdge()) sdfg.validate() self.fail('Failed to detect duplicate state') diff --git a/tests/transformations/nest_subgraph_test.py b/tests/transformations/nest_subgraph_test.py index 763bb3327d..ec339b907e 100644 --- a/tests/transformations/nest_subgraph_test.py +++ b/tests/transformations/nest_subgraph_test.py @@ -30,7 +30,7 @@ def test_nest_oneelementmap(): for node, state in sdfg.all_nodes_recursive(): if isinstance(node, dace.nodes.MapEntry): subgraph = state.scope_subgraph(node, include_entry=False, include_exit=False) - nest_state_subgraph(state.parent, state, subgraph) + nest_state_subgraph(state.sdfg, state, subgraph) sdfg(A=A, B=B) assert np.allclose(A, B) diff --git a/tests/transformations/state_fission_test.py b/tests/transformations/state_fission_test.py index 7c03fbed89..fb813d700c 100644 --- a/tests/transformations/state_fission_test.py +++ b/tests/transformations/state_fission_test.py @@ -120,17 +120,17 @@ def test_state_fission(): sdfg = make_nested_sdfg_cpu() # state fission - state = sdfg.states()[0] + state = list(sdfg.states())[0] node_x = state.nodes()[0] node_y = state.nodes()[1] node_z = state.nodes()[2] vec_add1 = state.nodes()[3] subg = dace.sdfg.graph.SubgraphView(state, [node_x, node_y, vec_add1, node_z]) - helpers.state_fission(sdfg, subg) + helpers.state_fission(subg) sdfg.validate() - assert (len(sdfg.states()) == 2) + assert (len(list(sdfg.states())) == 2) # run the program vec_add = sdfg.compile() diff --git a/tests/transformations/tasklet_fusion_test.py b/tests/transformations/tasklet_fusion_test.py index 743010e8c9..df96228cec 100644 --- a/tests/transformations/tasklet_fusion_test.py +++ b/tests/transformations/tasklet_fusion_test.py @@ -31,7 +31,7 @@ def _make_sdfg(language: str, with_data: bool = False): sdfg.add_array('A', (N, ), datatype) sdfg.add_array('B', (M, ), datatype) sdfg.add_array('C', (M, ), datatype) - state = sdfg.add_state(is_start_state=True) + state = sdfg.add_state(is_start_block=True) A = state.add_read('A') B = state.add_read('B') C = state.add_write('C') diff --git a/tests/transformations/trivial_loop_elimination_test.py b/tests/transformations/trivial_loop_elimination_test.py index 6f2769f921..20514a3331 100644 --- a/tests/transformations/trivial_loop_elimination_test.py +++ b/tests/transformations/trivial_loop_elimination_test.py @@ -3,7 +3,6 @@ import dace from dace.transformation.interstate import TrivialLoopElimination from dace.symbolic import pystr_to_symbolic -import unittest import numpy as np I = dace.symbol("I") @@ -17,21 +16,19 @@ def trivial_loop(data: dace.float64[I, J]): data[i, j] = data[i, j] + data[i - 1, j] -class TrivialLoopEliminationTest(unittest.TestCase): +def test_semantic_eq(): + A1 = np.random.rand(16, 16) + A2 = np.copy(A1) - def test_semantic_eq(self): - A1 = np.random.rand(16, 16) - A2 = np.copy(A1) + sdfg = trivial_loop.to_sdfg(simplify=False) + sdfg(A1, I=A1.shape[0], J=A1.shape[1]) - sdfg = trivial_loop.to_sdfg(simplify=False) - sdfg(A1, I=A1.shape[0], J=A1.shape[1]) + count = sdfg.apply_transformations(TrivialLoopElimination) + assert (count > 0) + sdfg(A2, I=A1.shape[0], J=A1.shape[1]) - count = sdfg.apply_transformations(TrivialLoopElimination) - self.assertGreater(count, 0) - sdfg(A2, I=A1.shape[0], J=A1.shape[1]) - - self.assertTrue(np.allclose(A1, A2)) + assert np.allclose(A1, A2) if __name__ == '__main__': - unittest.main() + test_semantic_eq() diff --git a/tests/wcr_cudatest.py b/tests/wcr_cudatest.py index 03a99fb8e1..018445fae6 100644 --- a/tests/wcr_cudatest.py +++ b/tests/wcr_cudatest.py @@ -8,7 +8,7 @@ def create_zero_initialization(init_state: dace.SDFGState, array_name): - sdfg = init_state.parent + sdfg = init_state.sdfg array_shape = sdfg.arrays[array_name].shape array_access_node = init_state.add_write(array_name)