Skip to content

Commit

Permalink
Various Cutout Fixes (#1662)
Browse files Browse the repository at this point in the history
- [x] Fix cutouts w.r.t. the use of UIDs, allowing them to be preserved
or re-generated depending on an input parameter
- [x] Fix singlestate cutout extraction when memlets access struct
members.
  • Loading branch information
phschaad authored Nov 4, 2024
1 parent 636811d commit b27024b
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 9 deletions.
56 changes: 47 additions & 9 deletions dace/sdfg/analysis/cutout.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def from_json(cls, json_obj, context=None):
def from_transformation(
cls, sdfg: SDFG, transformation: Union[PatternTransformation, SubgraphTransformation],
make_side_effects_global = True, use_alibi_nodes: bool = True, reduce_input_config = True,
symbols_map: Optional[Dict[str, Any]] = None
symbols_map: Optional[Dict[str, Any]] = None, preserve_guids: bool = False
) -> Union['SDFGCutout', SDFG]:
"""
Create a cutout from a transformation's set of affected graph elements.
Expand All @@ -130,6 +130,9 @@ def from_transformation(
:param reduce_input_config: Whether to reduce the input configuration where possible in singlestate cutouts.
:param symbols_map: A mapping of symbols to values to use for the cutout. Optional, only used when reducing the
input configuration.
:param preserve_guids: If True, ensures that the GUIDs of graph elements contained in the cutout remain
identical to the ones in their original graph. If False, new GUIDs will be generated.
False by default.
:return: The cutout.
"""
affected_nodes = _transformation_determine_affected_nodes(sdfg, transformation)
Expand All @@ -150,11 +153,12 @@ def from_transformation(
state = target_sdfg.node(transformation.state_id)
cutout = cls.singlestate_cutout(state, *affected_nodes, make_side_effects_global=make_side_effects_global,
use_alibi_nodes=use_alibi_nodes, reduce_input_config=reduce_input_config,
symbols_map=symbols_map)
symbols_map=symbols_map, preserve_guids=preserve_guids)
cutout.translate_transformation_into(transformation)
return cutout
elif isinstance(transformation, MultiStateTransformation):
cutout = cls.multistate_cutout(*affected_nodes, make_side_effects_global=make_side_effects_global)
cutout = cls.multistate_cutout(*affected_nodes, make_side_effects_global=make_side_effects_global,
preserve_guids=preserve_guids)
# If the cutout is an SDFG, there's no need to translate the transformation.
if isinstance(cutout, SDFGCutout):
cutout.translate_transformation_into(transformation)
Expand All @@ -169,14 +173,15 @@ def singlestate_cutout(cls,
make_side_effects_global: bool = True,
use_alibi_nodes: bool = True,
reduce_input_config: bool = False,
symbols_map: Optional[Dict[str, Any]] = None) -> 'SDFGCutout':
symbols_map: Optional[Dict[str, Any]] = None,
preserve_guids: bool = False) -> 'SDFGCutout':
"""
Cut out a subgraph of a state from an SDFG to run separately for localized testing or optimization.
The subgraph defined by the list of nodes will be extended to include access nodes of data containers necessary
to run the graph separately. In addition, all transient data containers that may contain data when the cutout is
executed are made global, as well as any transient data containers which are written to inside the cutout but
may be read after the cutout.
:param state: The SDFG state in which the subgraph resides.
:param nodes: The nodes in the subgraph to cut out.
:param make_copy: If True, deep-copies every SDFG element in the copy. Otherwise, original references are kept.
Expand All @@ -188,17 +193,29 @@ def singlestate_cutout(cls,
:param reduce_input_config: Whether to reduce the input configuration where possible in singlestate cutouts.
:param symbols_map: A mapping of symbols to values to use for the cutout. Optional, only used when reducing the
input configuration.
:param preserve_guids: If True, ensures that the GUIDs of graph elements contained in the cutout remain
identical to the ones in their original graph. If False, new GUIDs will be generated.
False by default - if make_copy is False, this has no effect by extension.
:return: The created SDFGCutout.
"""
if reduce_input_config:
nodes = _reduce_in_configuration(state, nodes, use_alibi_nodes, symbols_map)
create_element = copy.deepcopy if make_copy else (lambda x: x)

