diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index dfdbbb392b..b1eb42fe60 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -886,8 +886,8 @@ def generate_code(self, # NOTE: NestedSDFGs frequently contain tautologies in their symbol mapping, e.g., `'i': i`. Do not # redefine the symbols in such cases. - if (not is_top_level and isvarName in sdfg.parent_nsdfg_node.symbol_mapping.keys() - and str(sdfg.parent_nsdfg_node.symbol_mapping[isvarName] == isvarName)): + if (not is_top_level and isvarName in sdfg.parent_nsdfg_node.symbol_mapping + and str(sdfg.parent_nsdfg_node.symbol_mapping[isvarName]) == str(isvarName)): continue isvar = data.Scalar(isvarType) callsite_stream.write('%s;\n' % (isvar.as_arg(with_types=True, name=isvarName)), sdfg) diff --git a/dace/data.py b/dace/data.py index 3b571e6537..0a9858458b 100644 --- a/dace/data.py +++ b/dace/data.py @@ -243,6 +243,10 @@ def __hash__(self): def as_arg(self, with_types=True, for_call=False, name=None): """Returns a string for a C++ function signature (e.g., `int *A`). """ raise NotImplementedError + + def as_python_arg(self, with_types=True, for_call=False, name=None): + """Returns a string for a Data-Centric Python function signature (e.g., `A: dace.int32[M]`). """ + raise NotImplementedError def used_symbols(self, all_symbols: bool) -> Set[symbolic.SymbolicType]: """ @@ -583,6 +587,13 @@ def as_arg(self, with_types=True, for_call=False, name=None): if not with_types or for_call: return name return self.dtype.as_arg(name) + + def as_python_arg(self, with_types=True, for_call=False, name=None): + if self.storage is dtypes.StorageType.GPU_Global: + return Array(self.dtype, [1]).as_python_arg(with_types, for_call, name) + if not with_types or for_call: + return name + return f"{name}: {dtypes.TYPECLASS_TO_STRING[self.dtype].replace('::', '.')}" def sizes(self): return None @@ -849,6 +860,13 @@ def as_arg(self, with_types=True, for_call=False, name=None): if self.may_alias: return str(self.dtype.ctype) + ' *' + arrname return str(self.dtype.ctype) + ' * __restrict__ ' + arrname + + def as_python_arg(self, with_types=True, for_call=False, name=None): + arrname = name + + if not with_types or for_call: + return arrname + return f"{arrname}: {dtypes.TYPECLASS_TO_STRING[self.dtype].replace('::', '.')}{list(self.shape)}" def sizes(self): return [d.name if isinstance(d, symbolic.symbol) else str(d) for d in self.shape] diff --git a/dace/frontend/python/memlet_parser.py b/dace/frontend/python/memlet_parser.py index 7cc218c4fb..aa9d4ddb0d 100644 --- a/dace/frontend/python/memlet_parser.py +++ b/dace/frontend/python/memlet_parser.py @@ -200,7 +200,7 @@ def _fill_missing_slices(das, ast_ndslice, array, indices): def parse_memlet_subset(array: data.Data, node: Union[ast.Name, ast.Subscript], das: Dict[str, Any], - parsed_slice: Any = None) -> Tuple[subsets.Range, List[int]]: + parsed_slice: Any = None) -> Tuple[subsets.Range, List[int], List[int]]: """ Parses an AST subset and returns access range, as well as new dimensions to add. @@ -209,7 +209,7 @@ def parse_memlet_subset(array: data.Data, e.g., negative indices or empty shapes). :param node: AST node representing whole array or subset thereof. :param das: Dictionary of defined arrays and symbols mapped to their values. - :return: A 2-tuple of (subset, list of new axis indices). + :return: A 3-tuple of (subset, list of new axis indices, list of index-to-array-dimension correspondence). """ # Get memlet range ndslice = [(0, s - 1, 1) for s in array.shape] diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index b5d27e14f4..0329e31641 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -3177,6 +3177,12 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): if (not is_return and isinstance(target, ast.Name) and true_name and not op and not isinstance(true_array, data.Scalar) and not (true_array.shape == (1, ))): + if true_name in self.views: + if result in self.sdfg.arrays and self.views[true_name] == ( + result, Memlet.from_array(result, self.sdfg.arrays[result])): + continue + else: + raise DaceSyntaxError(self, target, 'Cannot reassign View "{}"'.format(name)) if (isinstance(result, str) and result in self.sdfg.arrays and self.sdfg.arrays[result].is_equivalent(true_array)): # Skip error if the arrays are defined exactly in the same way. diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index 9643d51c1f..eace0c8336 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -617,9 +617,10 @@ def _elementwise(pv: 'ProgramVisitor', def _simple_call(sdfg: SDFG, state: SDFGState, inpname: str, func: str, restype: dace.typeclass = None): """ Implements a simple call of the form `out = func(inp)`. """ + create_input = True if isinstance(inpname, (list, tuple)): # TODO investigate this inpname = inpname[0] - if not isinstance(inpname, str): + if not isinstance(inpname, str) and not symbolic.issymbolic(inpname): # Constant parameter cst = inpname inparr = data.create_datadescriptor(cst) @@ -627,6 +628,10 @@ def _simple_call(sdfg: SDFG, state: SDFGState, inpname: str, func: str, restype: inparr.transient = True sdfg.add_constant(inpname, cst, inparr) sdfg.add_datadesc(inpname, inparr) + elif symbolic.issymbolic(inpname): + dtype = symbolic.symtype(inpname) + inparr = data.Scalar(dtype) + create_input = False else: inparr = sdfg.arrays[inpname] @@ -636,10 +641,17 @@ def _simple_call(sdfg: SDFG, state: SDFGState, inpname: str, func: str, restype: outarr.dtype = restype num_elements = data._prod(inparr.shape) if num_elements == 1: - inp = state.add_read(inpname) + if create_input: + inp = state.add_read(inpname) + inconn_name = '__inp' + else: + inconn_name = symbolic.symstr(inpname) + out = state.add_write(outname) - tasklet = state.add_tasklet(func, {'__inp'}, {'__out'}, '__out = {f}(__inp)'.format(f=func)) - state.add_edge(inp, None, tasklet, '__inp', Memlet.from_array(inpname, inparr)) + tasklet = state.add_tasklet(func, {'__inp'} if create_input else {}, {'__out'}, + f'__out = {func}({inconn_name})') + if create_input: + state.add_edge(inp, None, tasklet, '__inp', Memlet.from_array(inpname, inparr)) state.add_edge(tasklet, '__out', out, None, Memlet.from_array(outname, outarr)) else: state.add_mapped_tasklet( @@ -2158,8 +2170,9 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op res = symbolic.equal(arr1.shape[-1], arr2.shape[-2]) if res is None: - warnings.warn(f'Last mode of first tesnsor/matrix {arr1.shape[-1]} and second-last mode of ' - f'second tensor/matrix {arr2.shape[-2]} may not match', UserWarning) + warnings.warn( + f'Last mode of first tesnsor/matrix {arr1.shape[-1]} and second-last mode of ' + f'second tensor/matrix {arr2.shape[-2]} may not match', UserWarning) elif not res: raise SyntaxError('Matrix dimension mismatch %s != %s' % (arr1.shape[-1], arr2.shape[-2])) @@ -2176,8 +2189,9 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op res = symbolic.equal(arr1.shape[-1], arr2.shape[0]) if res is None: - warnings.warn(f'Number of matrix columns {arr1.shape[-1]} and length of vector {arr2.shape[0]} ' - f'may not match', UserWarning) + warnings.warn( + f'Number of matrix columns {arr1.shape[-1]} and length of vector {arr2.shape[0]} ' + f'may not match', UserWarning) elif not res: raise SyntaxError("Number of matrix columns {} must match" "size of vector {}.".format(arr1.shape[1], arr2.shape[0])) @@ -2188,8 +2202,9 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op res = symbolic.equal(arr1.shape[0], arr2.shape[0]) if res is None: - warnings.warn(f'Length of vector {arr1.shape[0]} and number of matrix rows {arr2.shape[0]} ' - f'may not match', UserWarning) + warnings.warn( + f'Length of vector {arr1.shape[0]} and number of matrix rows {arr2.shape[0]} ' + f'may not match', UserWarning) elif not res: raise SyntaxError("Size of vector {} must match number of matrix " "rows {} must match".format(arr1.shape[0], arr2.shape[0])) @@ -2200,8 +2215,9 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op res = symbolic.equal(arr1.shape[0], arr2.shape[0]) if res is None: - warnings.warn(f'Length of first vector {arr1.shape[0]} and length of second vector {arr2.shape[0]} ' - f'may not match', UserWarning) + warnings.warn( + f'Length of first vector {arr1.shape[0]} and length of second vector {arr2.shape[0]} ' + f'may not match', UserWarning) elif not res: raise SyntaxError("Vectors in vector product must have same size: " "{} vs. {}".format(arr1.shape[0], arr2.shape[0])) @@ -4401,10 +4417,13 @@ def _datatype_converter(sdfg: SDFG, state: SDFGState, arg: UfuncInput, dtype: dt # Set tasklet parameters impl = { - 'name': "_convert_to_{}_".format(dtype.to_string()), + 'name': + "_convert_to_{}_".format(dtype.to_string()), 'inputs': ['__inp'], 'outputs': ['__out'], - 'code': "__out = dace.{}(__inp)".format(dtype.to_string()) + 'code': + "__out = {}(__inp)".format(f"dace.{dtype.to_string()}" if dtype not in (dace.bool, + dace.bool_) else dtype.to_string()) } if dtype in (dace.bool, dace.bool_): impl['code'] = "__out = dace.bool_(__inp)" diff --git a/dace/libraries/blas/nodes/matmul.py b/dace/libraries/blas/nodes/matmul.py index f0767a0473..83d07ded29 100644 --- a/dace/libraries/blas/nodes/matmul.py +++ b/dace/libraries/blas/nodes/matmul.py @@ -217,5 +217,7 @@ class MatMul(dace.sdfg.nodes.LibraryNode): default=0, desc="A scalar which will be multiplied with C before adding C") - def __init__(self, name, location=None): + def __init__(self, name, location=None, alpha=1, beta=0): + self.alpha = alpha + self.beta = beta super().__init__(name, location=location, inputs={"_a", "_b"}, outputs={"_c"}) diff --git a/dace/libraries/standard/nodes/reduce.py b/dace/libraries/standard/nodes/reduce.py index 0f76c7e252..dd026ea62c 100644 --- a/dace/libraries/standard/nodes/reduce.py +++ b/dace/libraries/standard/nodes/reduce.py @@ -1562,13 +1562,14 @@ class Reduce(dace.sdfg.nodes.LibraryNode): identity = Property(allow_none=True) def __init__(self, + name, wcr='lambda a, b: a', axes=None, identity=None, schedule=dtypes.ScheduleType.Default, debuginfo=None, **kwargs): - super().__init__(name='Reduce', **kwargs) + super().__init__(name=name, **kwargs) self.wcr = wcr self.axes = axes self.identity = identity @@ -1577,7 +1578,7 @@ def __init__(self, @staticmethod def from_json(json_obj, context=None): - ret = Reduce("lambda a, b: a", None) + ret = Reduce('reduce', 'lambda a, b: a', None) dace.serialize.set_properties_from_json(ret, json_obj, context=context) return ret diff --git a/dace/properties.py b/dace/properties.py index 61e569341f..44f8b4fbcc 100644 --- a/dace/properties.py +++ b/dace/properties.py @@ -1001,8 +1001,11 @@ def get_free_symbols(self, defined_syms: Set[str] = None) -> Set[str]: if self.language == dace.dtypes.Language.Python: visitor = TaskletFreeSymbolVisitor(defined_syms) if self.code: - for stmt in self.code: - visitor.visit(stmt) + if isinstance(self.code, list): + for stmt in self.code: + visitor.visit(stmt) + else: + visitor.visit(self.code) return visitor.free_symbols return set() diff --git a/dace/sdfg/analysis/schedule_tree/__init__.py b/dace/sdfg/analysis/schedule_tree/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dace/sdfg/analysis/schedule_tree/passes.py b/dace/sdfg/analysis/schedule_tree/passes.py new file mode 100644 index 0000000000..cc33245875 --- /dev/null +++ b/dace/sdfg/analysis/schedule_tree/passes.py @@ -0,0 +1,60 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" +Assortment of passes for schedule trees. +""" + +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from typing import Set + + +def remove_unused_and_duplicate_labels(stree: tn.ScheduleTreeScope): + """ + Removes unused and duplicate labels from the schedule tree. + + :param stree: The schedule tree to remove labels from. + """ + + class FindGotos(tn.ScheduleNodeVisitor): + + def __init__(self): + self.gotos: Set[str] = set() + + def visit_GotoNode(self, node: tn.GotoNode): + if node.target is not None: + self.gotos.add(node.target) + + class RemoveLabels(tn.ScheduleNodeTransformer): + + def __init__(self, labels_to_keep: Set[str]) -> None: + self.labels_to_keep = labels_to_keep + self.labels_seen = set() + + def visit_StateLabel(self, node: tn.StateLabel): + if node.state.name not in self.labels_to_keep: + return None + if node.state.name in self.labels_seen: + return None + self.labels_seen.add(node.state.name) + return node + + fg = FindGotos() + fg.visit(stree) + return RemoveLabels(fg.gotos).visit(stree) + + +def remove_empty_scopes(stree: tn.ScheduleTreeScope): + """ + Removes empty scopes from the schedule tree. + + :warning: This pass is not safe to use for for-loops, as it will remove indices that may be used after the loop. + """ + + class RemoveEmptyScopes(tn.ScheduleNodeTransformer): + + def visit_scope(self, node: tn.ScheduleTreeScope): + if len(node.children) == 0: + return None + + return self.generic_visit(node) + + return RemoveEmptyScopes().visit(stree) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py new file mode 100644 index 0000000000..917f748cb8 --- /dev/null +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -0,0 +1,743 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +from collections import defaultdict +import copy +from typing import Dict, List, Set +import dace +from dace import data, subsets, symbolic +from dace.codegen import control_flow as cf +from dace.sdfg.sdfg import InterstateEdge, SDFG +from dace.sdfg.state import SDFGState +from dace.sdfg import utils as sdutil, graph as gr, nodes as nd +from dace.sdfg.replace import replace_datadesc_names +from dace.frontend.python.astutils import negate_expr +from dace.sdfg.analysis.schedule_tree import treenodes as tn, passes as stpasses +from dace.transformation.passes.analysis import StateReachability +from dace.transformation.helpers import unsqueeze_memlet +from dace.properties import CodeBlock +from dace.memlet import Memlet + +import networkx as nx +import time +import sys + +NODE_TO_SCOPE_TYPE = { + dace.nodes.MapEntry: tn.MapScope, + dace.nodes.ConsumeEntry: tn.ConsumeScope, + dace.nodes.PipelineEntry: tn.PipelineScope, +} + + +def dealias_sdfg(sdfg: SDFG): + """ + Renames all data containers in an SDFG tree (i.e., nested SDFGs) to use the same data descriptors + as the top-level SDFG. This function takes care of offsetting memlets and internal + uses of arrays such that there is one naming system, and no aliasing of managed memory. + + This function operates in-place. + + :param sdfg: The SDFG to operate on. + """ + for nsdfg in sdfg.all_sdfgs_recursive(): + + if not nsdfg.parent: + continue + + replacements: Dict[str, str] = {} + inv_replacements: Dict[str, List[str]] = {} + parent_edges: Dict[str, Memlet] = {} + to_unsqueeze: Set[str] = set() + + parent_sdfg = nsdfg.parent_sdfg + parent_state = nsdfg.parent + parent_node = nsdfg.parent_nsdfg_node + + for name, desc in nsdfg.arrays.items(): + if desc.transient: + continue + for edge in parent_state.edges_by_connector(parent_node, name): + parent_name = edge.data.data + assert parent_name in parent_sdfg.arrays + if name != parent_name: + replacements[name] = parent_name + parent_edges[name] = edge + if parent_name in inv_replacements: + inv_replacements[parent_name].append(name) + to_unsqueeze.add(parent_name) + else: + inv_replacements[parent_name] = [name] + break + + if to_unsqueeze: + for parent_name in to_unsqueeze: + parent_arr = parent_sdfg.arrays[parent_name] + if isinstance(parent_arr, data.View): + parent_arr = data.Array(parent_arr.dtype, parent_arr.shape, parent_arr.transient, + parent_arr.allow_conflicts, parent_arr.storage, parent_arr.location, + parent_arr.strides, parent_arr.offset, parent_arr.may_alias, + parent_arr.lifetime, parent_arr.alignment, parent_arr.debuginfo, + parent_arr.total_size, parent_arr.start_offset, parent_arr.optional, + parent_arr.pool) + elif isinstance(parent_arr, data.StructureView): + parent_arr = data.Structure(parent_arr.members, parent_arr.name, parent_arr.transient, + parent_arr.storage, parent_arr.location, parent_arr.lifetime, + parent_arr.debuginfo) + child_names = inv_replacements[parent_name] + for name in child_names: + child_arr = copy.deepcopy(parent_arr) + child_arr.transient = False + nsdfg.arrays[name] = child_arr + for state in nsdfg.states(): + for e in state.edges(): + if not state.is_leaf_memlet(e): + continue + + mpath = state.memlet_path(e) + src, dst = mpath[0].src, mpath[-1].dst + + # We need to take directionality of the memlet into account and unsqueeze either to source or + # destination subset + if isinstance(src, nd.AccessNode) and src.data in child_names: + src_data = src.data + new_src_memlet = unsqueeze_memlet(e.data, parent_edges[src.data].data, use_src_subset=True) + else: + src_data = None + new_src_memlet = None + # We need to take directionality of the memlet into account + if isinstance(dst, nd.AccessNode) and dst.data in child_names: + dst_data = dst.data + new_dst_memlet = unsqueeze_memlet(e.data, parent_edges[dst.data].data, use_dst_subset=True) + else: + dst_data = None + new_dst_memlet = None + + if new_src_memlet is not None: + e.data.src_subset = new_src_memlet.subset + if new_dst_memlet is not None: + e.data.dst_subset = new_dst_memlet.subset + if e.data.data == src_data: + e.data.data = new_src_memlet.data + elif e.data.data == dst_data: + e.data.data = new_dst_memlet.data + + for e in nsdfg.edges(): + repl_dict = dict() + syms = e.data.read_symbols() + for memlet in e.data.get_read_memlets(nsdfg.arrays): + if memlet.data in child_names: + repl_dict[str(memlet)] = unsqueeze_memlet(memlet, parent_edges[memlet.data].data) + if memlet.data in syms: + syms.remove(memlet.data) + for s in syms: + if s in parent_edges: + repl_dict[s] = str(parent_edges[s].data) + e.data.replace_dict(repl_dict) + for name in child_names: + edge = parent_edges[name] + for e in parent_state.memlet_tree(edge): + if e.data.data == parent_name: + e.data.subset = subsets.Range.from_array(parent_arr) + else: + e.data.other_subset = subsets.Range.from_array(parent_arr) + + if replacements: + symbolic.safe_replace(replacements, lambda d: replace_datadesc_names(nsdfg, d), value_as_string=True) + parent_node.in_connectors = { + replacements[c] if c in replacements else c: t + for c, t in parent_node.in_connectors.items() + } + parent_node.out_connectors = { + replacements[c] if c in replacements else c: t + for c, t in parent_node.out_connectors.items() + } + for e in parent_state.all_edges(parent_node): + if e.src_conn in replacements: + e._src_conn = replacements[e.src_conn] + elif e.dst_conn in replacements: + e._dst_conn = replacements[e.dst_conn] + + +def normalize_memlet(sdfg: SDFG, state: SDFGState, original: gr.MultiConnectorEdge[Memlet], data: str) -> Memlet: + """ + Normalizes a memlet to a given data descriptor. + + :param sdfg: The SDFG. + :param state: The state. + :param original: The original memlet. + :param data: The data descriptor. + :return: A new memlet. + """ + # Shallow copy edge + edge = gr.MultiConnectorEdge(original.src, original.src_conn, original.dst, original.dst_conn, + copy.deepcopy(original.data), original.key) + edge.data.try_initialize(sdfg, state, edge) + + if '.' in edge.data.data and edge.data.data.startswith(data + '.'): + return edge.data + if edge.data.data == data: + return edge.data + + memlet = edge.data + if memlet._is_data_src: + new_subset, new_osubset = memlet.get_dst_subset(edge, state), memlet.get_src_subset(edge, state) + else: + new_subset, new_osubset = memlet.get_src_subset(edge, state), memlet.get_dst_subset(edge, state) + + memlet.data = data + memlet.subset = new_subset + memlet.other_subset = new_osubset + memlet._is_data_src = True + return memlet + + +def replace_memlets(sdfg: SDFG, input_mapping: Dict[str, Memlet], output_mapping: Dict[str, Memlet]): + """ + Replaces all uses of data containers in memlets and interstate edges in an SDFG. + :param sdfg: The SDFG. + :param input_mapping: A mapping from internal data descriptor names to external input memlets. + :param output_mapping: A mapping from internal data descriptor names to external output memlets. + """ + for state in sdfg.states(): + for e in state.edges(): + mpath = state.memlet_path(e) + src = mpath[0].src + dst = mpath[-1].dst + memlet = e.data + if isinstance(src, dace.nodes.AccessNode) and src.data in input_mapping: + src_data = src.data + src_memlet = unsqueeze_memlet(memlet, input_mapping[src.data], use_src_subset=True) + else: + src_data = None + src_memlet = None + if isinstance(dst, dace.nodes.AccessNode) and dst.data in output_mapping: + dst_data = dst.data + dst_memlet = unsqueeze_memlet(memlet, output_mapping[dst.data], use_dst_subset=True) + else: + dst_data = None + dst_memlet = None + + # Other cases (code->code) + if src_data is None and dst_data is None: + if e.data.data in input_mapping: + memlet = unsqueeze_memlet(memlet, input_mapping[e.data.data]) + elif e.data.data in output_mapping: + memlet = unsqueeze_memlet(memlet, output_mapping[e.data.data]) + e.data = memlet + else: + if src_memlet is not None: + memlet.src_subset = src_memlet.subset + if dst_memlet is not None: + memlet.dst_subset = dst_memlet.subset + if memlet.data == src_data: + memlet.data = src_memlet.data + elif memlet.data == dst_data: + memlet.data = dst_memlet.data + + for e in sdfg.edges(): + repl_dict = dict() + syms = e.data.read_symbols() + for memlet in e.data.get_read_memlets(sdfg.arrays): + if memlet.data in input_mapping or memlet.data in output_mapping: + # If array name is both in the input connectors and output connectors with different + # memlets, this is undefined behavior. Prefer output + if memlet.data in input_mapping: + mapping = input_mapping + if memlet.data in output_mapping: + mapping = output_mapping + + repl_dict[str(memlet)] = str(unsqueeze_memlet(memlet, mapping[memlet.data])) + if memlet.data in syms: + syms.remove(memlet.data) + for s in syms: + if s in input_mapping: + repl_dict[s] = str(input_mapping[s]) + + # Manual replacement with strings + # TODO(later): Would be MUCH better to use MemletReplacer / e.data.replace_dict(repl_dict, replace_keys=False) + for find, replace in repl_dict.items(): + for k, v in e.data.assignments.items(): + if find in v: + e.data.assignments[k] = v.replace(find, replace) + condstr = e.data.condition.as_string + if find in condstr: + e.data.condition.as_string = condstr.replace(find, replace) + + +def remove_name_collisions(sdfg: SDFG): + """ + Removes name collisions in nested SDFGs by renaming states, data containers, and symbols. + + :param sdfg: The SDFG. + """ + state_names_seen = set() + identifiers_seen = set() + + for nsdfg in sdfg.all_sdfgs_recursive(): + # Rename duplicate states + for state in nsdfg.nodes(): + if state.label in state_names_seen: + state.set_label(data.find_new_name(state.label, state_names_seen)) + state_names_seen.add(state.label) + + replacements: Dict[str, str] = {} + parent_node = nsdfg.parent_nsdfg_node + + # Preserve top-level SDFG names + do_not_replace = False + if not parent_node: + do_not_replace = True + + # Rename duplicate data containers + for name, desc in nsdfg.arrays.items(): + if name in identifiers_seen: + if not desc.transient or do_not_replace: + continue + + new_name = data.find_new_name(name, identifiers_seen) + replacements[name] = new_name + name = new_name + identifiers_seen.add(name) + + # Rename duplicate top-level symbols + for name in nsdfg.get_all_toplevel_symbols(): + # Will already be renamed during conversion + if parent_node is not None and name in parent_node.symbol_mapping: + continue + + if name in identifiers_seen and not do_not_replace: + new_name = data.find_new_name(name, identifiers_seen) + replacements[name] = new_name + name = new_name + identifiers_seen.add(name) + + # Rename duplicate constants + for name in nsdfg.constants_prop.keys(): + if name in identifiers_seen and not do_not_replace: + new_name = data.find_new_name(name, identifiers_seen) + replacements[name] = new_name + name = new_name + identifiers_seen.add(name) + + # If there is a name collision, replace all uses of the old names with the new names + if replacements: + nsdfg.replace_dict(replacements) + + +def _make_view_node(state: SDFGState, edge: gr.MultiConnectorEdge[Memlet], view_name: str, + viewed_name: str) -> tn.ViewNode: + """ + Helper function to create a view schedule tree node from a memlet edge. + """ + sdfg = state.parent + normalized = normalize_memlet(sdfg, state, edge, viewed_name) + return tn.ViewNode(target=view_name, + source=viewed_name, + memlet=normalized, + src_desc=sdfg.arrays[viewed_name], + view_desc=sdfg.arrays[view_name]) + + +def replace_symbols_until_set(nsdfg: dace.nodes.NestedSDFG): + """ + Replaces symbol values in a nested SDFG until their value has been reset. This is used for matching symbol + namespaces between an SDFG and a nested SDFG. + """ + mapping = nsdfg.symbol_mapping + sdfg = nsdfg.sdfg + reachable_states = StateReachability().apply_pass(sdfg, {})[sdfg.sdfg_id] + redefined_symbols: Dict[SDFGState, Set[str]] = defaultdict(set) + + # Collect redefined symbols + for e in sdfg.edges(): + redefined = e.data.assignments.keys() + redefined_symbols[e.dst] |= redefined + for reachable in reachable_states[e.dst]: + redefined_symbols[reachable] |= redefined + + # Replace everything but the redefined symbols + for state in sdfg.nodes(): + per_state_mapping = {k: v for k, v in mapping.items() if k not in redefined_symbols[state]} + symbolic.safe_replace(per_state_mapping, state.replace_dict) + for e in sdfg.out_edges(state): + symbolic.safe_replace(per_state_mapping, lambda d: e.data.replace_dict(d, replace_keys=False)) + + +def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[Memlet], tn.ScheduleTreeNode]: + """ + Creates a dictionary mapping edges to their corresponding schedule tree nodes, if relevant. + This handles view edges, reference sets, and dynamic map inputs. + + :param state: The state. + """ + result: Dict[gr.MultiConnectorEdge[Memlet], tn.ScheduleTreeNode] = {} + scope_to_edges: Dict[nd.EntryNode, List[gr.MultiConnectorEdge[Memlet]]] = defaultdict(list) + edges_to_ignore = set() + sdfg = state.parent + + for edge in state.edges(): + if edge in edges_to_ignore or edge in result: + continue + if edge.data.is_empty(): # Ignore empty memlets + edges_to_ignore.add(edge) + continue + + # Part of a memlet path - only consider innermost memlets + mtree = state.memlet_tree(edge) + all_edges = set(e for e in mtree) + leaves = set(mtree.leaves()) + edges_to_ignore.update(all_edges - leaves) + + # For every tree leaf, create a copy/view/reference set node as necessary + for e in leaves: + if e in edges_to_ignore or e in result: + continue + + # 1. Check for views + if isinstance(e.src, dace.nodes.AccessNode): + desc = e.src.desc(sdfg) + if isinstance(desc, (dace.data.View, dace.data.StructureView)): + vedge = sdutil.get_view_edge(state, e.src) + if e is vedge: + viewed_node = sdutil.get_view_node(state, e.src) + result[e] = _make_view_node(state, e, e.src.data, viewed_node.data) + scope = state.entry_node(e.dst if mtree.downwards else e.src) + scope_to_edges[scope].append(e) + continue + if isinstance(e.dst, dace.nodes.AccessNode): + desc = e.dst.desc(sdfg) + if isinstance(desc, (dace.data.View, dace.data.StructureView)): + vedge = sdutil.get_view_edge(state, e.dst) + if e is vedge: + viewed_node = sdutil.get_view_node(state, e.dst) + result[e] = _make_view_node(state, e, e.dst.data, viewed_node.data) + scope = state.entry_node(e.dst if mtree.downwards else e.src) + scope_to_edges[scope].append(e) + continue + + # 2. Check for reference sets + if isinstance(e.dst, dace.nodes.AccessNode) and e.dst_conn == 'set': + assert isinstance(e.dst.desc(sdfg), dace.data.Reference) + result[e] = tn.RefSetNode(target=e.dst.data, + memlet=e.data, + src_desc=sdfg.arrays[e.data.data], + ref_desc=sdfg.arrays[e.dst.data]) + scope = state.entry_node(e.dst if mtree.downwards else e.src) + scope_to_edges[scope].append(e) + continue + + # 3. Check for copies + # Get both ends of the memlet path + mpath = state.memlet_path(e) + src = mpath[0].src + dst = mpath[-1].dst + if not isinstance(src, dace.nodes.AccessNode): + continue + if not isinstance(dst, (dace.nodes.AccessNode, dace.nodes.EntryNode)): + continue + + # If the edge destination is the innermost node, it is a downward-pointing path + is_target_dst = e.dst is dst + + innermost_node = dst if is_target_dst else src + outermost_node = src if is_target_dst else dst + + # Normalize memlets to their innermost node, or source->destination if it is a same-scope edge + if e.src is src and e.dst is dst: + outermost_node = src + innermost_node = dst + + if isinstance(dst, dace.nodes.EntryNode): + # Special case: dynamic map range has no data + result[e] = tn.DynScopeCopyNode(target=e.dst_conn, memlet=e.data) + else: + target_name = innermost_node.data + new_memlet = normalize_memlet(sdfg, state, e, outermost_node.data) + result[e] = tn.CopyNode(target=target_name, memlet=new_memlet) + + scope = state.entry_node(e.dst if mtree.downwards else e.src) + scope_to_edges[scope].append(e) + + return result, scope_to_edges + + +def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: + """ + Use scope-aware topological sort to get nodes by scope and return the schedule tree of this state. + + :param state: The state. + :return: A string for the whole state + """ + result: List[tn.ScheduleTreeNode] = [] + sdfg = state.parent + + edge_to_stree: Dict[gr.MultiConnectorEdge[Memlet], tn.ScheduleTreeNode] + scope_to_edges: Dict[nd.EntryNode, List[gr.MultiConnectorEdge[Memlet]]] + edge_to_stree, scope_to_edges = prepare_schedule_tree_edges(state) + edges_to_ignore = set() + + # Handle all unscoped edges to generate output views + views = _generate_views_in_scope(scope_to_edges[None], edge_to_stree, sdfg, state) + result.extend(views) + + scopes: List[List[tn.ScheduleTreeNode]] = [] + for node in sdutil.scope_aware_topological_sort(state): + if isinstance(node, dace.nodes.EntryNode): + # Handle dynamic scope inputs + for e in state.in_edges(node): + if e in edges_to_ignore: + continue + if e in edge_to_stree: + result.append(edge_to_stree[e]) + edges_to_ignore.add(e) + + # Handle all scoped edges to generate (views) + views = _generate_views_in_scope(scope_to_edges[node], edge_to_stree, sdfg, state) + result.extend(views) + + # Create scope node and add to stack + scopes.append(result) + subnodes = [] + result.append(NODE_TO_SCOPE_TYPE[type(node)](node=node, children=subnodes)) + result = subnodes + elif isinstance(node, dace.nodes.ExitNode): + result = scopes.pop() + elif isinstance(node, dace.nodes.NestedSDFG): + nested_array_mapping_input = {} + nested_array_mapping_output = {} + generated_nviews = set() + + # Replace symbols and memlets in nested SDFGs to match the namespace of the parent SDFG + replace_symbols_until_set(node) + + # Create memlets for nested SDFG mapping, or nview schedule nodes if slice cannot be determined + for e in state.all_edges(node): + conn = e.dst_conn if e.dst is node else e.src_conn + if e.data.is_empty() or not conn: + continue + res = sdutil.map_view_to_array(node.sdfg.arrays[conn], sdfg.arrays[e.data.data], e.data.subset) + no_mapping = False + if res is None: + no_mapping = True + else: + mapping, expanded, squeezed = res + if expanded: # "newaxis" slices will be seen as views (for now) + no_mapping = True + else: + if e.dst is node: + nested_array_mapping_input[conn] = e.data + else: + nested_array_mapping_output[conn] = e.data + + if no_mapping: # Must use view (nview = nested SDFG view) + if conn not in generated_nviews: + result.append( + tn.NView(target=conn, + source=e.data.data, + memlet=e.data, + src_desc=sdfg.arrays[e.data.data], + view_desc=node.sdfg.arrays[conn])) + generated_nviews.add(conn) + + replace_memlets(node.sdfg, nested_array_mapping_input, nested_array_mapping_output) + + # Insert the nested SDFG flattened + nested_stree = as_schedule_tree(node.sdfg, in_place=True, toplevel=False) + result.extend(nested_stree.children) + elif isinstance(node, dace.nodes.Tasklet): + in_memlets = {e.dst_conn: e.data for e in state.in_edges(node) if e.dst_conn} + out_memlets = {e.src_conn: e.data for e in state.out_edges(node) if e.src_conn} + result.append(tn.TaskletNode(node=node, in_memlets=in_memlets, out_memlets=out_memlets)) + elif isinstance(node, dace.nodes.LibraryNode): + # NOTE: LibraryNodes do not necessarily have connectors + if node.in_connectors: + in_memlets = {e.dst_conn: e.data for e in state.in_edges(node) if e.dst_conn} + else: + in_memlets = set([e.data for e in state.in_edges(node)]) + if node.out_connectors: + out_memlets = {e.src_conn: e.data for e in state.out_edges(node) if e.src_conn} + else: + out_memlets = set([e.data for e in state.out_edges(node)]) + result.append(tn.LibraryCall(node=node, in_memlets=in_memlets, out_memlets=out_memlets)) + elif isinstance(node, dace.nodes.AccessNode): + # If one of the neighboring edges has a schedule tree node attached to it, use that + # (except for views, which were generated above) + for e in state.all_edges(node): + if e in edges_to_ignore: + continue + if e in edge_to_stree: + if isinstance(edge_to_stree[e], tn.ViewNode): + continue + result.append(edge_to_stree[e]) + edges_to_ignore.add(e) + + assert len(scopes) == 0 + + return result + + +def _generate_views_in_scope(edges: List[gr.MultiConnectorEdge[Memlet]], + edge_to_stree: Dict[gr.MultiConnectorEdge[Memlet], tn.ScheduleTreeNode], sdfg: SDFG, + state: SDFGState) -> List[tn.ScheduleTreeNode]: + """ + Generates all view and reference set edges in the correct order. This function is intended to be used + at the beginning of a scope. + """ + result: List[tn.ScheduleTreeNode] = [] + + # Make a dependency graph of all the views + g = nx.DiGraph() + node_to_stree = {} + for e in edges: + if e not in edge_to_stree: + continue + st = edge_to_stree[e] + if not isinstance(st, tn.ViewNode): + continue + g.add_edge(st.source, st.target) + node_to_stree[st.target] = st + + # Traverse in order and deduplicate + already_generated = set() + for n in nx.topological_sort(g): + if n in node_to_stree and n not in already_generated: + result.append(node_to_stree[n]) + already_generated.add(n) + + return result + + +def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) -> tn.ScheduleTreeScope: + """ + Converts an SDFG into a schedule tree. The schedule tree is a tree of nodes that represent the execution order of + the SDFG. + Each node in the tree can either represent a single statement (symbol assignment, tasklet, copy, library node, etc.) + or a ``ScheduleTreeScope`` block (map, for-loop, pipeline, etc.) that contains other nodes. + + It can be used to generate code from an SDFG, or to perform schedule transformations on the SDFG. For example, + erasing an empty if branch, or merging two consecutive for-loops. The SDFG can then be reconstructed via the + ``from_schedule_tree`` function. + + :param sdfg: The SDFG to convert. + :param in_place: If True, the SDFG is modified in-place. Otherwise, a copy is made. Note that the SDFG might not be + usable after the conversion if ``in_place`` is True! + :return: A schedule tree representing the given SDFG. + """ + from dace.transformation import helpers as xfh # Avoid import loop + + if not in_place: + sdfg = copy.deepcopy(sdfg) + + # Prepare SDFG for conversion + ############################# + + # Split edges with assignments and conditions + xfh.split_interstate_edges(sdfg) + + # Replace code->code edges with data<->code edges + xfh.replace_code_to_code_edges(sdfg) + + if toplevel: # Top-level SDFG preparation (only perform once) + dealias_sdfg(sdfg) + # Handle name collisions (in arrays, state labels, symbols) + remove_name_collisions(sdfg) + + ############################# + + # Create initial tree from CFG + cfg: cf.ControlFlow = cf.structured_control_flow_tree(sdfg, lambda _: '') + + # Traverse said tree (also into states) to create the schedule tree + def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.ScheduleTreeNode]: + result: List[tn.ScheduleTreeNode] = [] + if isinstance(node, cf.GeneralBlock): + subnodes: List[tn.ScheduleTreeNode] = [] + for n in node.elements: + subnodes.extend(totree(n, node)) + if not node.sequential: + # Nest in general block + result = [tn.GBlock(children=subnodes)] + else: + # Use the sub-nodes directly + result = subnodes + + elif isinstance(node, cf.SingleState): + result = state_schedule_tree(node.state) + + # Add interstate assignments unrelated to structured control flow + if parent is not None: + for e in sdfg.out_edges(node.state): + edge_body = [] + + if e not in parent.assignments_to_ignore: + for aname, aval in e.data.assignments.items(): + edge_body.append( + tn.AssignNode(name=aname, + value=CodeBlock(aval), + edge=InterstateEdge(assignments={aname: aval}))) + + if not parent.sequential: + if e not in parent.gotos_to_ignore: + edge_body.append(tn.GotoNode(target=e.dst.label)) + else: + if e in parent.gotos_to_break: + edge_body.append(tn.BreakNode()) + elif e in parent.gotos_to_continue: + edge_body.append(tn.ContinueNode()) + + if e not in parent.gotos_to_ignore and not e.data.is_unconditional(): + if sdfg.out_degree(node.state) == 1 and parent.sequential: + # Conditional state in sequential block! Add "if not condition goto exit" + result.append( + tn.StateIfScope(condition=CodeBlock(negate_expr(e.data.condition)), + children=[tn.GotoNode(target=None)])) + result.extend(edge_body) + else: + # Add "if condition" with the body above + result.append(tn.StateIfScope(condition=e.data.condition, children=edge_body)) + else: + result.extend(edge_body) + + elif isinstance(node, cf.ForScope): + result.append(tn.ForScope(header=node, children=totree(node.body))) + elif isinstance(node, cf.IfScope): + result.append(tn.IfScope(condition=node.condition, children=totree(node.body))) + if node.orelse is not None: + result.append(tn.ElseScope(children=totree(node.orelse))) + elif isinstance(node, cf.IfElseChain): + # Add "if" for the first condition, "elif"s for the rest + result.append(tn.IfScope(condition=node.body[0][0], children=totree(node.body[0][1]))) + for cond, body in node.body[1:]: + result.append(tn.ElifScope(condition=cond, children=totree(body))) + # "else goto exit" + result.append(tn.ElseScope(children=[tn.GotoNode(target=None)])) + elif isinstance(node, cf.WhileScope): + result.append(tn.WhileScope(header=node, children=totree(node.body))) + elif isinstance(node, cf.DoWhileScope): + result.append(tn.DoWhileScope(header=node, children=totree(node.body))) + else: + # e.g., "SwitchCaseScope" + raise tn.UnsupportedScopeException(type(node).__name__) + + if node.first_state is not None: + result = [tn.StateLabel(state=node.first_state)] + result + + return result + + # Recursive traversal of the control flow tree + result = tn.ScheduleTreeScope(children=totree(cfg)) + + # Clean up tree + stpasses.remove_unused_and_duplicate_labels(result) + + return result + + +if __name__ == '__main__': + s = time.time() + sdfg = SDFG.from_file(sys.argv[1]) + print('Loaded SDFG in', time.time() - s, 'seconds') + s = time.time() + stree = as_schedule_tree(sdfg, in_place=True) + print('Created schedule tree in', time.time() - s, 'seconds') + + with open('output_stree.txt', 'w') as fp: + fp.write(stree.as_string(-1) + '\n') diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py new file mode 100644 index 0000000000..99918cd2a4 --- /dev/null +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -0,0 +1,408 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +from dataclasses import dataclass, field +from dace import nodes, data, subsets +from dace.codegen import control_flow as cf +from dace.properties import CodeBlock +from dace.sdfg import InterstateEdge +from dace.sdfg.state import SDFGState +from dace.symbolic import symbol +from dace.memlet import Memlet +from typing import Dict, Iterator, List, Optional, Set, Union + +INDENTATION = ' ' + + +class UnsupportedScopeException(Exception): + pass + + +@dataclass +class ScheduleTreeNode: + parent: Optional['ScheduleTreeScope'] = field(default=None, init=False) + + def as_string(self, indent: int = 0): + return indent * INDENTATION + 'UNSUPPORTED' + + def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: + """ + Traverse tree nodes in a pre-order manner. + """ + yield self + + +@dataclass +class ScheduleTreeScope(ScheduleTreeNode): + children: List['ScheduleTreeNode'] + containers: Optional[Dict[str, data.Data]] = field(default_factory=dict, init=False) + symbols: Optional[Dict[str, symbol]] = field(default_factory=dict, init=False) + + def __init__(self, + children: Optional[List['ScheduleTreeNode']] = None): + self.children = children or [] + if self.children: + for child in children: + child.parent = self + + def as_string(self, indent: int = 0): + if not self.children: + return (indent + 1) * INDENTATION + 'pass' + return '\n'.join([child.as_string(indent + 1) for child in self.children]) + + def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: + """ + Traverse tree nodes in a pre-order manner. + """ + yield from super().preorder_traversal() + for child in self.children: + yield from child.preorder_traversal() + + # TODO: Helper function that gets input/output memlets of the scope + + +@dataclass +class ControlFlowScope(ScheduleTreeScope): + pass + + +@dataclass +class DataflowScope(ScheduleTreeScope): + node: nodes.EntryNode + + +@dataclass +class GBlock(ControlFlowScope): + """ + General control flow block. Contains a list of states + that can run in arbitrary order based on edges (gotos). + Normally contains irreducible control flow. + """ + + def as_string(self, indent: int = 0): + result = indent * INDENTATION + 'gblock:\n' + return result + super().as_string(indent) + + +@dataclass +class StateLabel(ScheduleTreeNode): + state: SDFGState + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'label {self.state.name}:' + + +@dataclass +class GotoNode(ScheduleTreeNode): + target: Optional[str] = None #: If None, equivalent to "goto exit" or "return" + + def as_string(self, indent: int = 0): + name = self.target or 'exit' + return indent * INDENTATION + f'goto {name}' + + +@dataclass +class AssignNode(ScheduleTreeNode): + """ + Represents a symbol assignment that is not part of a structured control flow block. + """ + name: str + value: CodeBlock + edge: InterstateEdge + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'assign {self.name} = {self.value.as_string}' + + +@dataclass +class ForScope(ControlFlowScope): + """ + For loop scope. + """ + header: cf.ForScope + + def as_string(self, indent: int = 0): + node = self.header + + result = (indent * INDENTATION + f'for {node.itervar} = {node.init}; {node.condition.as_string}; ' + f'{node.itervar} = {node.update}:\n') + return result + super().as_string(indent) + + +@dataclass +class WhileScope(ControlFlowScope): + """ + While loop scope. + """ + header: cf.WhileScope + + def as_string(self, indent: int = 0): + result = indent * INDENTATION + f'while {self.header.test.as_string}:\n' + return result + super().as_string(indent) + + +@dataclass +class DoWhileScope(ControlFlowScope): + """ + Do/While loop scope. + """ + header: cf.DoWhileScope + + def as_string(self, indent: int = 0): + header = indent * INDENTATION + 'do:\n' + footer = indent * INDENTATION + f'while {self.header.test.as_string}\n' + return header + super().as_string(indent) + footer + + +@dataclass +class IfScope(ControlFlowScope): + """ + If branch scope. + """ + condition: CodeBlock + + def as_string(self, indent: int = 0): + result = indent * INDENTATION + f'if {self.condition.as_string}:\n' + return result + super().as_string(indent) + + +@dataclass +class StateIfScope(IfScope): + """ + A special class of an if scope in general blocks for if statements that are part of a state transition. + """ + + def as_string(self, indent: int = 0): + result = indent * INDENTATION + f'stateif {self.condition.as_string}:\n' + return result + super(IfScope, self).as_string(indent) + + +@dataclass +class BreakNode(ScheduleTreeNode): + """ + Represents a break statement. + """ + + def as_string(self, indent: int = 0): + return indent * INDENTATION + 'break' + + +@dataclass +class ContinueNode(ScheduleTreeNode): + """ + Represents a continue statement. + """ + + def as_string(self, indent: int = 0): + return indent * INDENTATION + 'continue' + + +@dataclass +class ElifScope(ControlFlowScope): + """ + Else-if branch scope. + """ + condition: CodeBlock + + def as_string(self, indent: int = 0): + result = indent * INDENTATION + f'elif {self.condition.as_string}:\n' + return result + super().as_string(indent) + + +@dataclass +class ElseScope(ControlFlowScope): + """ + Else branch scope. + """ + + def as_string(self, indent: int = 0): + result = indent * INDENTATION + 'else:\n' + return result + super().as_string(indent) + + +@dataclass +class MapScope(DataflowScope): + """ + Map scope. + """ + + def as_string(self, indent: int = 0): + rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) + result = indent * INDENTATION + f'map {", ".join(self.node.map.params)} in [{rangestr}]:\n' + return result + super().as_string(indent) + + +@dataclass +class ConsumeScope(DataflowScope): + """ + Consume scope. + """ + + def as_string(self, indent: int = 0): + node: nodes.ConsumeEntry = self.node + cond = 'stream not empty' if node.consume.condition is None else node.consume.condition.as_string + result = indent * INDENTATION + f'consume (PE {node.consume.pe_index} out of {node.consume.num_pes}) while {cond}:\n' + return result + super().as_string(indent) + + +@dataclass +class PipelineScope(DataflowScope): + """ + Pipeline scope. + """ + + def as_string(self, indent: int = 0): + rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) + result = indent * INDENTATION + f'pipeline {", ".join(self.node.map.params)} in [{rangestr}]:\n' + return result + super().as_string(indent) + + +@dataclass +class TaskletNode(ScheduleTreeNode): + node: nodes.Tasklet + in_memlets: Dict[str, Memlet] + out_memlets: Dict[str, Memlet] + + def as_string(self, indent: int = 0): + in_memlets = ', '.join(f'{v}' for v in self.in_memlets.values()) + out_memlets = ', '.join(f'{v}' for v in self.out_memlets.values()) + if not out_memlets: + return indent * INDENTATION + f'tasklet({in_memlets})' + return indent * INDENTATION + f'{out_memlets} = tasklet({in_memlets})' + + +@dataclass +class LibraryCall(ScheduleTreeNode): + node: nodes.LibraryNode + in_memlets: Union[Dict[str, Memlet], Set[Memlet]] + out_memlets: Union[Dict[str, Memlet], Set[Memlet]] + + def as_string(self, indent: int = 0): + if isinstance(self.in_memlets, set): + in_memlets = ', '.join(f'{v}' for v in self.in_memlets) + else: + in_memlets = ', '.join(f'{v}' for v in self.in_memlets.values()) + if isinstance(self.out_memlets, set): + out_memlets = ', '.join(f'{v}' for v in self.out_memlets) + else: + out_memlets = ', '.join(f'{v}' for v in self.out_memlets.values()) + libname = type(self.node).__name__ + # Get the properties of the library node without its superclasses + own_properties = ', '.join(f'{k}={getattr(self.node, k)}' for k, v in self.node.__properties__.items() + if v.owner not in {nodes.Node, nodes.CodeNode, nodes.LibraryNode}) + return indent * INDENTATION + f'{out_memlets} = library {libname}[{own_properties}]({in_memlets})' + + +@dataclass +class CopyNode(ScheduleTreeNode): + target: str + memlet: Memlet + + def as_string(self, indent: int = 0): + if self.memlet.other_subset is not None and any(s != 0 for s in self.memlet.other_subset.min_element()): + offset = f'[{self.memlet.other_subset}]' + else: + offset = '' + if self.memlet.wcr is not None: + wcr = f' with {self.memlet.wcr}' + else: + wcr = '' + + return indent * INDENTATION + f'{self.target}{offset} = copy {self.memlet.data}[{self.memlet.subset}]{wcr}' + + +@dataclass +class DynScopeCopyNode(ScheduleTreeNode): + """ + A special case of a copy node that is used in dynamic scope inputs (e.g., dynamic map ranges). + """ + target: str + memlet: Memlet + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'{self.target} = dscopy {self.memlet.data}[{self.memlet.subset}]' + + +@dataclass +class ViewNode(ScheduleTreeNode): + target: str #: View name + source: str #: Viewed container name + memlet: Memlet + src_desc: data.Data + view_desc: data.Data + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'{self.target} = view {self.memlet} as {self.view_desc.shape}' + + +@dataclass +class NView(ViewNode): + """ + Nested SDFG view node. Subclass of a view that specializes in nested SDFG boundaries. + """ + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'{self.target} = nview {self.memlet} as {self.view_desc.shape}' + + +@dataclass +class RefSetNode(ScheduleTreeNode): + """ + Reference set node. Sets a reference to a data container. + """ + target: str + memlet: Memlet + src_desc: data.Data + ref_desc: data.Data + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'{self.target} = refset to {self.memlet}' + + +# Classes based on Python's AST NodeVisitor/NodeTransformer for schedule tree nodes +class ScheduleNodeVisitor: + + def visit(self, node: ScheduleTreeNode): + """Visit a node.""" + if isinstance(node, list): + return [self.visit(snode) for snode in node] + if isinstance(node, ScheduleTreeScope) and hasattr(self, 'visit_scope'): + return self.visit_scope(node) + + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + return visitor(node) + + def generic_visit(self, node: ScheduleTreeNode): + if isinstance(node, ScheduleTreeScope): + for child in node.children: + self.visit(child) + + +class ScheduleNodeTransformer(ScheduleNodeVisitor): + + def visit(self, node: ScheduleTreeNode): + if isinstance(node, list): + result = [] + for snode in node: + new_node = self.visit(snode) + if new_node is not None: + result.append(new_node) + return result + + return super().visit(node) + + def generic_visit(self, node: ScheduleTreeNode): + new_values = [] + if isinstance(node, ScheduleTreeScope): + for value in node.children: + if isinstance(value, ScheduleTreeNode): + value = self.visit(value) + if value is None: + continue + elif not isinstance(value, ScheduleTreeNode): + new_values.extend(value) + continue + new_values.append(value) + for val in new_values: + val.parent = node + node.children[:] = new_values + return node diff --git a/dace/sdfg/memlet_utils.py b/dace/sdfg/memlet_utils.py new file mode 100644 index 0000000000..59a2c178d2 --- /dev/null +++ b/dace/sdfg/memlet_utils.py @@ -0,0 +1,79 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import ast +from dace.frontend.python import memlet_parser +from dace import data, Memlet +from typing import Callable, Dict, Optional, Set, Union + + +class MemletReplacer(ast.NodeTransformer): + """ + Iterates over all memlet expressions (name or subscript with matching array in SDFG) in a code block. + The callable can also return another memlet to replace the current one. + """ + + def __init__(self, + arrays: Dict[str, data.Data], + process: Callable[[Memlet], Union[Memlet, None]], + array_filter: Optional[Set[str]] = None) -> None: + """ + Create a new memlet replacer. + + :param arrays: A mapping from array names to data descriptors. + :param process: A callable that takes a memlet and returns a memlet or None. + :param array_filter: An optional subset of array names to process. + """ + self.process = process + self.arrays = arrays + self.array_filter = array_filter or self.arrays.keys() + + def _parse_memlet(self, node: Union[ast.Name, ast.Subscript]) -> Memlet: + """ + Parses a memlet from a subscript or name node. + + :param node: The node to parse. + :return: The parsed memlet. + """ + # Get array name + if isinstance(node, ast.Name): + data = node.id + elif isinstance(node, ast.Subscript): + data = node.value.id + else: + raise TypeError('Expected Name or Subscript') + + # Parse memlet subset + array = self.arrays[data] + subset, newaxes, _ = memlet_parser.parse_memlet_subset(array, node, self.arrays) + if newaxes: + raise NotImplementedError('Adding new axes to memlets is not supported') + + return Memlet(data=data, subset=subset) + + def _memlet_to_ast(self, memlet: Memlet) -> ast.Subscript: + """ + Converts a memlet to a subscript node. + + :param memlet: The memlet to convert. + :return: The converted node. + """ + return ast.parse(f'{memlet.data}[{memlet.subset}]').body[0].value + + def _replace(self, node: Union[ast.Name, ast.Subscript]) -> ast.Subscript: + cur_memlet = self._parse_memlet(node) + new_memlet = self.process(cur_memlet) + if new_memlet is None: + return node + + new_node = self._memlet_to_ast(new_memlet) + return ast.copy_location(new_node, node) + + def visit_Name(self, node: ast.Name): + if node.id in self.array_filter: + return self._replace(node) + return self.generic_visit(node) + + def visit_Subscript(self, node: ast.Subscript): + if isinstance(node.value, ast.Name) and node.value.id in self.array_filter: + return self._replace(node) + return self.generic_visit(node) diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 28431deeea..32369a19a3 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -342,6 +342,10 @@ class Tasklet(CodeNode): 'additional side effects on the system state (e.g., callback). ' 'Defaults to None, which lets the framework make assumptions based on ' 'the tasklet contents') + ignored_symbols = SetProperty(element_type=str, desc='A set of symbols to ignore when computing ' + 'the symbols used by this tasklet. Used to skip certain symbols in non-Python ' + 'tasklets, where only string analysis is possible; and to skip globals in Python ' + 'tasklets that should not be given as parameters to the SDFG.') def __init__(self, label, @@ -355,6 +359,7 @@ def __init__(self, code_exit="", location=None, side_effects=None, + ignored_symbols=None, debuginfo=None): super(Tasklet, self).__init__(label, location, inputs, outputs) @@ -365,6 +370,7 @@ def __init__(self, self.code_init = CodeBlock(code_init, dtypes.Language.CPP) self.code_exit = CodeBlock(code_exit, dtypes.Language.CPP) self.side_effects = side_effects + self.ignored_symbols = ignored_symbols or set() self.debuginfo = debuginfo @property @@ -393,7 +399,11 @@ def validate(self, sdfg, state): @property def free_symbols(self) -> Set[str]: - return self.code.get_free_symbols(self.in_connectors.keys() | self.out_connectors.keys()) + symbols_to_ignore = self.in_connectors.keys() | self.out_connectors.keys() + symbols_to_ignore |= self.ignored_symbols + + return self.code.get_free_symbols(symbols_to_ignore) + def has_side_effects(self, sdfg) -> bool: """ @@ -581,16 +591,19 @@ def from_json(json_obj, context=None): return ret def used_symbols(self, all_symbols: bool) -> Set[str]: - free_syms = set().union(*(map(str, - pystr_to_symbolic(v).free_symbols) for v in self.symbol_mapping.values()), - *(map(str, - pystr_to_symbolic(v).free_symbols) for v in self.location.values())) + free_syms = set().union(*(map(str, pystr_to_symbolic(v).free_symbols) for v in self.location.values())) + + keys_to_use = set(self.symbol_mapping.keys()) # Filter out unused internal symbols from symbol mapping if not all_symbols: internally_used_symbols = self.sdfg.used_symbols(all_symbols=False) - free_syms &= internally_used_symbols - + keys_to_use &= internally_used_symbols + + free_syms |= set().union(*(map(str, + pystr_to_symbolic(v).free_symbols) for k, v in self.symbol_mapping.items() + if k in keys_to_use)) + return free_syms @property @@ -640,7 +653,7 @@ def validate(self, sdfg, state, references: Optional[Set[int]] = None, **context raise NameError('Data descriptor "%s" not found in nested SDFG connectors' % dname) if dname in connectors and desc.transient: raise NameError('"%s" is a connector but its corresponding array is transient' % dname) - + # Validate inout connectors from dace.sdfg import utils # Avoids circular import inout_connectors = self.in_connectors.keys() & self.out_connectors.keys() diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index 5e42830a75..4b36fad4fe 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -124,6 +124,7 @@ def replace_properties_dict(node: Any, if lang is dtypes.Language.CPP: # Replace in C++ code prefix = '' tokenized = tokenize_cpp.findall(code) + active_replacements = set() for name, new_name in reduced_repl.items(): if name not in tokenized: continue @@ -131,8 +132,14 @@ def replace_properties_dict(node: Any, # Use local variables and shadowing to replace replacement = f'auto {name} = {cppunparse.pyexpr2cpp(new_name)};\n' prefix = replacement + prefix + active_replacements.add(name) if prefix: propval.code = prefix + code + + # Ignore replaced symbols since they no longer exist as reads + if isinstance(node, dace.nodes.Tasklet): + node._ignored_symbols.update(active_replacements) + else: warnings.warn('Replacement of %s with %s was not made ' 'for string tasklet code of language %s' % (name, new_name, lang)) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index a23d2616f9..a7b5d90b2b 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -62,7 +62,7 @@ def __getitem__(self, key): token = tokens.pop(0) result = result.members[token] return result - + def __setitem__(self, key, val): if isinstance(key, str) and '.' in key: raise KeyError('NestedDict does not support setting nested keys') @@ -273,7 +273,7 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: rhs_symbols = set() for lhs, rhs in self.assignments.items(): # Always add LHS symbols to the set of candidate free symbols - rhs_symbols |= symbolic.free_symbols_and_functions(rhs) + rhs_symbols |= set(map(str, dace.symbolic.symbols_in_ast(ast.parse(rhs)))) # Add the RHS to the set of candidate defined symbols ONLY if it has not been read yet # This also solves the ordering issue that may arise in cases like the 3rd example above if lhs not in cond_symbols and lhs not in rhs_symbols: @@ -756,7 +756,7 @@ def replace_dict(self, if replace_in_graph: # Replace in inter-state edges for edge in self.edges(): - edge.data.replace_dict(repldict) + edge.data.replace_dict(repldict, replace_keys=replace_keys) # Replace in states for state in self.nodes(): @@ -1335,23 +1335,17 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: defined_syms = set() free_syms = set() - # Exclude data descriptor names, constants, and shapes of global data descriptors - not_strictly_necessary_global_symbols = set() - for name, desc in self.arrays.items(): + # Exclude data descriptor names and constants + for name in self.arrays.keys(): defined_syms.add(name) - if not all_symbols: - used_desc_symbols = desc.used_symbols(all_symbols) - not_strictly_necessary = (desc.used_symbols(all_symbols=True) - used_desc_symbols) - not_strictly_necessary_global_symbols |= set(map(str, not_strictly_necessary)) - defined_syms |= set(self.constants_prop.keys()) - # Start with the set of SDFG free symbols - if all_symbols: - free_syms |= set(self.symbols.keys()) - else: - free_syms |= set(s for s in self.symbols.keys() if s not in not_strictly_necessary_global_symbols) + # Add used symbols from init and exit code + for code in self.init_code.values(): + free_syms |= symbolic.symbols_in_code(code.as_string, self.symbols.keys()) + for code in self.exit_code.values(): + free_syms |= symbolic.symbols_in_code(code.as_string, self.symbols.keys()) # Add free state symbols used_before_assignment = set() @@ -1362,7 +1356,8 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: ordered_states = self.nodes() for state in ordered_states: - free_syms |= state.used_symbols(all_symbols) + state_fsyms = state.used_symbols(all_symbols) + free_syms |= state_fsyms # Add free inter-state symbols for e in self.out_edges(state): @@ -1370,13 +1365,18 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: # subracting the (true) free symbols from the edge's assignment keys. This way we can correctly # compute the symbols that are used before being assigned. efsyms = e.data.used_symbols(all_symbols) - defined_syms |= set(e.data.assignments.keys()) - efsyms + defined_syms |= set(e.data.assignments.keys()) - (efsyms | state_fsyms) used_before_assignment.update(efsyms - defined_syms) free_syms |= efsyms # Remove symbols that were used before they were assigned defined_syms -= used_before_assignment + # Add the set of SDFG symbol parameters + # If all_symbols is False, those symbols would only be added in the case of non-Python tasklets + if all_symbols: + free_syms |= set(self.symbols.keys()) + # Subtract symbols defined in inter-state edges and constants return free_syms - defined_syms @@ -1392,6 +1392,29 @@ def free_symbols(self) -> Set[str]: """ return self.used_symbols(all_symbols=True) + def get_all_toplevel_symbols(self) -> Set[str]: + """ + Returns a set of all symbol names that are used by the SDFG's state machine. + This includes all symbols in the descriptor repository and interstate edges, + whether free or defined. Used to identify duplicates when, e.g., inlining or + dealiasing a set of nested SDFGs. + """ + # Exclude constants and data descriptor names + exclude = set(self.arrays.keys()) | set(self.constants_prop.keys()) + + syms = set() + + # Start with the set of SDFG free symbols + syms |= set(self.symbols.keys()) + + # Add inter-state symbols + for e in self.edges(): + syms |= set(e.data.assignments.keys()) + syms |= e.data.free_symbols + + # Subtract exluded symbols + return syms - exclude + def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: """ Determines what data containers are read and written in this SDFG. Does @@ -1458,7 +1481,7 @@ def init_signature(self, for_call=False, free_symbols=None) -> str: :param for_call: If True, returns arguments that can be used when calling the SDFG. """ # Get global free symbols scalar arguments - free_symbols = free_symbols or self.free_symbols + free_symbols = free_symbols if free_symbols is not None else self.used_symbols(all_symbols=False) return ", ".join( dt.Scalar(self.symbols[k]).as_arg(name=k, with_types=not for_call, for_call=for_call) for k in sorted(free_symbols) if not k.startswith('__dace')) @@ -1478,6 +1501,21 @@ def signature_arglist(self, with_types=True, for_call=False, with_arrays=True, a arglist = arglist or self.arglist(scalars_only=not with_arrays) return [v.as_arg(name=k, with_types=with_types, for_call=for_call) for k, v in arglist.items()] + def python_signature_arglist(self, with_types=True, for_call=False, with_arrays=True, arglist=None) -> List[str]: + """ Returns a list of arguments necessary to call this SDFG, + formatted as a list of Data-Centric Python definitions. + + :param with_types: If True, includes argument types in the result. + :param for_call: If True, returns arguments that can be used when + calling the SDFG. + :param with_arrays: If True, includes arrays, otherwise, + only symbols and scalars are included. + :param arglist: An optional cached argument list. + :return: A list of strings. For example: `['A: dace.float32[M]', 'b: dace.int32']`. + """ + arglist = arglist or self.arglist(scalars_only=not with_arrays, free_symbols=[]) + return [v.as_python_arg(name=k, with_types=with_types, for_call=for_call) for k, v in arglist.items()] + def signature(self, with_types=True, for_call=False, with_arrays=True, arglist=None) -> str: """ Returns a C/C++ signature of this SDFG, used when generating code. @@ -1493,6 +1531,21 @@ def signature(self, with_types=True, for_call=False, with_arrays=True, arglist=N """ return ", ".join(self.signature_arglist(with_types, for_call, with_arrays, arglist)) + def python_signature(self, with_types=True, for_call=False, with_arrays=True, arglist=None) -> str: + """ Returns a Data-Centric Python signature of this SDFG, used when generating code. + + :param with_types: If True, includes argument types (can be used + for a function prototype). If False, only + include argument names (can be used for function + calls). + :param for_call: If True, returns arguments that can be used when + calling the SDFG. + :param with_arrays: If True, includes arrays, otherwise, + only symbols and scalars are included. + :param arglist: An optional cached argument list. + """ + return ", ".join(self.python_signature_arglist(with_types, for_call, with_arrays, arglist)) + def _repr_html_(self): """ HTML representation of the SDFG, used mainly for Jupyter notebooks. """ diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index a4a6648401..8ad0c67bb8 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -7,7 +7,7 @@ import inspect import itertools import warnings -from typing import Any, AnyStr, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, overload +from typing import TYPE_CHECKING, Any, AnyStr, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, overload import dace from dace import data as dt @@ -24,6 +24,9 @@ from dace.sdfg.validation import validate_state from dace.subsets import Range, Subset +if TYPE_CHECKING: + import dace.sdfg.scope + def _getdebuginfo(old_dinfo=None) -> dtypes.DebugInfo: """ Returns a DebugInfo object for the position that called this function. @@ -409,6 +412,13 @@ def scope_children(self, ################################################################### # Query, subgraph, and replacement methods + def is_leaf_memlet(self, e): + if isinstance(e.src, nd.ExitNode) and e.src_conn and e.src_conn.startswith('OUT_'): + return False + if isinstance(e.dst, nd.EntryNode) and e.dst_conn and e.dst_conn.startswith('IN_'): + return False + return True + def used_symbols(self, all_symbols: bool) -> Set[str]: """ Returns a set of symbol names that are used in the state. @@ -428,13 +438,23 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: elif isinstance(n, nd.AccessNode): # Add data descriptor symbols freesyms |= set(map(str, n.desc(sdfg).used_symbols(all_symbols))) - elif (isinstance(n, nd.Tasklet) and n.language == dtypes.Language.Python): - # Consider callbacks defined as symbols as free - for stmt in n.code.code: - for astnode in ast.walk(stmt): - if (isinstance(astnode, ast.Call) and isinstance(astnode.func, ast.Name) - and astnode.func.id in sdfg.symbols): - freesyms.add(astnode.func.id) + elif isinstance(n, nd.Tasklet): + if n.language == dtypes.Language.Python: + # Consider callbacks defined as symbols as free + for stmt in n.code.code: + for astnode in ast.walk(stmt): + if (isinstance(astnode, ast.Call) and isinstance(astnode.func, ast.Name) + and astnode.func.id in sdfg.symbols): + freesyms.add(astnode.func.id) + else: + # Find all string tokens and filter them to sdfg.symbols, while ignoring connectors + codesyms = symbolic.symbols_in_code( + n.code.as_string, + potential_symbols=sdfg.symbols.keys(), + symbols_to_ignore=(n.in_connectors.keys() | n.out_connectors.keys() | n.ignored_symbols), + ) + freesyms |= codesyms + continue if hasattr(n, 'used_symbols'): freesyms |= n.used_symbols(all_symbols) @@ -442,16 +462,9 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: freesyms |= n.free_symbols # Free symbols from memlets - def _is_leaf_memlet(e): - if isinstance(e.src, nd.ExitNode) and e.src_conn and e.src_conn.startswith('OUT_'): - return False - if isinstance(e.dst, nd.EntryNode) and e.dst_conn and e.dst_conn.startswith('IN_'): - return False - return True - for e in self.edges(): # If used for code generation, only consider memlet tree leaves - if not all_symbols and not _is_leaf_memlet(e): + if not all_symbols and not self.is_leaf_memlet(e): continue freesyms |= e.data.used_symbols(all_symbols) @@ -459,7 +472,7 @@ def _is_leaf_memlet(e): # Do not consider SDFG constants as symbols new_symbols.update(set(sdfg.constants.keys())) return freesyms - new_symbols - + @property def free_symbols(self) -> Set[str]: """ @@ -471,7 +484,6 @@ def free_symbols(self) -> Set[str]: """ return self.used_symbols(all_symbols=True) - def defined_symbols(self) -> Dict[str, dt.Data]: """ Returns a dictionary that maps currently-defined symbols in this SDFG @@ -532,8 +544,8 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, # Filter out memlets which go out but the same data is written to the AccessNode by another memlet for out_edge in list(out_edges): for in_edge in list(in_edges): - if (in_edge.data.data == out_edge.data.data and - in_edge.data.dst_subset.covers(out_edge.data.src_subset)): + if (in_edge.data.data == out_edge.data.data + and in_edge.data.dst_subset.covers(out_edge.data.src_subset)): out_edges.remove(out_edge) break @@ -800,7 +812,7 @@ def __init__(self, label=None, sdfg=None, debuginfo=None, location=None): self.nosync = False self.location = location if location is not None else {} self._default_lineinfo = None - + def __deepcopy__(self, memo): cls = self.__class__ result = cls.__new__(cls) @@ -1450,7 +1462,7 @@ def add_reduce( """ import dace.libraries.standard as stdlib # Avoid import loop debuginfo = _getdebuginfo(debuginfo or self._default_lineinfo) - result = stdlib.Reduce(wcr, axes, identity, schedule=schedule, debuginfo=debuginfo) + result = stdlib.Reduce('Reduce', wcr, axes, identity, schedule=schedule, debuginfo=debuginfo) self.add_node(result) return result diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 3396335ece..1078414161 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -810,7 +810,7 @@ def get_view_edge(state: SDFGState, view: nd.AccessNode) -> gr.MultiConnectorEdg out_edges = state.out_edges(view) # Invalid case: No data to view - if len(in_edges) == 0 or len(out_edges) == 0: + if len(in_edges) == 0 and len(out_edges) == 0: return None # If there is one edge (in/out) that leads (via memlet path) to an access diff --git a/dace/symbolic.py b/dace/symbolic.py index 0ab6e3f6ff..e9249218f9 100644 --- a/dace/symbolic.py +++ b/dace/symbolic.py @@ -14,6 +14,7 @@ from dace import dtypes DEFAULT_SYMBOL_TYPE = dtypes.int32 +_NAME_TOKENS = re.compile(r'[a-zA-Z_][a-zA-Z_0-9]*') # NOTE: Up to (including) version 1.8, sympy.abc._clash is a dictionary of the # form {'N': sympy.abc.N, 'I': sympy.abc.I, 'pi': sympy.abc.pi} @@ -1377,6 +1378,29 @@ def equal(a: SymbolicType, b: SymbolicType, is_length: bool = True) -> Union[boo if is_length: for arg in args: facts += [sympy.Q.integer(arg), sympy.Q.positive(arg)] - + with sympy.assuming(*facts): return sympy.ask(sympy.Q.is_true(sympy.Eq(*args))) + + +def symbols_in_code(code: str, potential_symbols: Set[str] = None, + symbols_to_ignore: Set[str] = None) -> Set[str]: + """ + Tokenizes a code string for symbols and returns a set thereof. + + :param code: The code to tokenize. + :param potential_symbols: If not None, filters symbols to this given set. + :param symbols_to_ignore: If not None, filters out symbols from this set. + """ + if not code: + return set() + if potential_symbols is not None and len(potential_symbols) == 0: + # Don't bother tokenizing for an empty set of potential symbols + return set() + + tokens = set(re.findall(_NAME_TOKENS, code)) + if potential_symbols is not None: + tokens &= potential_symbols + if symbols_to_ignore is None: + return tokens + return tokens - symbols_to_ignore diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 73da318e94..8986c4e37f 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -1307,6 +1307,23 @@ def redirect_edge(state: SDFGState, return new_edge +def replace_code_to_code_edges(sdfg: SDFG): + """ + Adds access nodes between all code->code edges in each state. + + :param sdfg: The SDFG to process. + """ + for state in sdfg.nodes(): + for edge in state.edges(): + if not isinstance(edge.src, nodes.CodeNode) or not isinstance(edge.dst, nodes.CodeNode): + continue + # Add access nodes + aname = state.add_access(edge.data.data) + state.add_edge(edge.src, edge.src_conn, aname, None, edge.data) + state.add_edge(aname, None, edge.dst, edge.dst_conn, copy.deepcopy(edge.data)) + state.remove_edge(edge) + + def can_run_state_on_fpga(state: SDFGState): """ Checks if state can be executed on FPGA. Used by FPGATransformState diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index c197adf827..9cec6d11af 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -102,12 +102,8 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = for e in sdfg.out_edges(state): e.data.replace_dict(mapping, replace_keys=False) - # If symbols are never unknown any longer, remove from SDFG + # Gather initial propagated symbols result = {k: v for k, v in symbols_replaced.items() if k not in remaining_unknowns} - # Remove from symbol repository - for sym in result: - if sym in sdfg.symbols: - sdfg.remove_symbol(sym) # Remove single-valued symbols from data descriptors (e.g., symbolic array size) sdfg.replace_dict({k: v @@ -121,6 +117,14 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = for sym in intersection: del edge.data.assignments[sym] + # If symbols are never unknown any longer, remove from SDFG + fsyms = sdfg.used_symbols(all_symbols=False) + result = {k: v for k, v in result.items() if k not in fsyms} + for sym in result: + if sym in sdfg.symbols: + # Remove from symbol repository and nested SDFG symbol mapipng + sdfg.remove_symbol(sym) + result = set(result.keys()) if self.recursive: @@ -188,7 +192,7 @@ def collect_constants(self, if len(in_edges) == 1: # Special case, propagate as-is if state not in result: # Condition evaluates to False when state is the start-state result[state] = {} - + # First the prior state if in_edges[0].src in result: # Condition evaluates to False when state is the start-state self._propagate(result[state], result[in_edges[0].src]) diff --git a/dace/transformation/passes/prune_symbols.py b/dace/transformation/passes/prune_symbols.py index 94fcbdbc58..cf55f7a9b2 100644 --- a/dace/transformation/passes/prune_symbols.py +++ b/dace/transformation/passes/prune_symbols.py @@ -1,16 +1,13 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. import itertools -import re from dataclasses import dataclass from typing import Optional, Set, Tuple -from dace import SDFG, dtypes, properties +from dace import SDFG, dtypes, properties, symbolic from dace.sdfg import nodes from dace.transformation import pass_pipeline as ppl -_NAME_TOKENS = re.compile(r'[a-zA-Z_][a-zA-Z_0-9]*') - @dataclass(unsafe_hash=True) @properties.make_properties @@ -81,7 +78,7 @@ def used_symbols(self, sdfg: SDFG) -> Set[str]: # Add symbols in global/init/exit code for code in itertools.chain(sdfg.global_code.values(), sdfg.init_code.values(), sdfg.exit_code.values()): - result |= _symbols_in_code(code.as_string) + result |= symbolic.symbols_in_code(code.as_string) for desc in sdfg.arrays.values(): result |= set(map(str, desc.free_symbols)) @@ -94,21 +91,19 @@ def used_symbols(self, sdfg: SDFG) -> Set[str]: for node in state.nodes(): if isinstance(node, nodes.Tasklet): if node.code.language != dtypes.Language.Python: - result |= _symbols_in_code(node.code.as_string) + result |= symbolic.symbols_in_code(node.code.as_string, sdfg.symbols.keys(), + node.ignored_symbols) if node.code_global.language != dtypes.Language.Python: - result |= _symbols_in_code(node.code_global.as_string) + result |= symbolic.symbols_in_code(node.code_global.as_string, sdfg.symbols.keys(), + node.ignored_symbols) if node.code_init.language != dtypes.Language.Python: - result |= _symbols_in_code(node.code_init.as_string) + result |= symbolic.symbols_in_code(node.code_init.as_string, sdfg.symbols.keys(), + node.ignored_symbols) if node.code_exit.language != dtypes.Language.Python: - result |= _symbols_in_code(node.code_exit.as_string) - + result |= symbolic.symbols_in_code(node.code_exit.as_string, sdfg.symbols.keys(), + node.ignored_symbols) for e in sdfg.edges(): result |= e.data.free_symbols return result - -def _symbols_in_code(code: str) -> Set[str]: - if not code: - return set() - return set(re.findall(_NAME_TOKENS, code)) diff --git a/tests/schedule_tree/naming_test.py b/tests/schedule_tree/naming_test.py new file mode 100644 index 0000000000..0811682870 --- /dev/null +++ b/tests/schedule_tree/naming_test.py @@ -0,0 +1,204 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree +from dace.transformation.passes.constant_propagation import ConstantPropagation + +import pytest +from typing import List + + +def _irreducible_loop_to_loop(): + sdfg = dace.SDFG('irreducible') + # Add a simple chain of two for loops with goto from second to first's body + s1 = sdfg.add_state_after(sdfg.add_state_after(sdfg.add_state())) + s2 = sdfg.add_state() + e = sdfg.add_state() + + # Add a loop + l1 = sdfg.add_state() + l2 = sdfg.add_state_after(l1) + sdfg.add_loop(s1, l1, s2, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l2) + + l3 = sdfg.add_state() + l4 = sdfg.add_state_after(l3) + sdfg.add_loop(s2, l3, e, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l4) + + # Irreducible part + sdfg.add_edge(l3, l1, dace.InterstateEdge('i < 5')) + + # Avoiding undefined behavior + sdfg.edges_between(l3, l4)[0].data.condition.as_string = 'i >= 5' + + return sdfg + + +def _nested_irreducible_loops(): + sdfg = _irreducible_loop_to_loop() + nsdfg = _irreducible_loop_to_loop() + + l1 = sdfg.node(5) + l1.add_nested_sdfg(nsdfg, None, {}, {}) + return sdfg + + +def test_clash_states(): + """ + Same test as test_irreducible_in_loops, but all states in the nested SDFG share names with the top SDFG + """ + sdfg = _nested_irreducible_loops() + + stree = as_schedule_tree(sdfg) + unique_names = set() + for node in stree.preorder_traversal(): + if isinstance(node, tn.StateLabel): + if node.state.name in unique_names: + raise NameError('Name clash') + unique_names.add(node.state.name) + + +@pytest.mark.parametrize('constprop', (False, True)) +def test_clash_symbol_mapping(constprop): + sdfg = dace.SDFG('tester') + sdfg.add_array('A', [200], dace.float64) + sdfg.add_symbol('M', dace.int64) + sdfg.add_symbol('N', dace.int64) + sdfg.add_symbol('k', dace.int64) + + state = sdfg.add_state() + state2 = sdfg.add_state() + sdfg.add_edge(state, state2, dace.InterstateEdge(assignments={'k': 'M + 1'})) + + nsdfg = dace.SDFG('nester') + nsdfg.add_symbol('M', dace.int64) + nsdfg.add_symbol('N', dace.int64) + nsdfg.add_symbol('k', dace.int64) + nsdfg.add_array('out', [100], dace.float64) + nsdfg.add_transient('tmp', [100], dace.float64) + nstate = nsdfg.add_state() + nstate2 = nsdfg.add_state() + nsdfg.add_edge(nstate, nstate2, dace.InterstateEdge(assignments={'k': 'M + 1'})) + + # Copy + # The code should end up as `tmp[N:N+2] <- out[M+1:M+3]` + # In the outer SDFG: `tmp[N:N+2] <- A[M+101:M+103]` + r = nstate.add_access('out') + w = nstate.add_access('tmp') + nstate.add_edge(r, None, w, None, dace.Memlet(data='out', subset='k:k+2', other_subset='M:M+2')) + + # Tasklet + # The code should end up as `tmp[M] -> Tasklet -> out[N + 1]` + # In the outer SDFG: `tmp[M] -> Tasklet -> A[N + 101]` + r = nstate2.add_access('tmp') + w = nstate2.add_access('out') + t = nstate2.add_tasklet('dosomething', {'a'}, {'b'}, 'b = a + 1') + nstate2.add_edge(r, None, t, 'a', dace.Memlet('tmp[N]')) + nstate2.add_edge(t, 'b', w, None, dace.Memlet('out[k]')) + + # Connect nested SDFG to parent SDFG with an offset memlet + nsdfg_node = state2.add_nested_sdfg(nsdfg, None, {}, {'out'}, {'N': 'M', 'M': 'N', 'k': 'k'}) + w = state2.add_write('A') + state2.add_edge(nsdfg_node, 'out', w, None, dace.Memlet('A[100:200]')) + + # Get rid of k + if constprop: + ConstantPropagation().apply_pass(sdfg, {}) + + stree = as_schedule_tree(sdfg) + assert len(stree.children) in (2, 4) # Either with assignments or without + + # With assignments + if len(stree.children) == 4: + assert constprop is False + assert isinstance(stree.children[0], tn.AssignNode) + assert isinstance(stree.children[1], tn.CopyNode) + assert isinstance(stree.children[2], tn.AssignNode) + assert isinstance(stree.children[3], tn.TaskletNode) + assert stree.children[1].memlet.data == 'A' + assert str(stree.children[1].memlet.src_subset) == 'k + 100:k + 102' + assert str(stree.children[1].memlet.dst_subset) == 'N:N + 2' + assert stree.children[3].in_memlets['a'].data == 'tmp' + assert str(stree.children[3].in_memlets['a'].src_subset) == 'M' + assert stree.children[3].out_memlets['b'].data == 'A' + assert str(stree.children[3].out_memlets['b'].dst_subset) == 'k + 100' + else: + assert constprop is True + assert isinstance(stree.children[0], tn.CopyNode) + assert isinstance(stree.children[1], tn.TaskletNode) + assert stree.children[0].memlet.data == 'A' + assert str(stree.children[0].memlet.src_subset) == 'M + 101:M + 103' + assert str(stree.children[0].memlet.dst_subset) == 'N:N + 2' + assert stree.children[1].in_memlets['a'].data == 'tmp' + assert str(stree.children[1].in_memlets['a'].src_subset) == 'M' + assert stree.children[1].out_memlets['b'].data == 'A' + assert str(stree.children[1].out_memlets['b'].dst_subset) == 'N + 101' + + +def test_edgecase_symbol_mapping(): + sdfg = dace.SDFG('tester') + sdfg.add_symbol('M', dace.int64) + sdfg.add_symbol('N', dace.int64) + + state = sdfg.add_state() + state2 = sdfg.add_state_after(state) + + nsdfg = dace.SDFG('nester') + nsdfg.add_symbol('M', dace.int64) + nsdfg.add_symbol('N', dace.int64) + nsdfg.add_symbol('k', dace.int64) + nstate = nsdfg.add_state() + nstate.add_tasklet('dosomething', {}, {}, 'print(k)', side_effects=True) + nstate2 = nsdfg.add_state() + nstate3 = nsdfg.add_state() + nsdfg.add_edge(nstate, nstate2, dace.InterstateEdge(assignments={'k': 'M + 1'})) + nsdfg.add_edge(nstate2, nstate3, dace.InterstateEdge(assignments={'l': 'k'})) + + state2.add_nested_sdfg(nsdfg, None, {}, {}, {'N': 'M', 'M': 'N', 'k': 'M + 1'}) + + stree = as_schedule_tree(sdfg) + + # k is reassigned internally, so that should be preserved + assert len(stree.children) == 3 + assert isinstance(stree.children[0], tn.TaskletNode) + assert 'M + 1' in stree.children[0].node.code.as_string + assert isinstance(stree.children[1], tn.AssignNode) + assert stree.children[1].name == 'k' + assert stree.children[1].value.as_string == '(N + 1)' + assert isinstance(stree.children[2], tn.AssignNode) + assert stree.children[2].name == 'l' + assert stree.children[2].value.as_string in ('k', '(N + 1)') + + +def _check_for_name_clashes(stree: tn.ScheduleTreeNode): + + def _traverse(node: tn.ScheduleTreeScope, scopes: List[str]): + for child in node.children: + if isinstance(child, tn.ForScope): + itervar = child.header.itervar + if itervar in scopes: + raise NameError('Nested scope redefines iteration variable') + _traverse(child, scopes + [itervar]) + elif isinstance(child, tn.MapScope): + itervars = child.node.map.params + if any(itervar in scopes for itervar in itervars): + raise NameError('Nested scope redefines iteration variable') + _traverse(child, scopes + itervars) + elif isinstance(child, tn.ScheduleTreeScope): + _traverse(child, scopes) + + _traverse(stree, []) + + +def test_clash_iteration_symbols(): + sdfg = _nested_irreducible_loops() + + stree = as_schedule_tree(sdfg) + _check_for_name_clashes(stree) + + +if __name__ == '__main__': + test_clash_states() + test_clash_symbol_mapping(False) + test_clash_symbol_mapping(True) + test_edgecase_symbol_mapping() + test_clash_iteration_symbols() diff --git a/tests/schedule_tree/nesting_test.py b/tests/schedule_tree/nesting_test.py new file mode 100644 index 0000000000..161f15d6c1 --- /dev/null +++ b/tests/schedule_tree/nesting_test.py @@ -0,0 +1,234 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" +Nesting and dealiasing tests for schedule trees. +""" +import dace +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree +from dace.transformation.dataflow import RemoveSliceView + +import pytest + +N = dace.symbol('N') +T = dace.symbol('T') + + +def test_stree_mpath_multiscope(): + + @dace.program + def tester(A: dace.float64[N, N]): + for i in dace.map[0:N:T]: + for j, k in dace.map[0:T, 0:N]: + for l in dace.map[0:T]: + A[i + j, k + l] = 1 + + # The test should generate different SDFGs for different simplify configurations, + # but the same schedule tree + stree = as_schedule_tree(tester.to_sdfg()) + assert [type(n) for n in stree.preorder_traversal()][1:] == [tn.MapScope, tn.MapScope, tn.MapScope, tn.TaskletNode] + + +def test_stree_mpath_multiscope_dependent(): + + @dace.program + def tester(A: dace.float64[N, N]): + for i in dace.map[0:N:T]: + for j, k in dace.map[0:T, 0:N]: + for l in dace.map[0:k]: + A[i + j, l] = 1 + + # The test should generate different SDFGs for different simplify configurations, + # but the same schedule tree + stree = as_schedule_tree(tester.to_sdfg()) + assert [type(n) for n in stree.preorder_traversal()][1:] == [tn.MapScope, tn.MapScope, tn.MapScope, tn.TaskletNode] + + +def test_stree_mpath_nested(): + + @dace.program + def nester(A, i, k, j): + for l in range(k): + A[i + j, l] = 1 + + @dace.program + def tester(A: dace.float64[N, N]): + for i in dace.map[0:N:T]: + for j, k in dace.map[0:T, 0:N]: + nester(A, i, j, k) + + stree = as_schedule_tree(tester.to_sdfg()) + + # Simplifying yields a different SDFG due to scalars and symbols, so testing is slightly different + simplified = dace.Config.get_bool('optimizer', 'automatic_simplification') + + if simplified: + assert [type(n) + for n in stree.preorder_traversal()][1:] == [tn.MapScope, tn.MapScope, tn.ForScope, tn.TaskletNode] + + tasklet: tn.TaskletNode = list(stree.preorder_traversal())[-1] + + if simplified: + assert str(next(iter(tasklet.out_memlets.values()))) == 'A[i + k, l]' + else: + assert str(next(iter(tasklet.out_memlets.values()))).endswith(', l]') + + +@pytest.mark.parametrize('dst_subset', (False, True)) +def test_stree_copy_same_scope(dst_subset): + sdfg = dace.SDFG('tester') + sdfg.add_array('A', [3 * N], dace.float64) + sdfg.add_array('B', [3 * N], dace.float64) + state = sdfg.add_state() + + r = state.add_read('A') + w = state.add_write('B') + if not dst_subset: + state.add_nedge(r, w, dace.Memlet(data='A', subset='2*N:3*N', other_subset='N:2*N')) + else: + state.add_nedge(r, w, dace.Memlet(data='B', subset='N:2*N', other_subset='2*N:3*N')) + + stree = as_schedule_tree(sdfg) + assert len(stree.children) == 1 and isinstance(stree.children[0], tn.CopyNode) + assert stree.children[0].target == 'B' + assert stree.children[0].as_string() == 'B[N:2*N] = copy A[2*N:3*N]' + + +@pytest.mark.parametrize('dst_subset', (False, True)) +def test_stree_copy_different_scope(dst_subset): + sdfg = dace.SDFG('tester') + sdfg.add_array('A', [3 * N], dace.float64) + sdfg.add_array('B', [3 * N], dace.float64) + state = sdfg.add_state() + + r = state.add_read('A') + w = state.add_write('B') + me, mx = state.add_map('something', dict(i='0:1')) + if not dst_subset: + state.add_memlet_path(r, me, w, memlet=dace.Memlet(data='A', subset='2*N:3*N', other_subset='N + i:2*N + i')) + else: + state.add_memlet_path(r, me, w, memlet=dace.Memlet(data='B', subset='N + i:2*N + i', other_subset='2*N:3*N')) + state.add_nedge(w, mx, dace.Memlet()) + + stree = as_schedule_tree(sdfg) + stree_nodes = list(stree.preorder_traversal())[1:] + assert [type(n) for n in stree_nodes] == [tn.MapScope, tn.CopyNode] + assert stree_nodes[-1].target == 'B' + assert stree_nodes[-1].as_string() == 'B[N + i:2*N + i] = copy A[2*N:3*N]' + + +def test_dealias_nested_call(): + + @dace.program + def nester(a, b): + b[:] = a + + @dace.program + def tester(a: dace.float64[40], b: dace.float64[40]): + nester(b[1:21], a[10:30]) + + sdfg = tester.to_sdfg(simplify=False) + sdfg.apply_transformations_repeated(RemoveSliceView) + + stree = as_schedule_tree(sdfg) + assert len(stree.children) == 1 + copy = stree.children[0] + assert isinstance(copy, tn.CopyNode) + assert copy.target == 'a' + assert copy.memlet.data == 'b' + assert str(copy.memlet.src_subset) == '1:21' + assert str(copy.memlet.dst_subset) == '10:30' + + +def test_dealias_nested_call_samearray(): + + @dace.program + def nester(a, b): + b[:] = a + + @dace.program + def tester(a: dace.float64[40]): + nester(a[1:21], a[10:30]) + + sdfg = tester.to_sdfg(simplify=False) + sdfg.apply_transformations_repeated(RemoveSliceView) + + stree = as_schedule_tree(sdfg) + assert len(stree.children) == 1 + copy = stree.children[0] + assert isinstance(copy, tn.CopyNode) + assert copy.target == 'a' + assert copy.memlet.data == 'a' + assert str(copy.memlet.src_subset) == '1:21' + assert str(copy.memlet.dst_subset) == '10:30' + + +@pytest.mark.parametrize('simplify', (False, True)) +def test_dealias_memlet_composition(simplify): + + def nester2(c): + c[2] = 1 + + def nester1(b): + nester2(b[-5:]) + + @dace.program + def tester(a: dace.float64[N, N]): + nester1(a[:, 1]) + + sdfg = tester.to_sdfg(simplify=simplify) + stree = as_schedule_tree(sdfg) + + # Simplifying yields a different SDFG due to views, so testing is slightly different + if simplify: + assert len(stree.children) == 1 + tasklet = stree.children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert str(next(iter(tasklet.out_memlets.values()))) == 'a[N - 3, 1]' + else: + assert len(stree.children) == 3 + stree_nodes = list(stree.preorder_traversal())[1:] + assert [type(n) for n in stree_nodes] == [tn.ViewNode, tn.ViewNode, tn.TaskletNode] + + +def test_dealias_interstate_edge(): + sdfg = dace.SDFG('tester') + sdfg.add_array('A', [20], dace.float64) + sdfg.add_array('B', [20], dace.float64) + + nsdfg = dace.SDFG('nester') + nsdfg.add_array('A', [19], dace.float64) + nsdfg.add_array('B', [15], dace.float64) + nsdfg.add_symbol('m', dace.float64) + nstate1 = nsdfg.add_state() + nstate2 = nsdfg.add_state() + nsdfg.add_edge(nstate1, nstate2, dace.InterstateEdge(condition='B[1] > 0', assignments=dict(m='A[2]'))) + + # Connect to nested SDFG both with flipped definitions and offset memlets + state = sdfg.add_state() + nsdfg_node = state.add_nested_sdfg(nsdfg, None, {'A', 'B'}, {}) + ra = state.add_read('A') + rb = state.add_read('B') + state.add_edge(ra, None, nsdfg_node, 'B', dace.Memlet('A[1:20]')) + state.add_edge(rb, None, nsdfg_node, 'A', dace.Memlet('B[2:17]')) + + sdfg.validate() + stree = as_schedule_tree(sdfg) + nodes = list(stree.preorder_traversal())[1:] + assert [type(n) for n in nodes] == [tn.StateIfScope, tn.GotoNode, tn.AssignNode] + assert 'A[2]' in nodes[0].condition.as_string + assert 'B[4]' in nodes[-1].value.as_string + + +if __name__ == '__main__': + test_stree_mpath_multiscope() + test_stree_mpath_multiscope_dependent() + test_stree_mpath_nested() + test_stree_copy_same_scope(False) + test_stree_copy_same_scope(True) + test_stree_copy_different_scope(False) + test_stree_copy_different_scope(True) + test_dealias_nested_call() + test_dealias_nested_call_samearray() + test_dealias_memlet_composition(False) + test_dealias_memlet_composition(True) + test_dealias_interstate_edge() diff --git a/tests/schedule_tree/schedule_test.py b/tests/schedule_tree/schedule_test.py new file mode 100644 index 0000000000..09779c670f --- /dev/null +++ b/tests/schedule_tree/schedule_test.py @@ -0,0 +1,289 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import dace +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree +import numpy as np + + +def test_for_in_map_in_for(): + + @dace.program + def matmul(A: dace.float32[10, 10], B: dace.float32[10, 10], C: dace.float32[10, 10]): + for i in range(10): + for j in dace.map[0:10]: + atile = dace.define_local([10], dace.float32) + atile[:] = A[i] + for k in range(10): + with dace.tasklet: + a << atile[k] + b << B[k, j] + cin << C[i, j] + c >> C[i, j] + c = cin + a * b + + sdfg = matmul.to_sdfg() + stree = as_schedule_tree(sdfg) + + assert len(stree.children) == 1 # for + fornode = stree.children[0] + assert isinstance(fornode, tn.ForScope) + assert len(fornode.children) == 1 # map + mapnode = fornode.children[0] + assert isinstance(mapnode, tn.MapScope) + assert len(mapnode.children) == 2 # copy, for + copynode, fornode = mapnode.children + assert isinstance(copynode, tn.CopyNode) + assert isinstance(fornode, tn.ForScope) + assert len(fornode.children) == 1 # tasklet + tasklet = fornode.children[0] + assert isinstance(tasklet, tn.TaskletNode) + + +def test_libnode(): + M, N, K = (dace.symbol(s) for s in 'MNK') + + @dace.program + def matmul_lib(a: dace.float64[M, K], b: dace.float64[K, N]): + return a @ b + + sdfg = matmul_lib.to_sdfg() + stree = as_schedule_tree(sdfg) + assert len(stree.children) == 1 + assert isinstance(stree.children[0], tn.LibraryCall) + assert (stree.children[0].as_string() == + '__return[0:M, 0:N] = library MatMul[alpha=1, beta=0](a[0:M, 0:K], b[0:K, 0:N])') + + +def test_nesting(): + + @dace.program + def nest2(a: dace.float64[10]): + a += 1 + + @dace.program + def nest1(a: dace.float64[5, 10]): + for i in range(5): + nest2(a[:, i]) + + @dace.program + def main(a: dace.float64[20, 10]): + nest1(a[:5]) + nest1(a[5:10]) + nest1(a[10:15]) + nest1(a[15:]) + + sdfg = main.to_sdfg(simplify=True) + stree = as_schedule_tree(sdfg) + + # Despite two levels of nesting, immediate children are the 4 for loops + assert len(stree.children) == 4 + offsets = ['', '5', '10', '15'] + for fornode, offset in zip(stree.children, offsets): + assert isinstance(fornode, tn.ForScope) + assert len(fornode.children) == 1 # map + mapnode = fornode.children[0] + assert isinstance(mapnode, tn.MapScope) + assert len(mapnode.children) == 1 # tasklet + tasklet = mapnode.children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert offset in str(next(iter(tasklet.in_memlets.values()))) + + +def test_nesting_view(): + + @dace.program + def nest2(a: dace.float64[40]): + a += 1 + + @dace.program + def nest1(a): + for i in range(5): + subset = a[:, i, :] + nest2(subset.reshape((40, ))) + + @dace.program + def main(a: dace.float64[20, 10]): + nest1(a.reshape((4, 5, 10))) + + sdfg = main.to_sdfg() + stree = as_schedule_tree(sdfg) + assert any(isinstance(node, tn.ViewNode) for node in stree.children) + + +def test_nesting_nview(): + + @dace.program + def nest2(a: dace.float64[40]): + a += 1 + + @dace.program + def nest1(a: dace.float64[4, 5, 10]): + for i in range(5): + nest2(a[:, i, :]) + + @dace.program + def main(a: dace.float64[20, 10]): + nest1(a) + + sdfg = main.to_sdfg() + stree = as_schedule_tree(sdfg) + assert isinstance(stree.children[0], tn.NView) + + +def test_irreducible_sub_sdfg(): + sdfg = dace.SDFG('irreducible') + # Add a simple chain + s = sdfg.add_state_after(sdfg.add_state_after(sdfg.add_state())) + # Add an irreducible CFG + s1 = sdfg.add_state() + s2 = sdfg.add_state() + + sdfg.add_edge(s, s1, dace.InterstateEdge('a < b')) + # sdfg.add_edge(s, s2, dace.InterstateEdge('a >= b')) + sdfg.add_edge(s1, s2, dace.InterstateEdge('b > 9')) + sdfg.add_edge(s2, s1, dace.InterstateEdge('b < 19')) + e = sdfg.add_state() + sdfg.add_edge(s1, e, dace.InterstateEdge('a < 0')) + sdfg.add_edge(s2, e, dace.InterstateEdge('b < 0')) + + # Add a loop following general block + sdfg.add_loop(e, sdfg.add_state(), None, 'i', '0', 'i < 10', 'i + 1') + + stree = as_schedule_tree(sdfg) + node_types = [type(n) for n in stree.preorder_traversal()] + assert node_types.count(tn.GBlock) == 1 # Only one gblock + assert node_types[-1] == tn.ForScope # Check that loop was detected + + +def test_irreducible_in_loops(): + sdfg = dace.SDFG('irreducible') + # Add a simple chain of two for loops with goto from second to first's body + s1 = sdfg.add_state_after(sdfg.add_state_after(sdfg.add_state())) + s2 = sdfg.add_state() + e = sdfg.add_state() + + # Add a loop + l1 = sdfg.add_state() + l2 = sdfg.add_state_after(l1) + sdfg.add_loop(s1, l1, s2, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l2) + + l3 = sdfg.add_state() + l4 = sdfg.add_state_after(l3) + sdfg.add_loop(s2, l3, e, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l4) + + # Irreducible part + sdfg.add_edge(l3, l1, dace.InterstateEdge('i < 5')) + + # Avoiding undefined behavior + sdfg.edges_between(l3, l4)[0].data.condition.as_string = 'i >= 5' + + stree = as_schedule_tree(sdfg) + node_types = [type(n) for n in stree.preorder_traversal()] + assert node_types.count(tn.GBlock) == 1 + assert node_types.count(tn.ForScope) == 2 + + +def test_reference(): + sdfg = dace.SDFG('tester') + sdfg.add_symbol('n', dace.int32) + sdfg.add_array('A', [20], dace.float64) + sdfg.add_array('B', [20], dace.float64) + sdfg.add_array('C', [20], dace.float64) + sdfg.add_reference('ref', [20], dace.float64) + + init = sdfg.add_state() + s1 = sdfg.add_state() + s2 = sdfg.add_state() + end = sdfg.add_state() + sdfg.add_edge(init, s1, dace.InterstateEdge('n > 0')) + sdfg.add_edge(init, s2, dace.InterstateEdge('n <= 0')) + sdfg.add_edge(s1, end, dace.InterstateEdge()) + sdfg.add_edge(s2, end, dace.InterstateEdge()) + + s1.add_edge(s1.add_access('A'), None, s1.add_access('ref'), 'set', dace.Memlet('A[0:20]')) + s2.add_edge(s2.add_access('B'), None, s2.add_access('ref'), 'set', dace.Memlet('B[0:20]')) + end.add_nedge(end.add_access('ref'), end.add_access('C'), dace.Memlet('ref[0:20]')) + + stree = as_schedule_tree(sdfg) + nodes = list(stree.preorder_traversal())[1:] + assert [type(n) for n in nodes] == [tn.IfScope, tn.RefSetNode, tn.ElseScope, tn.RefSetNode, tn.CopyNode] + assert nodes[1].as_string() == 'ref = refset to A[0:20]' + assert nodes[3].as_string() == 'ref = refset to B[0:20]' + + +def test_code_to_code(): + sdfg = dace.SDFG('tester') + sdfg.add_scalar('scal', dace.int32, transient=True) + state = sdfg.add_state() + t1 = state.add_tasklet('a', {}, {'out'}, 'out = 5') + t2 = state.add_tasklet('b', {'inp'}, {}, 'print(inp)', side_effects=True) + state.add_edge(t1, 'out', t2, 'inp', dace.Memlet('scal')) + + stree = as_schedule_tree(sdfg) + assert len(stree.children) == 2 + assert all(isinstance(c, tn.TaskletNode) for c in stree.children) + assert stree.children[1].as_string().startswith('tasklet(scal') + + +def test_dyn_map_range(): + H = dace.symbol() + nnz = dace.symbol('nnz') + W = dace.symbol() + + @dace.program + def spmv(A_row: dace.uint32[H + 1], A_col: dace.uint32[nnz], A_val: dace.float32[nnz], x: dace.float32[W]): + b = np.zeros([H], dtype=np.float32) + + for i in dace.map[0:H]: + for j in dace.map[A_row[i]:A_row[i + 1]]: + b[i] += A_val[j] * x[A_col[j]] + + return b + + sdfg = spmv.to_sdfg() + stree = as_schedule_tree(sdfg) + assert len(stree.children) == 2 + assert all(isinstance(c, tn.MapScope) for c in stree.children) + mapscope = stree.children[1] + start, end, dynrangemap = mapscope.children + assert isinstance(start, tn.DynScopeCopyNode) + assert isinstance(end, tn.DynScopeCopyNode) + assert isinstance(dynrangemap, tn.MapScope) + + +def test_multiview(): + sdfg = dace.SDFG('tester') + sdfg.add_array('A', [20, 20], dace.float64) + sdfg.add_array('B', [20, 20], dace.float64) + sdfg.add_view('Av', [400], dace.float64) + sdfg.add_view('Avv', [10, 40], dace.float64) + sdfg.add_view('Bv', [400], dace.float64) + sdfg.add_view('Bvv', [10, 40], dace.float64) + state = sdfg.add_state() + av = state.add_access('Av') + bv = state.add_access('Bv') + bvv = state.add_access('Bvv') + avv = state.add_access('Avv') + state.add_edge(state.add_read('A'), None, av, None, dace.Memlet('A[0:20, 0:20]')) + state.add_edge(av, None, avv, 'views', dace.Memlet('Av[0:400]')) + state.add_edge(avv, None, bvv, None, dace.Memlet('Avv[0:10, 0:40]')) + state.add_edge(bvv, 'views', bv, None, dace.Memlet('Bv[0:400]')) + state.add_edge(bv, 'views', state.add_write('B'), None, dace.Memlet('Bv[0:400]')) + + stree = as_schedule_tree(sdfg) + assert [type(n) for n in stree.children] == [tn.ViewNode, tn.ViewNode, tn.ViewNode, tn.ViewNode, tn.CopyNode] + + +if __name__ == '__main__': + test_for_in_map_in_for() + test_libnode() + test_nesting() + test_nesting_view() + test_nesting_nview() + test_irreducible_sub_sdfg() + test_irreducible_in_loops() + test_reference() + test_code_to_code() + test_dyn_map_range() + test_multiview() diff --git a/tests/sdfg/memlet_utils_test.py b/tests/sdfg/memlet_utils_test.py new file mode 100644 index 0000000000..467838fc56 --- /dev/null +++ b/tests/sdfg/memlet_utils_test.py @@ -0,0 +1,67 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import dace +import numpy as np +import pytest +from dace.sdfg import memlet_utils as mu + + +def _replace_zero_with_one(memlet: dace.Memlet) -> dace.Memlet: + for i, s in enumerate(memlet.subset): + if s == 0: + memlet.subset[i] = 1 + return memlet + + +@pytest.mark.parametrize('filter_type', ['none', 'same_array', 'different_array']) +def test_replace_memlet(filter_type): + # Prepare SDFG + sdfg = dace.SDFG('replace_memlet') + sdfg.add_array('A', [2, 2], dace.float64) + sdfg.add_array('B', [1], dace.float64) + state1 = sdfg.add_state() + state2 = sdfg.add_state() + state3 = sdfg.add_state() + end_state = sdfg.add_state() + sdfg.add_edge(state1, state2, dace.InterstateEdge('A[0, 0] > 0')) + sdfg.add_edge(state1, state3, dace.InterstateEdge('A[0, 0] <= 0')) + sdfg.add_edge(state2, end_state, dace.InterstateEdge()) + sdfg.add_edge(state3, end_state, dace.InterstateEdge()) + + t2 = state2.add_tasklet('write_one', {}, {'out'}, 'out = 1') + t3 = state3.add_tasklet('write_two', {}, {'out'}, 'out = 2') + w2 = state2.add_write('B') + w3 = state3.add_write('B') + state2.add_memlet_path(t2, w2, src_conn='out', memlet=dace.Memlet('B')) + state3.add_memlet_path(t3, w3, src_conn='out', memlet=dace.Memlet('B')) + + # Filter memlets + if filter_type == 'none': + filter = set() + elif filter_type == 'same_array': + filter = {'A'} + elif filter_type == 'different_array': + filter = {'B'} + + # Replace memlets in conditions + replacer = mu.MemletReplacer(sdfg.arrays, _replace_zero_with_one, filter) + for e in sdfg.edges(): + e.data.condition.code[0] = replacer.visit(e.data.condition.code[0]) + + # Compile and run + sdfg.compile() + + A = np.array([[1, 1], [1, -1]], dtype=np.float64) + B = np.array([0], dtype=np.float64) + sdfg(A=A, B=B) + + if filter_type in {'none', 'same_array'}: + assert B[0] == 2 + else: + assert B[0] == 1 + + +if __name__ == '__main__': + test_replace_memlet('none') + test_replace_memlet('same_array') + test_replace_memlet('different_array') diff --git a/tests/symbol_dependent_transients_test.py b/tests/symbol_dependent_transients_test.py index f718abf379..8033b6b196 100644 --- a/tests/symbol_dependent_transients_test.py +++ b/tests/symbol_dependent_transients_test.py @@ -45,7 +45,7 @@ def _make_sdfg(name, storage=dace.dtypes.StorageType.CPU_Heap, isview=False): body2_state.add_nedge(read_a, read_tmp1, dace.Memlet(f'A[2:{N}-2, 2:{N}-2, i:{N}]')) else: read_tmp1 = body2_state.add_read('tmp1') - rednode = standard.Reduce(wcr='lambda a, b : a + b', identity=0) + rednode = standard.Reduce('sum', wcr='lambda a, b : a + b', identity=0) if storage == dace.dtypes.StorageType.GPU_Global: rednode.implementation = 'CUDA (device)' elif storage == dace.dtypes.StorageType.FPGA_Global: