diff --git a/dace/sdfg/analysis/cutout.py b/dace/sdfg/analysis/cutout.py index 5d2eae7c6f..ec95157989 100644 --- a/dace/sdfg/analysis/cutout.py +++ b/dace/sdfg/analysis/cutout.py @@ -118,7 +118,7 @@ def from_json(cls, json_obj, context=None): def from_transformation( cls, sdfg: SDFG, transformation: Union[PatternTransformation, SubgraphTransformation], make_side_effects_global = True, use_alibi_nodes: bool = True, reduce_input_config = True, - symbols_map: Optional[Dict[str, Any]] = None + symbols_map: Optional[Dict[str, Any]] = None, preserve_guids: bool = False ) -> Union['SDFGCutout', SDFG]: """ Create a cutout from a transformation's set of affected graph elements. @@ -130,6 +130,9 @@ def from_transformation( :param reduce_input_config: Whether to reduce the input configuration where possible in singlestate cutouts. :param symbols_map: A mapping of symbols to values to use for the cutout. Optional, only used when reducing the input configuration. + :param preserve_guids: If True, ensures that the GUIDs of graph elements contained in the cutout remain + identical to the ones in their original graph. If False, new GUIDs will be generated. + False by default. :return: The cutout. """ affected_nodes = _transformation_determine_affected_nodes(sdfg, transformation) @@ -150,11 +153,12 @@ def from_transformation( state = target_sdfg.node(transformation.state_id) cutout = cls.singlestate_cutout(state, *affected_nodes, make_side_effects_global=make_side_effects_global, use_alibi_nodes=use_alibi_nodes, reduce_input_config=reduce_input_config, - symbols_map=symbols_map) + symbols_map=symbols_map, preserve_guids=preserve_guids) cutout.translate_transformation_into(transformation) return cutout elif isinstance(transformation, MultiStateTransformation): - cutout = cls.multistate_cutout(*affected_nodes, make_side_effects_global=make_side_effects_global) + cutout = cls.multistate_cutout(*affected_nodes, make_side_effects_global=make_side_effects_global, + preserve_guids=preserve_guids) # If the cutout is an SDFG, there's no need to translate the transformation. if isinstance(cutout, SDFGCutout): cutout.translate_transformation_into(transformation) @@ -169,14 +173,15 @@ def singlestate_cutout(cls, make_side_effects_global: bool = True, use_alibi_nodes: bool = True, reduce_input_config: bool = False, - symbols_map: Optional[Dict[str, Any]] = None) -> 'SDFGCutout': + symbols_map: Optional[Dict[str, Any]] = None, + preserve_guids: bool = False) -> 'SDFGCutout': """ Cut out a subgraph of a state from an SDFG to run separately for localized testing or optimization. The subgraph defined by the list of nodes will be extended to include access nodes of data containers necessary to run the graph separately. In addition, all transient data containers that may contain data when the cutout is executed are made global, as well as any transient data containers which are written to inside the cutout but may be read after the cutout. - + :param state: The SDFG state in which the subgraph resides. :param nodes: The nodes in the subgraph to cut out. :param make_copy: If True, deep-copies every SDFG element in the copy. Otherwise, original references are kept. @@ -188,17 +193,29 @@ def singlestate_cutout(cls, :param reduce_input_config: Whether to reduce the input configuration where possible in singlestate cutouts. :param symbols_map: A mapping of symbols to values to use for the cutout. Optional, only used when reducing the input configuration. + :param preserve_guids: If True, ensures that the GUIDs of graph elements contained in the cutout remain + identical to the ones in their original graph. If False, new GUIDs will be generated. + False by default - if make_copy is False, this has no effect by extension. :return: The created SDFGCutout. """ if reduce_input_config: nodes = _reduce_in_configuration(state, nodes, use_alibi_nodes, symbols_map) - create_element = copy.deepcopy if make_copy else (lambda x: x) + + def clone_f(x: Union[Memlet, InterstateEdge, nd.Node, ControlFlowBlock]): + ret = copy.deepcopy(x) + if preserve_guids: + ret.guid = x.guid + return ret + + create_element = clone_f if make_copy else (lambda x: x) sdfg = state.parent subgraph: StateSubgraphView = StateSubgraphView(state, nodes) subgraph = _extend_subgraph_with_access_nodes(state, subgraph, use_alibi_nodes) # Make a new SDFG with the included constants, used symbols, and data containers. cutout = SDFGCutout(sdfg.name + '_cutout', sdfg.constants_prop) + if preserve_guids: + cutout.guid = sdfg.guid cutout._base_sdfg = sdfg defined_syms = subgraph.defined_symbols() freesyms = subgraph.free_symbols @@ -213,11 +230,24 @@ def singlestate_cutout(cls, memlet = edge.data if memlet.data in cutout.arrays: continue - new_desc = sdfg.arrays[memlet.data].clone() - cutout.add_datadesc(memlet.data, new_desc) + dataname = memlet.data + if '.' in dataname: + # This is an access to a struct memeber, which typically happens for the memlets between an access node + # pointing to a struct (or view thereof), and a view pointing to the member. Assert that this is indeed + # the case (i.e., only one '.' is found in the name of the data being accessed), and if so, clone the + # struct (or struct view) data descriptor instad. + parts = dataname.split('.') + if len(parts) == 2: + dataname = parts[0] + else: + raise RuntimeError('Attempting to add invalid multi-nested data ' + memlet.data + ' to a cutout') + new_desc = sdfg.arrays[dataname].clone() + cutout.add_datadesc(dataname, new_desc) # Add a single state with the extended subgraph new_state = cutout.add_state(state.label, is_start_state=True) + if preserve_guids: + new_state.guid = state.guid in_translation = dict() out_translation = dict() for e in sg_edges: @@ -322,6 +352,7 @@ def singlestate_cutout(cls, def multistate_cutout(cls, *states: SDFGState, make_side_effects_global: bool = True, + preserve_guids: bool = False, override_start_block: Optional[ControlFlowBlock] = None) -> Union['SDFGCutout', SDFG]: """ Cut out a multi-state subgraph from an SDFG to run separately for localized testing or optimization. @@ -337,12 +368,19 @@ def multistate_cutout(cls, :param make_side_effects_global: If True, all transient data containers which are read inside the cutout but may be written to _before_ the cutout, or any data containers which are written to inside the cutout but may be read _after_ the cutout, are made global. + :param preserve_guids: If True, ensures that the GUIDs of graph elements contained in the cutout remain + identical to the ones in their original graph. If False, new GUIDs will be generated. + False by default - if make_copy is False, this has no effect by extension. :param override_start_block: If set, explicitly force a given control flow block to be the start block. If left None (default), the start block is automatically determined based on domination relationships in the original graph. :return: The created SDFGCutout or the original SDFG where no smaller cutout could be obtained. """ - create_element = copy.deepcopy + def create_element(x: Union[ControlFlowBlock, InterstateEdge]) -> Union[ControlFlowBlock, InterstateEdge]: + ret = copy.deepcopy(x) + if preserve_guids: + ret.guid = x.guid + return ret # Check that all states are inside the same SDFG. sdfg = list(states)[0].parent diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 4ae91d5ea0..d29b1a22e4 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -55,6 +55,16 @@ def __str__(self): else: return type(self).__name__ + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k == 'guid': # Skip ID + continue + setattr(result, k, dcpy(v, memo)) + return result + def validate(self, sdfg, state): pass diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index cb8a7d5c2d..19d2a47295 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -205,6 +205,16 @@ def __setattr__(self, name: str, value: Any) -> None: super().__setattr__('_uncond', None) return super().__setattr__(name, value) + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k == 'guid': # Skip ID + continue + setattr(result, k, copy.deepcopy(v, memo)) + return result + @staticmethod def _convert_assignment(assignment) -> str: if isinstance(assignment, ast.AST):