def clone_f(x: Union[Memlet, InterstateEdge, nd.Node, ControlFlowBlock]):
ret = copy.deepcopy(x)
if preserve_guids:
ret.guid = x.guid
return ret

create_element = clone_f if make_copy else (lambda x: x)
sdfg = state.parent
subgraph: StateSubgraphView = StateSubgraphView(state, nodes)
subgraph = _extend_subgraph_with_access_nodes(state, subgraph, use_alibi_nodes)

# Make a new SDFG with the included constants, used symbols, and data containers.
cutout = SDFGCutout(sdfg.name + '_cutout', sdfg.constants_prop)
if preserve_guids:
cutout.guid = sdfg.guid
cutout._base_sdfg = sdfg
defined_syms = subgraph.defined_symbols()
freesyms = subgraph.free_symbols
Expand All @@ -213,11 +230,24 @@ def singlestate_cutout(cls,
memlet = edge.data
if memlet.data in cutout.arrays:
continue
new_desc = sdfg.arrays[memlet.data].clone()
cutout.add_datadesc(memlet.data, new_desc)
dataname = memlet.data
if '.' in dataname:
# This is an access to a struct memeber, which typically happens for the memlets between an access node
# pointing to a struct (or view thereof), and a view pointing to the member. Assert that this is indeed
# the case (i.e., only one '.' is found in the name of the data being accessed), and if so, clone the
# struct (or struct view) data descriptor instad.
parts = dataname.split('.')
if len(parts) == 2:
dataname = parts[0]
else:
raise RuntimeError('Attempting to add invalid multi-nested data ' + memlet.data + ' to a cutout')
new_desc = sdfg.arrays[dataname].clone()
cutout.add_datadesc(dataname, new_desc)

# Add a single state with the extended subgraph
new_state = cutout.add_state(state.label, is_start_state=True)
if preserve_guids:
new_state.guid = state.guid
in_translation = dict()
out_translation = dict()
for e in sg_edges:
Expand Down Expand Up @@ -322,6 +352,7 @@ def singlestate_cutout(cls,
def multistate_cutout(cls,
*states: SDFGState,
make_side_effects_global: bool = True,
preserve_guids: bool = False,
override_start_block: Optional[ControlFlowBlock] = None) -> Union['SDFGCutout', SDFG]:
"""
Cut out a multi-state subgraph from an SDFG to run separately for localized testing or optimization.
Expand All @@ -337,12 +368,19 @@ def multistate_cutout(cls,
:param make_side_effects_global: If True, all transient data containers which are read inside the cutout but may
be written to _before_ the cutout, or any data containers which are written to
inside the cutout but may be read _after_ the cutout, are made global.
:param preserve_guids: If True, ensures that the GUIDs of graph elements contained in the cutout remain
identical to the ones in their original graph. If False, new GUIDs will be generated.
False by default - if make_copy is False, this has no effect by extension.
:param override_start_block: If set, explicitly force a given control flow block to be the start block. If left
None (default), the start block is automatically determined based on domination
relationships in the original graph.
:return: The created SDFGCutout or the original SDFG where no smaller cutout could be obtained.
"""
create_element = copy.deepcopy
def create_element(x: Union[ControlFlowBlock, InterstateEdge]) -> Union[ControlFlowBlock, InterstateEdge]:
ret = copy.deepcopy(x)
if preserve_guids:
ret.guid = x.guid
return ret

# Check that all states are inside the same SDFG.
sdfg = list(states)[0].parent
Expand Down
10 changes: 10 additions & 0 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ def __str__(self):
else:
return type(self).__name__

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == 'guid': # Skip ID
continue
setattr(result, k, dcpy(v, memo))
return result

def validate(self, sdfg, state):
pass

Expand Down
10 changes: 10 additions & 0 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,16 @@ def __setattr__(self, name: str, value: Any) -> None:
super().__setattr__('_uncond', None)
return super().__setattr__(name, value)

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == 'guid': # Skip ID
continue
setattr(result, k, copy.deepcopy(v, memo))
return result

@staticmethod
def _convert_assignment(assignment) -> str:
if isinstance(assignment, ast.AST):
Expand Down

0 comments on commit b27024b

Please sign in to comment.