diff --git a/dace/codegen/codegen.py b/dace/codegen/codegen.py index c502a47376..b963da4812 100644 --- a/dace/codegen/codegen.py +++ b/dace/codegen/codegen.py @@ -100,13 +100,13 @@ def _get_codegen_targets(sdfg: SDFG, frame: framecode.DaCeCodeGenerator): frame.targets.add(disp.get_scope_dispatcher(node.schedule)) elif isinstance(node, dace.nodes.Node): state: SDFGState = parent - nsdfg = state.parent + nsdfg = state.sdfg frame.targets.add(disp.get_node_dispatcher(nsdfg, state, node)) # Array allocation if isinstance(node, dace.nodes.AccessNode): state: SDFGState = parent - nsdfg = state.parent + nsdfg = state.sdfg desc = node.desc(nsdfg) frame.targets.add(disp.get_array_dispatcher(desc.storage)) @@ -124,13 +124,13 @@ def _get_codegen_targets(sdfg: SDFG, frame: framecode.DaCeCodeGenerator): dst_node = leaf_e.dst if leaf_e.data.is_empty(): continue - tgt = disp.get_copy_dispatcher(node, dst_node, leaf_e, state.parent, state) + tgt = disp.get_copy_dispatcher(node, dst_node, leaf_e, state.sdfg, state) if tgt is not None: frame.targets.add(tgt) else: # Rooted at dst_node dst_node = mtree.root().edge.dst - tgt = disp.get_copy_dispatcher(node, dst_node, e, state.parent, state) + tgt = disp.get_copy_dispatcher(node, dst_node, e, state.sdfg, state) if tgt is not None: frame.targets.add(tgt) diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index 182604c892..d4803481f7 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -121,7 +121,7 @@ class SingleState(ControlFlow): last_state: bool = False def as_cpp(self, codegen, symbols) -> str: - sdfg = self.state.parent + sdfg = self.state.sdfg expr = '__state_{}_{}:;\n'.format(sdfg.sdfg_id, self.state.label) if self.state.number_of_nodes() > 0: @@ -218,7 +218,7 @@ def as_cpp(self, codegen, symbols) -> str: # In a general block, emit transitions and assignments after each # individual state if isinstance(elem, SingleState): - sdfg = elem.state.parent + sdfg = elem.state.sdfg out_edges = sdfg.out_edges(elem.state) for j, e in enumerate(out_edges): if e not in self.gotos_to_ignore: @@ -361,7 +361,7 @@ def as_cpp(self, codegen, symbols) -> str: init = f'{symbols[self.itervar]} {self.itervar}' init += ' = ' + self.init - sdfg = self.guard.parent + sdfg = self.guard.sdfg preinit = '' if self.init_edges: @@ -403,7 +403,7 @@ class WhileScope(ControlFlow): def as_cpp(self, codegen, symbols) -> str: if self.test is not None: - sdfg = self.guard.parent + sdfg = self.guard.sdfg test = unparse_interstate_edge(self.test.code[0], sdfg, codegen=codegen) else: test = 'true' diff --git a/dace/codegen/targets/cuda.py b/dace/codegen/targets/cuda.py index ee49f04d03..adac3317d2 100644 --- a/dace/codegen/targets/cuda.py +++ b/dace/codegen/targets/cuda.py @@ -204,9 +204,9 @@ def preprocess(self, sdfg: SDFG) -> None: for state, node, defined_syms in sdutil.traverse_sdfg_with_defined_symbols(sdfg, recursive=True): if (isinstance(node, nodes.MapEntry) and node.map.schedule in (dtypes.ScheduleType.GPU_Device, dtypes.ScheduleType.GPU_Persistent)): - if state.parent not in shared_transients: - shared_transients[state.parent] = state.parent.shared_transients() - self._arglists[node] = state.scope_subgraph(node).arglist(defined_syms, shared_transients[state.parent]) + if state.sdfg not in shared_transients: + shared_transients[state.sdfg] = state.sdfg.shared_transients() + self._arglists[node] = state.scope_subgraph(node).arglist(defined_syms, shared_transients[state.sdfg]) def _compute_pool_release(self, top_sdfg: SDFG): """ @@ -831,7 +831,7 @@ def increment(streams): # Remove CUDA streams from paths of non-gpu copies and CPU tasklets for node, graph in sdfg.all_nodes_recursive(): if isinstance(graph, SDFGState): - cur_sdfg = graph.parent + cur_sdfg = graph.sdfg if (isinstance(node, (nodes.EntryNode, nodes.ExitNode)) and node.schedule in dtypes.GPU_SCHEDULES): # Node must have GPU stream, remove childpath and continue @@ -1421,7 +1421,7 @@ def generate_scope(self, sdfg, dfg_scope, state_id, function_stream, callsite_st visited = set() for node, parent in dfg_scope.all_nodes_recursive(): if isinstance(node, nodes.AccessNode): - nsdfg: SDFG = parent.parent + nsdfg: SDFG = parent.sdfg desc = node.desc(nsdfg) if (nsdfg, node.data) in visited: continue diff --git a/dace/codegen/targets/fpga.py b/dace/codegen/targets/fpga.py index 413cb751d6..e1753dc323 100644 --- a/dace/codegen/targets/fpga.py +++ b/dace/codegen/targets/fpga.py @@ -1332,7 +1332,7 @@ def partition_kernels(self, state: dace.SDFGState, default_kernel: int = 0): """ concurrent_kernels = 0 # Max number of kernels - sdfg = state.parent + sdfg = state.sdfg def increment(kernel_id): if concurrent_kernels > 0: diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 9ee5c2ef17..57767c66b0 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -491,7 +491,7 @@ def _get_schedule(self, scope: Union[nodes.EntryNode, SDFGState, SDFG]) -> dtype elif isinstance(scope, nodes.EntryNode): return scope.schedule elif isinstance(scope, (SDFGState, SDFG)): - sdfg: SDFG = (scope if isinstance(scope, SDFG) else scope.parent) + sdfg: SDFG = (scope if isinstance(scope, SDFG) else scope.sdfg) if sdfg.parent_nsdfg_node is None: return TOP_SCHEDULE @@ -721,7 +721,7 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG): if curscope is None: curscope = curstate elif isinstance(curscope, (SDFGState, SDFG)): - cursdfg: SDFG = (curscope if isinstance(curscope, SDFG) else curscope.parent) + cursdfg: SDFG = (curscope if isinstance(curscope, SDFG) else curscope.sdfg) # Go one SDFG up if cursdfg.parent_nsdfg_node is None: curscope = None diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 801c742979..5cdd10ba98 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -1335,7 +1335,7 @@ def _add_block(self, block: ControlFlowBlock): self.current_state = block def _add_state(self, label=None) -> SDFGState: - state = self.cfg_target.add_state(label, False, self.sdfg) + state = self.cfg_target.add_state(label, False) self._add_block(state) return state diff --git a/dace/memlet.py b/dace/memlet.py index 74a1320a3b..5d25b3ae18 100644 --- a/dace/memlet.py +++ b/dace/memlet.py @@ -400,11 +400,11 @@ def try_initialize(self, sdfg: 'dace.sdfg.SDFG', state: 'dace.sdfg.SDFGState', self.subset = subsets.Range.from_array(sdfg.arrays[self.data]) def get_src_subset(self, edge: 'dace.sdfg.graph.MultiConnectorEdge', state: 'dace.sdfg.SDFGState'): - self.try_initialize(state.parent, state, edge) + self.try_initialize(state.sdfg, state, edge) return self.src_subset def get_dst_subset(self, edge: 'dace.sdfg.graph.MultiConnectorEdge', state: 'dace.sdfg.SDFGState'): - self.try_initialize(state.parent, state, edge) + self.try_initialize(state.sdfg, state, edge) return self.dst_subset @staticmethod diff --git a/dace/sdfg/analysis/cutout.py b/dace/sdfg/analysis/cutout.py index a72a6d7e54..781073a099 100644 --- a/dace/sdfg/analysis/cutout.py +++ b/dace/sdfg/analysis/cutout.py @@ -193,7 +193,7 @@ def singlestate_cutout(cls, if reduce_input_config: nodes = _reduce_in_configuration(state, nodes, use_alibi_nodes, symbols_map) create_element = copy.deepcopy if make_copy else (lambda x: x) - sdfg = state.parent + sdfg = state.sdfg subgraph: StateSubgraphView = StateSubgraphView(state, nodes) subgraph = _extend_subgraph_with_access_nodes(state, subgraph, use_alibi_nodes) @@ -341,8 +341,8 @@ def multistate_cutout(cls, create_element = copy.deepcopy # Check that all states are inside the same SDFG. - sdfg = list(states)[0].parent - if any(i.parent != sdfg for i in states): + sdfg = list(states)[0].sdfg + if any(i.sdfg != sdfg for i in states): raise Exception('Not all cutout states reside in the same SDFG') cutout_states: Set[SDFGState] = set(states) @@ -423,13 +423,13 @@ def multistate_cutout(cls, in_translation[is_edge.src] = new_el out_translation[new_el] = is_edge.src cutout.add_node(new_el, is_start_state=(is_edge.src == start_state)) - new_el.parent = cutout + new_el.sdfg = cutout if is_edge.dst not in in_translation: new_el: SDFGState = create_element(is_edge.dst) in_translation[is_edge.dst] = new_el out_translation[new_el] = is_edge.dst cutout.add_node(new_el, is_start_state=(is_edge.dst == start_state)) - new_el.parent = cutout + new_el.sdfg = cutout new_isedge: InterstateEdge = create_element(is_edge.data) in_translation[is_edge.data] = new_isedge out_translation[new_isedge] = is_edge.data @@ -442,7 +442,7 @@ def multistate_cutout(cls, in_translation[state] = new_el out_translation[new_el] = state cutout.add_node(new_el, is_start_state=(state == start_state)) - new_el.parent = cutout + new_el.sdfg = cutout in_translation[sdfg.sdfg_id] = cutout.sdfg_id out_translation[cutout.sdfg_id] = sdfg.sdfg_id @@ -574,8 +574,8 @@ def _reduce_in_configuration(state: SDFGState, affected_nodes: Set[nd.Node], use # For the given state, determine what should count as the input configuration if we were to cut out the entire # state. - state_reachability_dict = StateReachability().apply_pass(state.parent, None) - state_reach = state_reachability_dict[state.parent.sdfg_id] + state_reachability_dict = StateReachability().apply_pass(state.sdfg, None) + state_reach = state_reachability_dict[state.sdfg.sdfg_id] reaching_cutout: Set[SDFGState] = set() for k, v in state_reach.items(): if state in v: @@ -586,7 +586,7 @@ def _reduce_in_configuration(state: SDFGState, affected_nodes: Set[nd.Node], use if state.out_degree(dn) > 0: # This is read from, add to the system state if it is written anywhere else in the graph. # Except if it is also written to at the same time and is scalar or of size 1. - array = state.parent.arrays[dn.data] + array = state.sdfg.arrays[dn.data] if state.in_degree(dn) > 0 and (array.total_size == 1 or isinstance(array, data.Scalar)): continue elif not array.transient: @@ -608,8 +608,8 @@ def _reduce_in_configuration(state: SDFGState, affected_nodes: Set[nd.Node], use # about symbol values. Not sure how to do that yet. if symbols_map is None: symbols_map = dict() - consts = state.parent.constants - for s in state.parent.symbols: + consts = state.sdfg.constants + for s in state.sdfg.symbols: if s in consts: symbols_map[s] = consts[s] else: @@ -730,8 +730,8 @@ def _reduce_in_configuration(state: SDFGState, affected_nodes: Set[nd.Node], use for node in scope_nodes: if isinstance(node, nd.AccessNode) and node.data in state_input_configuration: - if not proxy_graph.has_edge(source, node) and node.data in state.parent.arrays: - vol = state.parent.arrays[node.data].total_size + if not proxy_graph.has_edge(source, node) and node.data in state.sdfg.arrays: + vol = state.sdfg.arrays[node.data].total_size if isinstance(vol, sp.Expr): vol = vol.subs(symbols_map) proxy_graph.add_edge(source, node, capacity=vol) @@ -767,7 +767,7 @@ def _stateset_predecessor_frontier(states: Set[SDFGState]) -> Tuple[Set[SDFGStat pred_frontier = set() pred_frontier_edges = set() for state in states: - for iedge in state.parent.in_edges(state): + for iedge in state.sdfg.in_edges(state): if iedge.src not in states: if iedge.src not in pred_frontier: pred_frontier.add(iedge.src) @@ -819,7 +819,7 @@ def _create_alibi_access_node_for_edge(target_sdfg: SDFG, target_state: SDFGStat def _extend_subgraph_with_access_nodes(state: SDFGState, subgraph: StateSubgraphView, use_alibi_nodes: bool) -> StateSubgraphView: """ Expands a subgraph view to include necessary input/output access nodes, using memlet paths. """ - sdfg = state.parent + sdfg = state.sdfg result: List[nd.Node] = copy.copy(subgraph.nodes()) queue: Deque[nd.Node] = deque(subgraph.nodes()) @@ -1014,7 +1014,7 @@ def _cutout_determine_output_configuration(ct: SDFG, cutout_reach: Set[SDFGState check_for_read_after.add(dn.data) original_state: SDFGState = out_translation[state] - for edge in original_state.parent.out_edges(original_state): + for edge in original_state.sdfg.out_edges(original_state): if edge.dst in cutout_reach: border_out_edges.add(edge.data) diff --git a/dace/sdfg/infer_types.py b/dace/sdfg/infer_types.py index 105e1d12e9..ed4f5e068f 100644 --- a/dace/sdfg/infer_types.py +++ b/dace/sdfg/infer_types.py @@ -257,7 +257,7 @@ def _determine_schedule_from_storage(state: SDFGState, node: nodes.Node) -> Opti # From memlets, use non-scalar data descriptors for decision constraints: Set[dtypes.ScheduleType] = set() - sdfg = state.parent + sdfg = state.sdfg for dname in memlets: if isinstance(sdfg.arrays[dname], data.Scalar): continue # Skip scalars @@ -276,7 +276,7 @@ def _determine_schedule_from_storage(state: SDFGState, node: nodes.Node) -> Opti raise validation.InvalidSDFGNodeError( f'Cannot determine default schedule for node {node}. ' 'Multiple arrays that point to it say that it should be the following schedules: ' - f'{constraints}', state.parent, state.parent.node_id(state), state.node_id(node)) + f'{constraints}', state.sdfg, state.sdfg.node_id(state), state.node_id(node)) else: child_schedule = next(iter(constraints)) @@ -338,7 +338,7 @@ def _set_default_storage_in_scope(state: SDFGState, parent_node: Optional[nodes. parent_schedules = parent_schedules + [dtypes.ScheduleType.GPU_ThreadBlock] # End of special case - sdfg = state.parent + sdfg = state.sdfg child_storage = _determine_child_storage(parent_schedules) if child_storage is None: child_storage = dtypes.SCOPEDEFAULT_STORAGE[None] diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 28431deeea..d37b27d3de 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -265,7 +265,7 @@ def __label__(self, sdfg, state): def desc(self, sdfg): from dace.sdfg import SDFGState, ScopeSubgraphView if isinstance(sdfg, (SDFGState, ScopeSubgraphView)): - sdfg = sdfg.parent + sdfg = sdfg.sdfg return sdfg.arrays[self.data] def validate(self, sdfg, state): @@ -588,7 +588,7 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: # Filter out unused internal symbols from symbol mapping if not all_symbols: - internally_used_symbols = self.sdfg.used_symbols(all_symbols=False) + internally_used_symbols = self.sdfg.used_symbols(all_symbols=False)[0] free_syms &= internally_used_symbols return free_syms diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index 5e42830a75..5ce2b7a45f 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -168,17 +168,18 @@ def replace_datadesc_names(sdfg, repl: Dict[str, str]): sdfg.constants_prop[repl[aname]] = sdfg.constants_prop[aname] del sdfg.constants_prop[aname] - # Replace in interstate edges - for e in sdfg.edges(): - e.data.replace_dict(repl, replace_keys=False) - - for state in sdfg.nodes(): - # Replace in access nodes - for node in state.data_nodes(): - if node.data in repl: - node.data = repl[node.data] - - # Replace in memlets - for edge in state.edges(): - if edge.data.data in repl: - edge.data.data = repl[edge.data.data] + for cf in sdfg.all_cfgs_recursive(recurse_into_sdfgs=False): + # Replace in interstate edges + for e in cf.edges(): + e.data.replace_dict(repl, replace_keys=False) + + for block in cf.nodes(): + if isinstance(block, dace.SDFGState): + # Replace in access nodes + for node in block.data_nodes(): + if node.data in repl: + node.data = repl[node.data] + # Replace in memlets + for edge in block.edges(): + if edge.data.data in repl: + edge.data.data = repl[edge.data.data] diff --git a/dace/sdfg/scope.py b/dace/sdfg/scope.py index 930d65be2e..75f8eacf75 100644 --- a/dace/sdfg/scope.py +++ b/dace/sdfg/scope.py @@ -122,7 +122,10 @@ def node_id_or_none(node): if node is None: return -1 return state.node_id(node) - return {node_id_or_none(k): [node_id_or_none(vi) for vi in v] for k, v in scope_dict.items()} + res = {} + for k, v in scope_dict.items(): + res[node_id_or_none(k)] = [node_id_or_none(vi) for vi in v] if v is not None else [] + return res def scope_contains_scope(sdict: ScopeDictType, node: NodeType, other_node: NodeType) -> bool: @@ -246,7 +249,7 @@ def is_devicelevel_gpu_kernel(sdfg: 'dace.sdfg.SDFG', state: 'dace.sdfg.SDFGStat if is_parent_nested: return is_devicelevel_gpu(sdfg.parent.parent, sdfg.parent, sdfg.parent_nsdfg_node, with_gpu_default=True) else: - return is_devicelevel_gpu(state.parent, state, node, with_gpu_default=True) + return is_devicelevel_gpu(state.sdfg, state, node, with_gpu_default=True) def is_devicelevel_fpga(sdfg: 'dace.sdfg.SDFG', state: 'dace.sdfg.SDFGState', node: NodeType) -> bool: @@ -294,7 +297,7 @@ def devicelevel_block_size(sdfg: 'dace.sdfg.SDFG', state: 'dace.sdfg.SDFGState', # Traverse up nested SDFGs if sdfg.parent is not None: if isinstance(sdfg.parent, SDFGState): - parent = sdfg.parent.parent + parent = sdfg.parent.sdfg else: parent = sdfg.parent state, node = next((s, n) for s in parent.nodes() for n in s.nodes() diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 086b5354aa..f2ece7e042 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -501,8 +501,8 @@ def __init__(self, self.add_constant(cstname, cstval, cst_dtype) self._propagate = propagate - self._parent = parent self.symbols = {} + self._parent_sdfg = None self._parent_nsdfg_node = None self._sdfg_list = [self] self._arrays = NestedDict() # type: Dict[str, dt.Array] @@ -1199,24 +1199,24 @@ def propagate(self): def propagate(self, propagate: bool): self._propagate = propagate - @property - def parent(self) -> SDFGState: - """ Returns the parent SDFG state of this SDFG, if exists. """ - return self._parent - @property def parent_nsdfg_node(self) -> nd.NestedSDFG: """ Returns the parent NestedSDFG node of this SDFG, if exists. """ return self._parent_nsdfg_node - @parent.setter - def parent(self, value): - self._parent = value - @parent_nsdfg_node.setter def parent_nsdfg_node(self, value): self._parent_nsdfg_node = value + @property + def parent_sdfg(self) -> 'SDFG': + """ Returns the parent SDFG of this SDFG, if exists. """ + return self._parent_sdfg + + @parent_sdfg.setter + def parent_sdfg(self, value): + self._parent_sdfg = value + def remove_node(self, node: SDFGState): if node is self._cached_start_block: self._cached_start_block = None @@ -1236,15 +1236,7 @@ def arrays_recursive(self): if isinstance(node, nd.NestedSDFG): yield from node.sdfg.arrays_recursive() - def used_symbols(self, all_symbols: bool) -> Set[str]: - """ - Returns a set of symbol names that are used by the SDFG, but not - defined within it. This property is used to determine the symbolic - parameters of the SDFG. - - :param all_symbols: If False, only returns the set of symbols that will be used - in the generated code and are needed as arguments. - """ + def used_symbols(self, all_symbols: bool) -> Tuple[Set[str], Set[str], Set[str]]: defined_syms = set() free_syms = set() @@ -1269,29 +1261,17 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: # Add free state symbols used_before_assignment = set() - try: - ordered_states = self.topological_sort(self.start_state) - except ValueError: # Failsafe (e.g., for invalid or empty SDFGs) - ordered_states = self.nodes() - - for state in ordered_states: - free_syms |= state.used_symbols(all_symbols) - - # Add free inter-state symbols - for e in self.out_edges(state): - # NOTE: First we get the true InterstateEdge free symbols, then we compute the newly defined symbols by - # subracting the (true) free symbols from the edge's assignment keys. This way we can correctly - # compute the symbols that are used before being assigned. - efsyms = e.data.used_symbols(all_symbols) - defined_syms |= set(e.data.assignments.keys()) - efsyms - used_before_assignment.update(efsyms - defined_syms) - free_syms |= efsyms + b_free_syms, b_defined_syms, b_used_before_syms = super().used_symbols(all_symbols) + free_syms |= b_free_syms + defined_syms |= b_defined_syms + used_before_assignment |= b_used_before_syms # Remove symbols that were used before they were assigned defined_syms -= used_before_assignment # Subtract symbols defined in inter-state edges and constants - return free_syms - defined_syms + free_syms -= defined_syms + return free_syms, defined_syms, used_before_assignment def arglist(self, scalars_only=False, free_symbols=None) -> Dict[str, dt.Data]: """ @@ -1322,7 +1302,7 @@ def arglist(self, scalars_only=False, free_symbols=None) -> Dict[str, dt.Data]: } # Add global free symbols used in the generated code to scalar arguments - free_symbols = free_symbols if free_symbols is not None else self.used_symbols(all_symbols=False) + free_symbols = free_symbols if free_symbols is not None else self.used_symbols(all_symbols=False)[0] scalar_args.update({k: dt.Scalar(self.symbols[k]) for k in free_symbols if not k.startswith('__dace')}) # Fill up ordered dictionary @@ -1528,8 +1508,8 @@ def from_file(filename: str) -> 'SDFG': # Dynamic SDFG creation API ############################## - def add_state(self, label=None, is_start_block=False, parent_sdfg=None) -> 'SDFGState': - return super().add_state(label, is_start_block, self if parent_sdfg is None else parent_sdfg) + def add_state(self, label=None, is_start_block=False) -> 'SDFGState': + return super().add_state(label, is_start_block) def add_state_before(self, state: 'SDFGState', label=None, is_start_state=False) -> 'SDFGState': """ Adds a new SDFG state before an existing state, reconnecting diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 7ab6e66e83..aef522bcf1 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -76,9 +76,6 @@ class BlockGraphView(abc.ABC): methods. """ - def __init__(self, *args, **kwargs): - self._clear_scopedict_cache() - ################################################################### # Typing overrides @@ -189,86 +186,11 @@ def edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[Multi """ raise NotImplementedError() - ################################################################### - # Scope-related methods - - def _clear_scopedict_cache(self): - """ - Clears the cached results for the scope_dict function. - For use when the graph mutates (e.g., new edges/nodes, deletions). - """ - self._scope_dict_toparent_cached = None - self._scope_dict_tochildren_cached = None - self._scope_tree_cached = None - self._scope_leaves_cached = None - # TODO: needs to be bubbled up to parents or not cached in control graph views. - - def scope_tree(self) -> 'dace.sdfg.scope.ScopeTree': - from dace.sdfg.scope import ScopeTree - - if (hasattr(self, '_scope_tree_cached') and self._scope_tree_cached is not None): - return copy.copy(self._scope_tree_cached) - - sdp = self.scope_dict() - sdc = self.scope_children() - - result = {} - - # Get scopes - for node, scopenodes in sdc.items(): - if node is None: - exit_node = None - else: - exit_node = next(v for v in scopenodes if isinstance(v, nd.ExitNode)) - scope = ScopeTree(node, exit_node) - result[node] = scope - - # Scope parents and children - for node, scope in result.items(): - if node is not None: - scope.parent = result[sdp[node]] - scope.children = [result[n] for n in sdc[node] if isinstance(n, nd.EntryNode)] - - self._scope_tree_cached = result - - return copy.copy(self._scope_tree_cached) - - def scope_leaves(self) -> List['dace.sdfg.scope.ScopeTree']: - if (hasattr(self, '_scope_leaves_cached') and self._scope_leaves_cached is not None): - return copy.copy(self._scope_leaves_cached) - st = self.scope_tree() - self._scope_leaves_cached = [scope for scope in st.values() if len(scope.children) == 0] - return copy.copy(self._scope_leaves_cached) - - @abc.abstractmethod - def scope_dict(self, validate: bool = True) -> Dict[nd.Node, Union[nd.Node, 'SDFGState']]: - """ - Returns a dictionary that maps each SDFG node to its parent entry node, or to their parent state if not in any - scope. - - :param return_ids: Return node ID numbers instead of node objects. - :param validate: Ensure that the graph is not malformed when computing dictionary. - :return: The mapping from a node to its parent scope entry node. - """ - raise NotImplementedError() - - @abc.abstractmethod - def scope_children(self, - validate: bool = True) -> Dict[Union['SDFGState', nd.EntryNode], List[nd.Node]]: - """ - Returns a dictionary that maps each SDFG entry node to its children, not including the children of children - entry nodes. - A list of top-level nodes (i.e., not in any scope) are keyed by their parent state in the dictionary. - - :param validate: Ensure that the graph is not malformed when computing dictionary. - :return: The mapping from a node to a list of children nodes. - """ - raise NotImplementedError() - ################################################################### # Query, subgraph, and replacement methods - def used_symbols(self, all_symbols: bool) -> Set[str]: + @abc.abstractmethod + def used_symbols(self, all_symbols: bool) -> Tuple[Set[str], Set[str], Set[str]]: """ Returns a set of symbol names that are used in the graph. @@ -285,7 +207,7 @@ def free_symbols(self) -> Set[str]: :note: Assumes that the graph is valid (i.e., without undefined or overlapping symbols). """ - return self.used_symbols(all_symbols=True) + return self.used_symbols(all_symbols=True)[0] @abc.abstractmethod def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: @@ -384,7 +306,7 @@ def replace_dict(self, class DataflowGraphView(BlockGraphView, abc.ABC): def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + self._clear_scopedict_cache() ################################################################### # Typing overrides @@ -555,6 +477,53 @@ def edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[Multi ################################################################### # Scope-related methods + def _clear_scopedict_cache(self): + """ + Clears the cached results for the scope_dict function. + For use when the graph mutates (e.g., new edges/nodes, deletions). + """ + self._scope_dict_toparent_cached = None + self._scope_dict_tochildren_cached = None + self._scope_tree_cached = None + self._scope_leaves_cached = None + + def scope_tree(self) -> 'dace.sdfg.scope.ScopeTree': + from dace.sdfg.scope import ScopeTree + + if (hasattr(self, '_scope_tree_cached') and self._scope_tree_cached is not None): + return copy.copy(self._scope_tree_cached) + + sdp = self.scope_dict() + sdc = self.scope_children() + + result = {} + + # Get scopes + for node, scopenodes in sdc.items(): + if node is None: + exit_node = None + else: + exit_node = next(v for v in scopenodes if isinstance(v, nd.ExitNode)) + scope = ScopeTree(node, exit_node) + result[node] = scope + + # Scope parents and children + for node, scope in result.items(): + if node is not None: + scope.parent = result[sdp[node]] + scope.children = [result[n] for n in sdc[node] if isinstance(n, nd.EntryNode)] + + self._scope_tree_cached = result + + return copy.copy(self._scope_tree_cached) + + def scope_leaves(self) -> List['dace.sdfg.scope.ScopeTree']: + if (hasattr(self, '_scope_leaves_cached') and self._scope_leaves_cached is not None): + return copy.copy(self._scope_leaves_cached) + st = self.scope_tree() + self._scope_leaves_cached = [scope for scope in st.values() if len(scope.children) == 0] + return copy.copy(self._scope_leaves_cached) + def scope_dict(self, validate: bool = True) -> Dict[nd.Node, Union['SDFGState', nd.Node]]: from dace.sdfg.scope import _scope_dict_inner if self._scope_dict_toparent_cached is not None: @@ -562,7 +531,7 @@ def scope_dict(self, validate: bool = True) -> Dict[nd.Node, Union['SDFGState', else: result = {} node_queue = collections.deque(self.source_nodes()) - eq = _scope_dict_inner(self, node_queue, self, False, result) + eq = _scope_dict_inner(self, node_queue, None, False, result) # Sanity checks if validate and len(eq) != 0: @@ -589,7 +558,7 @@ def scope_children(self, validate: bool = True) -> Dict[Union[nd.Node, 'SDFGStat else: result = {} node_queue = collections.deque(self.source_nodes()) - eq = _scope_dict_inner(self, node_queue, self, True, result) + eq = _scope_dict_inner(self, node_queue, None, True, result) # Sanity checks if validate and len(eq) != 0: @@ -612,9 +581,9 @@ def scope_children(self, validate: bool = True) -> Dict[Union[nd.Node, 'SDFGStat ################################################################### # Query, subgraph, and replacement methods - def used_symbols(self, all_symbols: bool) -> Set[str]: - state = self.graph if isinstance(self, SubgraphView) else self - sdfg = state.parent + def used_symbols(self, all_symbols: bool) -> Tuple[Set[str], Set[str], Set[str]]: + state: dace.SDFGState = self.graph if isinstance(self, SubgraphView) else self + sdfg = state.sdfg new_symbols = set() freesyms = set() @@ -655,7 +624,7 @@ def _is_leaf_memlet(e): # Do not consider SDFG constants as symbols new_symbols.update(set(sdfg.constants.keys())) - return freesyms - new_symbols + return freesyms - new_symbols, new_symbols, set() def defined_symbols(self) -> Dict[str, dt.Data]: state = self.graph if isinstance(self, SubgraphView) else self @@ -860,9 +829,6 @@ def replace_dict(self, @make_properties class ControlGraphView(BlockGraphView, abc.ABC): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - ################################################################### # Typing overrides @@ -938,35 +904,6 @@ def edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[Multi if node in block.nodes(): return block.edges_by_connector(node, connector) - ################################################################### - # Scope-related methods - - def scope_dict(self, validate: bool = True) -> Dict[nd.Node, Union[nd.Node, 'SDFGState']]: - if self._scope_dict_toparent_cached is not None: - return copy.copy(self._scope_dict_toparent_cached) - else: - result = {} - - for block in self.nodes(): - result.update(block.scope_dict(validate)) - - # Cache result. - self._scope_dict_toparent_cached = result - return copy.copy(result) - - def scope_children(self, validate: bool = True) -> Dict[Union[nd.Node, 'SDFGState'], List[nd.Node]]: - if self._scope_dict_tochildren_cached is not None: - return copy.copy(self._scope_dict_tochildren_cached) - else: - result = {} - - for block in self.nodes(): - result.update(block.scope_children(validate)) - - # Cache result - self._scope_dict_tochildren_cached = result - return copy.copy(result) - ################################################################### # Query, subgraph, and replacement methods @@ -1025,13 +962,18 @@ class ControlFlowBlock(BlockGraphView, abc.ABC): is_collapsed = Property(dtype=bool, desc='Show this block as collapsed', default=False) - _parent_cfg: Optional['ControlFlowGraph'] = None + _sdfg: Optional['dace.SDFG'] = None + _parent: Optional['ControlFlowBlock'] = None _label: str - def __init__(self, label: str='', parent: Optional['ControlFlowGraph']=None): + def __init__(self, + label: str='', + parent: Optional['ControlFlowBlock']=None, + sdfg: Optional['dace.SDFG'] = None): super(ControlFlowBlock, self).__init__() self._label = label - self._parent_cfg = parent + self._parent = parent + self._sdfg = sdfg self._default_lineinfo = None self.is_collapsed = False @@ -1070,13 +1012,21 @@ def name(self) -> str: return self._label @property - def parent_cfg(self): - """ Returns the parent graph of this block. """ - return self._parent_cfg + def parent(self) -> Optional['ControlFlowBlock']: + """ Returns the parent block of this block. """ + return self._parent + + @parent.setter + def parent(self, block: Optional['ControlFlowBlock']): + self._parent = block + + @property + def sdfg(self) -> Optional['dace.SDFG']: + return self._sdfg - @parent_cfg.setter - def parent_cfg(self, value): - self._parent_cfg = value + @sdfg.setter + def sdfg(self, sdfg: Optional['dace.SDFG']): + self._sdfg = sdfg @make_properties @@ -1114,7 +1064,7 @@ class SDFGState(OrderedMultiDiConnectorGraph[nd.Node, mm.Memlet], ControlFlowBlo def __repr__(self) -> str: return f"SDFGState ({self.label})" - def __init__(self, label=None, sdfg=None, debuginfo=None, location=None, cfg=None): + def __init__(self, label=None, sdfg=None, debuginfo=None, location=None, parent=None): """ Constructs an SDFG state. :param label: Name for the state (optional). @@ -1123,8 +1073,7 @@ def __init__(self, label=None, sdfg=None, debuginfo=None, location=None, cfg=Non """ from dace.sdfg.sdfg import SDFG # Avoid import loop OrderedMultiDiConnectorGraph.__init__(self) - ControlFlowBlock.__init__(self, label, sdfg) - self._parent: Optional[SDFG] = sdfg + ControlFlowBlock.__init__(self, label, parent, sdfg) self._graph = self # Allowing MemletTrackingView mixin to work self._clear_scopedict_cache() self._debuginfo = debuginfo @@ -1147,15 +1096,6 @@ def __deepcopy__(self, memo): pass return result - @property - def parent(self): - """ Returns the parent SDFG of this state. """ - return self._parent - - @parent.setter - def parent(self, value): - self._parent = value - def is_empty(self): return self.number_of_nodes() == 0 @@ -1182,7 +1122,7 @@ def add_node(self, node): # Correct nested SDFG's parent attributes if isinstance(node, nd.NestedSDFG): node.sdfg.parent = self - node.sdfg.parent_sdfg = self.parent + node.sdfg.parent_sdfg = self.sdfg node.sdfg.parent_nsdfg_node = node self._clear_scopedict_cache() return super(SDFGState, self).add_node(node) @@ -1210,7 +1150,7 @@ def add_edge(self, u, u_connector, v, v_connector, memlet): self._clear_scopedict_cache() result = super(SDFGState, self).add_edge(u, u_connector, v, v_connector, memlet) - memlet.try_initialize(self.parent, self, result) + memlet.try_initialize(self.sdfg, self, result) return result def remove_edge(self, edge): @@ -1237,7 +1177,7 @@ def to_json(self, parent=None): # Try to initialize edges before serialization for edge in self.edges(): - edge.data.try_initialize(self.parent, self, edge) + edge.data.try_initialize(self.sdfg, self, edge) ret = { 'type': type(self).__name__, @@ -1309,7 +1249,7 @@ def _repr_html_(self): from dace.sdfg import SDFG arrays = set(n.data for n in self.data_nodes()) sdfg = SDFG(self.label) - sdfg._arrays = {k: self._parent.arrays[k] for k in arrays} + sdfg._arrays = {k: self.sdfg.arrays[k] for k in arrays} sdfg.add_node(self) return sdfg._repr_html_() @@ -1329,7 +1269,7 @@ def symbols_defined_at(self, node: nd.Node) -> Dict[str, dtypes.typeclass]: if node is None: return collections.OrderedDict() - sdfg: SDFG = self.parent + sdfg: SDFG = self.sdfg # Start with global symbols symbols = collections.OrderedDict(sdfg.symbols) @@ -1471,7 +1411,7 @@ def add_nested_sdfg( debuginfo = _getdebuginfo(debuginfo or self._default_lineinfo) sdfg.parent = self - sdfg.parent_sdfg = self.parent + sdfg.parent_sdfg = self.sdfg sdfg.update_sdfg_list([]) @@ -1737,7 +1677,7 @@ def add_mapped_tasklet(self, # Try to initialize memlets for edge in edges: - edge.data.try_initialize(self.parent, self, edge) + edge.data.try_initialize(self.sdfg, self, edge) return tasklet, map_entry, map_exit @@ -1919,8 +1859,8 @@ def add_edge_pair( ) # Try to initialize memlets - iedge.data.try_initialize(self.parent, self, iedge) - eedge.data.try_initialize(self.parent, self, eedge) + iedge.data.try_initialize(self.sdfg, self, iedge) + eedge.data.try_initialize(self.sdfg, self, eedge) return (iedge, eedge) @@ -2023,7 +1963,7 @@ def add_memlet_path(self, *path_nodes, memlet=None, src_conn=None, dst_conn=None cur_memlet = propagate_memlet(self, cur_memlet, snode, True) # Try to initialize memlets for edge in edges: - edge.data.try_initialize(self.parent, self, edge) + edge.data.try_initialize(self.sdfg, self, edge) def remove_memlet_path(self, edge: MultiConnectorEdge, remove_orphans: bool = True) -> None: """ Removes all memlets and associated connectors along a path formed @@ -2120,9 +2060,9 @@ def add_array(self, 'The "SDFGState.add_array" API is deprecated, please ' 'use "SDFG.add_array" and "SDFGState.add_access"', DeprecationWarning) # Workaround to allow this legacy API - if name in self.parent._arrays: - del self.parent._arrays[name] - self.parent.add_array(name, + if name in self.sdfg._arrays: + del self.sdfg._arrays[name] + self.sdfg.add_array(name, shape, dtype, storage=storage, @@ -2153,9 +2093,9 @@ def add_stream( 'The "SDFGState.add_stream" API is deprecated, please ' 'use "SDFG.add_stream" and "SDFGState.add_access"', DeprecationWarning) # Workaround to allow this legacy API - if name in self.parent._arrays: - del self.parent._arrays[name] - self.parent.add_stream( + if name in self.sdfg._arrays: + del self.sdfg._arrays[name] + self.sdfg.add_stream( name, dtype, buffer_size, @@ -2182,9 +2122,9 @@ def add_scalar( 'The "SDFGState.add_scalar" API is deprecated, please ' 'use "SDFG.add_scalar" and "SDFGState.add_access"', DeprecationWarning) # Workaround to allow this legacy API - if name in self.parent._arrays: - del self.parent._arrays[name] - self.parent.add_scalar(name, dtype, storage, transient, lifetime, debuginfo) + if name in self.sdfg._arrays: + del self.sdfg._arrays[name] + self.sdfg.add_scalar(name, dtype, storage, transient, lifetime, debuginfo) return self.add_access(name, debuginfo) def add_transient(self, @@ -2297,44 +2237,6 @@ def __init__(self): self._start_block: Optional[int] = None self._cached_start_block: Optional[ControlFlowBlock] = None - def add_edge(self, src: ControlFlowBlock, dst: ControlFlowBlock, data: 'dace.sdfg.InterstateEdge'): - """ Adds a new edge to the graph. Must be an InterstateEdge or a subclass thereof. - - :param u: Source node. - :param v: Destination node. - :param edge: The edge to add. - """ - if not isinstance(src, ControlFlowBlock): - raise TypeError('Expected ControlFlowBlock, got ' + str(type(src))) - if not isinstance(dst, ControlFlowBlock): - raise TypeError('Expected ControlFlowBlock, got ' + str(type(dst))) - if not isinstance(data, dace.sdfg.InterstateEdge): - raise TypeError('Expected InterstateEdge, got ' + str(type(data))) - if dst is self._cached_start_block: - self._cached_start_block = None - return super(ControlFlowGraph, self).add_edge(src, dst, data) - - def add_node(self, node, is_start_block=False): - if not isinstance(node, ControlFlowBlock): - raise TypeError('Expected ControlFlowBlock, got ' + str(type(node))) - super().add_node(node) - node.parent_cfg = self - self._cached_start_block = None - if is_start_block is True: - self.start_block = len(self.nodes()) - 1 - self._cached_start_block = node - - def add_state(self, label=None, is_start_block=False, parent_sdfg=None) -> SDFGState: - if self._labels is None or len(self._labels) != self.number_of_nodes(): - self._labels = set(s.label for s in self.nodes()) - label = label or 'state' - existing_labels = self._labels - label = dt.find_new_name(label, existing_labels) - state = SDFGState(label, parent_sdfg) - self._labels.add(label) - self.add_node(state, is_start_block=is_start_block) - return state - ################################################################### # Traversal methods @@ -2405,35 +2307,56 @@ def start_block(self, block_id): @make_properties class ScopeBlock(ControlFlowGraph, ControlFlowBlock): - # TODO: instrumentation - _parent_sdfg: Optional['dace.SDFG'] = None - _parent_cfg: Optional[ControlFlowGraph] = None - - def __init__(self, label: str='', parent: Optional[ControlFlowGraph]=None, sdfg: Optional['dace.SDFG']=None): + def __init__(self, + label: str='', + parent: Optional['ControlFlowBlock']=None, + sdfg: Optional['dace.SDFG'] = None): ControlFlowGraph.__init__(self) - ControlFlowBlock.__init__(self, label, parent) - self._parent_cfg = parent - self._parent_sdfg = sdfg + ControlFlowBlock.__init__(self, label, parent, sdfg) - ''' - def used_symbols(self, all_symbols: bool) -> Set[str]: - defined_syms = set() - free_syms = set() + def add_edge(self, src: ControlFlowBlock, dst: ControlFlowBlock, data: 'dace.sdfg.InterstateEdge'): + """ Adds a new edge to the graph. Must be an InterstateEdge or a subclass thereof. - # Exclude data descriptor names, constants, and shapes of global data descriptors - not_strictly_necessary_global_symbols = set() - sdfg: dace.SDFG = self._parent_sdfg if self._parent_sdfg is not None else self - for name, desc in sdfg.arrays.items(): - defined_syms.add(name) + :param u: Source node. + :param v: Destination node. + :param edge: The edge to add. + """ + if not isinstance(src, ControlFlowBlock): + raise TypeError('Expected ControlFlowBlock, got ' + str(type(src))) + if not isinstance(dst, ControlFlowBlock): + raise TypeError('Expected ControlFlowBlock, got ' + str(type(dst))) + if not isinstance(data, dace.sdfg.InterstateEdge): + raise TypeError('Expected InterstateEdge, got ' + str(type(data))) + if dst is self._cached_start_block: + self._cached_start_block = None + return super().add_edge(src, dst, data) - if not all_symbols: - used_desc_symbols = desc.used_symbols(all_symbols) - not_strictly_necessary = (desc.used_symbols(all_symbols=True) - used_desc_symbols) - not_strictly_necessary_global_symbols |= set(map(str, not_strictly_necessary)) + def add_node(self, node, is_start_block=False): + if not isinstance(node, ControlFlowBlock): + raise TypeError('Expected ControlFlowBlock, got ' + str(type(node))) + super().add_node(node) + node.parent = self + node.sdfg = self if isinstance(self, dace.SDFG) else self.sdfg + self._cached_start_block = None + if is_start_block is True: + self.start_block = len(self.nodes()) - 1 + self._cached_start_block = node - defined_syms |= set(sdfg.constants_prop.keys()) + def add_state(self, label=None, is_start_block=False) -> SDFGState: + if self._labels is None or len(self._labels) != self.number_of_nodes(): + self._labels = set(s.label for s in self.nodes()) + label = label or 'state' + existing_labels = self._labels + label = dt.find_new_name(label, existing_labels) + state = SDFGState(label) + self._labels.add(label) + self.add_node(state, is_start_block=is_start_block) + return state - # Add free state symbols + @abc.abstractmethod + def used_symbols(self, all_symbols: bool) -> Tuple[Set[str], Set[str], Set[str]]: + defined_syms = set() + free_syms = set() used_before_assignment = set() try: @@ -2442,7 +2365,10 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: ordered_blocks = self.nodes() for block in ordered_blocks: - free_syms |= block.used_symbols(all_symbols) + b_free_syms, b_defined_syms, b_used_before_syms = block.used_symbols(all_symbols) + free_syms |= b_free_syms + defined_syms |= b_defined_syms + used_before_assignment |= b_used_before_syms # Add free inter-state symbols for e in self.out_edges(block): @@ -2454,12 +2380,12 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: used_before_assignment.update(efsyms - defined_syms) free_syms |= efsyms - # Remove symbols that were used before they were assigned + # Remove symbols that were used before they were assigned. defined_syms -= used_before_assignment - # Subtract symbols defined in inter-state edges and constants - return free_syms - defined_syms - ''' + # Subtract symbols defined in inter-state edges and constants from the list of free symbols. + free_syms -= defined_syms + return free_syms, defined_syms, used_before_assignment def to_json(self, parent=None): graph_json = ControlFlowGraph.to_json(self) @@ -2470,23 +2396,6 @@ def to_json(self, parent=None): ################################################################### # Getters & setters, overrides - @property - def parent_cfg(self) -> Optional['ControlFlowGraph']: - return self._parent_cfg - - @parent_cfg.setter - def parent_cfg(self, parent_cfg: 'ControlFlowGraph') -> None: - self._parent_cfg = parent_cfg - - @property - def parent_sdfg(self) -> 'dace.SDFG': - """ Returns the parent SDFG of this control flow graph, if exists. """ - return self._parent_sdfg - - @parent_sdfg.setter - def parent_sdfg(self, value): - self._parent_sdfg = value - def __str__(self): return ControlFlowBlock.__str__(self) @@ -2501,6 +2410,7 @@ class LoopScopeBlock(ScopeBlock): init_statement = CodeProperty(optional=True, allow_none=True, default=None) scope_condition = CodeProperty(allow_none=True, default=None) inverted = Property(dtype=bool, default=False) + loop_variable = Property(dtype=str, default='') def __init__(self, loop_var: str, @@ -2509,8 +2419,9 @@ def __init__(self, update_expr: str, label: str = '', parent: Optional[ControlFlowGraph] = None, + sdfg: Optional['dace.SDFG'] = None, inverted: bool = False): - super(LoopScopeBlock, self).__init__(label, parent) + super(LoopScopeBlock, self).__init__(label, parent, sdfg) if initialize_expr is not None: self.init_statement = CodeBlock('%s = %s' % (loop_var, initialize_expr)) @@ -2527,13 +2438,28 @@ def __init__(self, else: self.update_statement = None + self.loop_variable = loop_var self.inverted = inverted - def used_symbols(self, all_symbols: bool) -> Set[str]: - symbols = set() + def used_symbols(self, all_symbols: bool) -> Tuple[Set[str], Set[str], Set[str]]: + free_symbols = set() + defined_symbols = set(self.loop_variable) + used_before_assignment = set() if self.init_statement is not None: - symbols |= self.init_statement.used_symbols(all_symbols) - return symbols() + free_symbols |= self.init_statement.get_free_symbols() + if self.update_statement is not None: + free_symbols |= self.update_statement.get_free_symbols() + free_symbols |= self.scope_condition.get_free_symbols() + + b_free_symbols, b_defined_symbols, b_used_before_assignment = super().used_symbols(all_symbols) + free_symbols |= b_free_symbols + defined_symbols |= b_defined_symbols + used_before_assignment |= b_used_before_assignment + + defined_symbols -= used_before_assignment + free_symbols -= defined_symbols + + return free_symbols, defined_symbols, used_before_assignment def to_json(self, parent=None): return super().to_json(parent) diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 1a01bce92c..f6081efbcc 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -765,7 +765,7 @@ def get_last_view_node(state: SDFGState, view: nd.AccessNode) -> nd.AccessNode: Given a view access node, returns the last viewed access node if existent, else None """ - sdfg = state.parent + sdfg = state.sdfg node = view desc = sdfg.arrays[node.data] while isinstance(desc, dt.View): @@ -781,7 +781,7 @@ def get_all_view_nodes(state: SDFGState, view: nd.AccessNode) -> List[nd.AccessN Given a view access node, returns a list of viewed access nodes if existent, else None """ - sdfg = state.parent + sdfg = state.sdfg node = view desc = sdfg.arrays[node.data] result = [node] @@ -910,10 +910,10 @@ def is_parallel(state: SDFGState, node: Optional[nd.Node] = None) -> bool: curnode = sdict[curnode] if curnode.schedule != dtypes.ScheduleType.Sequential: return True - if state.parent.parent is not None: + if state.sdfg.parent is not None: # Find nested SDFG node and continue recursion - nsdfg_node = next(n for n in state.parent.parent if isinstance(n, nd.NestedSDFG) and n.sdfg == state.parent) - return is_parallel(state.parent.parent, nsdfg_node) + nsdfg_node = next(n for n in state.sdfg.parent if isinstance(n, nd.NestedSDFG) and n.sdfg == state.sdfg) + return is_parallel(state.sdfg.parent, nsdfg_node) return False @@ -1269,18 +1269,18 @@ def inline_sdfgs(sdfg: SDFG, permissive: bool = False, progress: bool = None, mu for node, state in optional_progressbar(reversed(nsdfgs), title='Inlining SDFGs', n=len(nsdfgs), progress=progress): id = node.sdfg.sdfg_id - sd = state.parent + graph = state.parent # We have to reevaluate every time due to changing IDs - state_id = sd.node_id(state) + state_id = graph.node_id(state) if multistate: candidate = { InlineMultistateSDFG.nested_sdfg: node, } inliner = InlineMultistateSDFG() - inliner.setup_match(sd, id, state_id, candidate, 0, override=True) - if inliner.can_be_applied(state, 0, sd, permissive=permissive): - inliner.apply(state, sd) + inliner.setup_match(graph, id, state_id, candidate, 0, override=True) + if inliner.can_be_applied(state, 0, graph, permissive=permissive): + inliner.apply(state, state.sdfg) counter += 1 continue @@ -1288,9 +1288,9 @@ def inline_sdfgs(sdfg: SDFG, permissive: bool = False, progress: bool = None, mu InlineSDFG.nested_sdfg: node, } inliner = InlineSDFG() - inliner.setup_match(sd, id, state_id, candidate, 0, override=True) - if inliner.can_be_applied(state, 0, sd, permissive=permissive): - inliner.apply(state, sd) + inliner.setup_match(graph, id, state_id, candidate, 0, override=True) + if inliner.can_be_applied(state, 0, graph, permissive=permissive): + inliner.apply(state, state.sdfg) counter += 1 return counter @@ -1379,7 +1379,7 @@ def unique_node_repr(graph: Union[SDFGState, ScopeSubgraphView], node: Node) -> """ # Build a unique representation - sdfg = graph.parent + sdfg = graph.sdfg state = graph if isinstance(graph, SDFGState) else graph._graph return str(sdfg.sdfg_id) + "_" + str(sdfg.node_id(state)) + "_" + str(state.node_id(node)) @@ -1413,7 +1413,7 @@ def is_nonfree_sym_dependent(node: nd.AccessNode, desc: dt.Data, state: SDFGStat # is the View. n = get_view_node(state, node) if n and isinstance(n, nd.AccessNode): - d = state.parent.arrays[n.data] + d = state.sdfg.arrays[n.data] return is_nonfree_sym_dependent(n, d, state, fsymbols) elif isinstance(desc, dt.Array): if any(str(s) not in fsymbols for s in desc.free_symbols): diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index ab6f7410b1..a2976c760a 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -357,14 +357,12 @@ def validate_state(state: 'dace.sdfg.SDFGState', if not dtypes.validate_name(state._label): raise InvalidSDFGError("Invalid state name", sdfg, state_id) - # TODO: set state SDFG and validate here. Parent won't point to the same thing. - #if state._parent != sdfg: - # raise InvalidSDFGError("State does not point to the correct " - # "parent", sdfg, state_id) + if state.sdfg != sdfg: + raise InvalidSDFGError("State does not point to the correct sdfg", sdfg, state_id) # Unreachable ######################################## - parent = state.parent + parent = state.sdfg if (parent.number_of_nodes() > 1 and parent.in_degree(state) == 0 and parent.out_degree(state) == 0): raise InvalidSDFGError("Unreachable state", sdfg, state_id) diff --git a/dace/sdfg/work_depth_analysis/helpers.py b/dace/sdfg/work_depth_analysis/helpers.py index a80e769f64..9e1c5ba265 100644 --- a/dace/sdfg/work_depth_analysis/helpers.py +++ b/dace/sdfg/work_depth_analysis/helpers.py @@ -34,7 +34,7 @@ def get_uuid(element, state=None): if isinstance(element, SDFG): return ids_to_string(element.sdfg_id) elif isinstance(element, SDFGState): - return ids_to_string(element.parent.sdfg_id, element.parent.node_id(element)) + return ids_to_string(element.sdfg.sdfg_id, element.sdfg.node_id(element)) elif isinstance(element, nodes.Node): return ids_to_string(state.parent.sdfg_id, state.parent.node_id(state), state.node_id(element)) else: diff --git a/dace/transformation/auto/auto_optimize.py b/dace/transformation/auto/auto_optimize.py index 54dbc8d4ac..50e8e021e5 100644 --- a/dace/transformation/auto/auto_optimize.py +++ b/dace/transformation/auto/auto_optimize.py @@ -73,7 +73,7 @@ def greedy_fuse(graph_or_subgraph: GraphViewType, # we are in graph or subgraph sdfg, graph, subgraph = None, None, None if isinstance(graph_or_subgraph, SDFGState): - sdfg = graph_or_subgraph.parent + sdfg = graph_or_subgraph.sdfg sdfg.apply_transformations_repeated(MapFusion, validate_all=validate_all) graph = graph_or_subgraph subgraph = SubgraphView(graph, graph.nodes()) @@ -196,7 +196,7 @@ def tile_wcrs(graph_or_subgraph: GraphViewType, validate_all: bool, prefer_parti return if not isinstance(graph, SDFGState): raise TypeError('Graph must be a state, an SDFG, or a subgraph of either') - sdfg = graph.parent + sdfg = graph.sdfg edges_to_consider: Set[Tuple[gr.MultiConnectorEdge[Memlet], nodes.MapEntry]] = set() for edge in graph_or_subgraph.edges(): diff --git a/dace/transformation/dataflow/hbm_transform.py b/dace/transformation/dataflow/hbm_transform.py index 15076a0ca6..e4e2aa277c 100644 --- a/dace/transformation/dataflow/hbm_transform.py +++ b/dace/transformation/dataflow/hbm_transform.py @@ -83,7 +83,7 @@ def _update_memlet_hbm(state: SDFGState, inner_edge: graph.MultiConnectorEdge, # If the memlet already contains the distributed subset, ignore it # That's helpful because of inconsistencies when nesting and because # one can 'hint' the correct bank assignment when using this function - if len(mem.subset) == len(state.parent.arrays[this_node.data].shape): + if len(mem.subset) == len(state.sdfg.arrays[this_node.data].shape): return new_subset = subsets.Range([[inner_subset_index, inner_subset_index, 1]] + [x for x in mem.subset]) @@ -100,7 +100,7 @@ def _update_memlet_hbm(state: SDFGState, inner_edge: graph.MultiConnectorEdge, other_node, nd.NestedSDFG): # Ignore those and update them via propagation new_subset = subsets.Range.from_array( - state.parent.arrays[this_node.data]) + state.sdfg.arrays[this_node.data]) if isinstance(other_node, nd.AccessNode): fwtasklet = state.add_tasklet("fwtasklet", set(["_in"]), set(["_out"]), diff --git a/dace/transformation/dataflow/streaming_memory.py b/dace/transformation/dataflow/streaming_memory.py index 4cf40b30bf..93c86a0d0e 100644 --- a/dace/transformation/dataflow/streaming_memory.py +++ b/dace/transformation/dataflow/streaming_memory.py @@ -160,10 +160,10 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi while curstate is not None: if curstate.entry_node(node) is not None: return False - if curstate.parent.parent_nsdfg_node is None: + if curstate.sdfg.parent_nsdfg_node is None: break - node = curstate.parent.parent_nsdfg_node - curstate = curstate.parent.parent + node = curstate.sdfg.parent_nsdfg_node + curstate = curstate.sdfg.parent # Only one memlet path is allowed per outgoing/incoming edge edges = (graph.out_edges(access) if expr_index == 0 else graph.in_edges(access)) @@ -628,10 +628,10 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi while curstate is not None: if curstate.entry_node(node) is not None: return False - if curstate.parent.parent_nsdfg_node is None: + if curstate.sdfg.parent_nsdfg_node is None: break - node = curstate.parent.parent_nsdfg_node - curstate = curstate.parent.parent + node = curstate.sdfg.parent_nsdfg_node + curstate = curstate.sdfg.parent # Array must not be used anywhere else in the state if any(n is not access and n.data == access.data for n in graph.data_nodes()): diff --git a/dace/transformation/dataflow/warp_tiling.py b/dace/transformation/dataflow/warp_tiling.py index 211910eebf..7639794943 100644 --- a/dace/transformation/dataflow/warp_tiling.py +++ b/dace/transformation/dataflow/warp_tiling.py @@ -55,7 +55,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG) -> nodes.MapEntry: # Stride and offset all internal maps maps_to_stride = xfh.get_internal_scopes(graph, new_me, immediate=True) for nstate, nmap in maps_to_stride: - nsdfg = nstate.parent + nsdfg = nstate.sdfg nsdfg_node = nsdfg.parent_nsdfg_node # Map cannot be partitioned across a warp diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 73da318e94..16517de9e2 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -376,7 +376,7 @@ def nest_state_subgraph(sdfg: SDFG, SDFG. :raise ValueError: The subgraph is contained in more than one scope. """ - if state.parent != sdfg: + if state.sdfg != sdfg: raise KeyError('State does not belong to given SDFG') if subgraph is not state and subgraph.graph is not state: raise KeyError('Subgraph does not belong to given state') @@ -1002,7 +1002,7 @@ def simplify_state(state: SDFGState, remove_views: bool = False) -> MultiDiGraph :return: The MultiDiGraph object. """ - sdfg = state.parent + sdfg = state.sdfg # Copy the whole state G = MultiDiGraph() @@ -1226,7 +1226,7 @@ def contained_in(state: SDFGState, node: nodes.Node, scope: nodes.EntryNode) -> # A node is contained within itself if node is scope: return True - cursdfg = state.parent + cursdfg = state.sdfg curstate = state curscope = state.entry_node(node) while cursdfg is not None: @@ -1249,7 +1249,7 @@ def get_parent_map(state: SDFGState, node: Optional[nodes.Node] = None) -> Optio :param node: The node to test (optional). :return: A tuple of (entry node, state) or None. """ - cursdfg = state.parent + cursdfg = state.sdfg curstate = state curscope = node while cursdfg is not None: @@ -1318,8 +1318,8 @@ def can_run_state_on_fpga(state: SDFGState): return False # Streams have strict conditions due to code generator limitations - if (isinstance(node, nodes.AccessNode) and isinstance(graph.parent.arrays[node.data], data.Stream)): - nodedesc = graph.parent.arrays[node.data] + if (isinstance(node, nodes.AccessNode) and isinstance(graph.sdfg.arrays[node.data], data.Stream)): + nodedesc = graph.sdfg.arrays[node.data] sdict = graph.scope_dict() if nodedesc.storage in [ dtypes.StorageType.CPU_Heap, dtypes.StorageType.CPU_Pinned, dtypes.StorageType.CPU_ThreadLocal @@ -1331,7 +1331,7 @@ def can_run_state_on_fpga(state: SDFGState): return False # Arrays of streams cannot have symbolic size on FPGA - if symbolic.issymbolic(nodedesc.total_size, graph.parent.constants): + if symbolic.issymbolic(nodedesc.total_size, graph.sdfg.constants): return False # Streams cannot be unbounded on FPGA diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 8fb6600b76..3ed2bd4bea 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -423,7 +423,7 @@ def apply(self, _, sdfg: sd.SDFG): new_body = sdfg.add_state('single_state_body') nsdfg = SDFG("loop_body", constants=sdfg.constants_prop, parent=new_body) nsdfg.add_node(body, is_start_state=True) - body.parent = nsdfg + body.sdfg = nsdfg exit_state = nsdfg.add_state('exit') nsymbols = dict() for state in states: diff --git a/dace/transformation/interstate/sdfg_nesting.py b/dace/transformation/interstate/sdfg_nesting.py index fc3ebfbdca..df75bb9911 100644 --- a/dace/transformation/interstate/sdfg_nesting.py +++ b/dace/transformation/interstate/sdfg_nesting.py @@ -565,7 +565,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): # Fission state if necessary cc = utils.weakly_connected_component(state, node) if not any(n in cc for n in subgraph.nodes()): - helpers.state_fission(state.parent, cc) + helpers.state_fission(state.sdfg, cc) for edge in removed_out_edges: # Find last access node that refers to this edge try: @@ -604,7 +604,7 @@ def _modify_access_to_access(self, """ Deals with access->access edges where both sides are non-transient. """ - nsdfg_node = nstate.parent.parent_nsdfg_node + nsdfg_node = nstate.sdfg.parent_nsdfg_node edges_to_ignore = edges_to_ignore or set() result = set() edges = input_edges diff --git a/dace/transformation/interstate/state_fusion.py b/dace/transformation/interstate/state_fusion.py index 775b7da8cd..84748d11eb 100644 --- a/dace/transformation/interstate/state_fusion.py +++ b/dace/transformation/interstate/state_fusion.py @@ -30,7 +30,7 @@ def __init__(self, first_input_nodes: Set[nodes.AccessNode], first_output_nodes: def top_level_nodes(state: SDFGState): - return state.scope_children()[state] + return state.scope_children()[None] class StateFusion(transformation.MultiStateTransformation): diff --git a/dace/transformation/passes/analysis.py b/dace/transformation/passes/analysis.py index a7532b2761..e0a562f143 100644 --- a/dace/transformation/passes/analysis.py +++ b/dace/transformation/passes/analysis.py @@ -223,7 +223,7 @@ def _find_dominating_write( ) -> Optional[Edge[InterstateEdge]]: last_state: SDFGState = read if isinstance(read, SDFGState) else read.src - in_edges = last_state.parent.in_edges(last_state) + in_edges = last_state.sdfg.in_edges(last_state) deg = len(in_edges) if deg == 0: return None @@ -233,7 +233,7 @@ def _find_dominating_write( write_isedge = None n_state = state_idom[last_state] if state_idom[last_state] != last_state else None while n_state is not None and write_isedge is None: - oedges = n_state.parent.out_edges(n_state) + oedges = n_state.sdfg.out_edges(n_state) odeg = len(oedges) if odeg == 1: if any([sym == k for k in oedges[0].data.assignments.keys()]): @@ -241,7 +241,7 @@ def _find_dominating_write( else: dom_edge = None for cand in oedges: - if nxsp.has_path(n_state.parent.nx, cand.dst, last_state): + if nxsp.has_path(n_state.sdfg.nx, cand.dst, last_state): if dom_edge is not None: dom_edge = None break @@ -279,7 +279,7 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[int, dominators = all_doms[write.dst] reach = state_reach[write.dst] for dom in dominators: - iedges = dom.parent.in_edges(dom) + iedges = dom.sdfg.in_edges(dom) if len(iedges) == 1 and iedges[0] in result[sym]: other_accesses = result[sym][iedges[0]] coarsen = False diff --git a/dace/transformation/passes/simplify.py b/dace/transformation/passes/simplify.py index b128af3fb2..e6f416cc4d 100644 --- a/dace/transformation/passes/simplify.py +++ b/dace/transformation/passes/simplify.py @@ -16,7 +16,7 @@ from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols SIMPLIFY_PASSES = [ - #InlineSDFGs, + InlineSDFGs, #ScalarToSymbolPromotion, FuseStates, #OptionalArrayInference, diff --git a/dace/transformation/transformation.py b/dace/transformation/transformation.py index e5e3dc925a..99f48373c1 100644 --- a/dace/transformation/transformation.py +++ b/dace/transformation/transformation.py @@ -708,7 +708,7 @@ def setup_match(self, subgraph: Union[Set[int], gr.SubgraphView], sdfg_id: int = self.subgraph = set(subgraph.graph.node_id(n) for n in subgraph.nodes()) if isinstance(subgraph.graph, SDFGState): - sdfg = subgraph.graph.parent + sdfg = subgraph.graph.sdfg self.sdfg_id = sdfg.sdfg_id self.state_id = sdfg.node_id(subgraph.graph) elif isinstance(subgraph.graph, SDFG): diff --git a/samples/optimization/matmul.py b/samples/optimization/matmul.py index 06b6a38939..9c8fbbc383 100644 --- a/samples/optimization/matmul.py +++ b/samples/optimization/matmul.py @@ -113,7 +113,7 @@ def optimize_for_cpu(sdfg: dace.SDFG, m: int, n: int, k: int): # Vectorize microkernel map postamble = n % 4 != 0 entry_inner, inner_state = find_map_and_state_by_param(sdfg, 'k') - Vectorization.apply_to(inner_state.parent, + Vectorization.apply_to(inner_state.sdfg, dict(vector_len=4, preamble=False, postamble=postamble), map_entry=entry_inner) diff --git a/tests/transformations/nest_subgraph_test.py b/tests/transformations/nest_subgraph_test.py index 763bb3327d..ec339b907e 100644 --- a/tests/transformations/nest_subgraph_test.py +++ b/tests/transformations/nest_subgraph_test.py @@ -30,7 +30,7 @@ def test_nest_oneelementmap(): for node, state in sdfg.all_nodes_recursive(): if isinstance(node, dace.nodes.MapEntry): subgraph = state.scope_subgraph(node, include_entry=False, include_exit=False) - nest_state_subgraph(state.parent, state, subgraph) + nest_state_subgraph(state.sdfg, state, subgraph) sdfg(A=A, B=B) assert np.allclose(A, B) diff --git a/tests/wcr_cudatest.py b/tests/wcr_cudatest.py index 03a99fb8e1..018445fae6 100644 --- a/tests/wcr_cudatest.py +++ b/tests/wcr_cudatest.py @@ -8,7 +8,7 @@ def create_zero_initialization(init_state: dace.SDFGState, array_name): - sdfg = init_state.parent + sdfg = init_state.sdfg array_shape = sdfg.arrays[array_name].shape array_access_node = init_state.add_write(array_name)