Skip to content

Commit

Permalink
Schedule Trees (#1145)
Browse files Browse the repository at this point in the history
This PR adds support for a scheduling-oriented view of SDFGs. Upon conversion, the SDFG and its nested SDFGs keep the same array names and are organized in one tree, where each node corresponds to a schedulable concept (map scope, copy, tasklet, for-loop scope, etc.). The graph structure can be converted to sequential text with `as_string`. Useful for inspecting and analyzing schedules. 

---------

Co-authored-by: Alexandros Nikolaos Ziogas <[email protected]>
  • Loading branch information
tbennun and alexnick83 authored Sep 26, 2023
1 parent f6263b5 commit a582261
Show file tree
Hide file tree
Showing 27 changed files with 2,354 additions and 96 deletions.
4 changes: 2 additions & 2 deletions dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,8 +886,8 @@ def generate_code(self,

# NOTE: NestedSDFGs frequently contain tautologies in their symbol mapping, e.g., `'i': i`. Do not
# redefine the symbols in such cases.
if (not is_top_level and isvarName in sdfg.parent_nsdfg_node.symbol_mapping.keys()
and str(sdfg.parent_nsdfg_node.symbol_mapping[isvarName] == isvarName)):
if (not is_top_level and isvarName in sdfg.parent_nsdfg_node.symbol_mapping
and str(sdfg.parent_nsdfg_node.symbol_mapping[isvarName]) == str(isvarName)):
continue
isvar = data.Scalar(isvarType)
callsite_stream.write('%s;\n' % (isvar.as_arg(with_types=True, name=isvarName)), sdfg)
Expand Down
18 changes: 18 additions & 0 deletions dace/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ def __hash__(self):
def as_arg(self, with_types=True, for_call=False, name=None):
"""Returns a string for a C++ function signature (e.g., `int *A`). """
raise NotImplementedError

def as_python_arg(self, with_types=True, for_call=False, name=None):
"""Returns a string for a Data-Centric Python function signature (e.g., `A: dace.int32[M]`). """
raise NotImplementedError

def used_symbols(self, all_symbols: bool) -> Set[symbolic.SymbolicType]:
"""
Expand Down Expand Up @@ -583,6 +587,13 @@ def as_arg(self, with_types=True, for_call=False, name=None):
if not with_types or for_call:
return name
return self.dtype.as_arg(name)

def as_python_arg(self, with_types=True, for_call=False, name=None):
if self.storage is dtypes.StorageType.GPU_Global:
return Array(self.dtype, [1]).as_python_arg(with_types, for_call, name)
if not with_types or for_call:
return name
return f"{name}: {dtypes.TYPECLASS_TO_STRING[self.dtype].replace('::', '.')}"

def sizes(self):
return None
Expand Down Expand Up @@ -849,6 +860,13 @@ def as_arg(self, with_types=True, for_call=False, name=None):
if self.may_alias:
return str(self.dtype.ctype) + ' *' + arrname
return str(self.dtype.ctype) + ' * __restrict__ ' + arrname

def as_python_arg(self, with_types=True, for_call=False, name=None):
arrname = name

if not with_types or for_call:
return arrname
return f"{arrname}: {dtypes.TYPECLASS_TO_STRING[self.dtype].replace('::', '.')}{list(self.shape)}"

def sizes(self):
return [d.name if isinstance(d, symbolic.symbol) else str(d) for d in self.shape]
Expand Down
4 changes: 2 additions & 2 deletions dace/frontend/python/memlet_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _fill_missing_slices(das, ast_ndslice, array, indices):
def parse_memlet_subset(array: data.Data,
node: Union[ast.Name, ast.Subscript],
das: Dict[str, Any],
parsed_slice: Any = None) -> Tuple[subsets.Range, List[int]]:
parsed_slice: Any = None) -> Tuple[subsets.Range, List[int], List[int]]:
"""
Parses an AST subset and returns access range, as well as new dimensions to
add.
Expand All @@ -209,7 +209,7 @@ def parse_memlet_subset(array: data.Data,
e.g., negative indices or empty shapes).
:param node: AST node representing whole array or subset thereof.
:param das: Dictionary of defined arrays and symbols mapped to their values.
:return: A 2-tuple of (subset, list of new axis indices).
:return: A 3-tuple of (subset, list of new axis indices, list of index-to-array-dimension correspondence).
"""
# Get memlet range
ndslice = [(0, s - 1, 1) for s in array.shape]
Expand Down
6 changes: 6 additions & 0 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3177,6 +3177,12 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):

if (not is_return and isinstance(target, ast.Name) and true_name and not op
and not isinstance(true_array, data.Scalar) and not (true_array.shape == (1, ))):
if true_name in self.views:
if result in self.sdfg.arrays and self.views[true_name] == (
result, Memlet.from_array(result, self.sdfg.arrays[result])):
continue
else:
raise DaceSyntaxError(self, target, 'Cannot reassign View "{}"'.format(name))
if (isinstance(result, str) and result in self.sdfg.arrays
and self.sdfg.arrays[result].is_equivalent(true_array)):
# Skip error if the arrays are defined exactly in the same way.
Expand Down
47 changes: 33 additions & 14 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,16 +617,21 @@ def _elementwise(pv: 'ProgramVisitor',

def _simple_call(sdfg: SDFG, state: SDFGState, inpname: str, func: str, restype: dace.typeclass = None):
""" Implements a simple call of the form `out = func(inp)`. """
create_input = True
if isinstance(inpname, (list, tuple)): # TODO investigate this
inpname = inpname[0]
if not isinstance(inpname, str):
if not isinstance(inpname, str) and not symbolic.issymbolic(inpname):
# Constant parameter
cst = inpname
inparr = data.create_datadescriptor(cst)
inpname = sdfg.temp_data_name()
inparr.transient = True
sdfg.add_constant(inpname, cst, inparr)
sdfg.add_datadesc(inpname, inparr)
elif symbolic.issymbolic(inpname):
dtype = symbolic.symtype(inpname)
inparr = data.Scalar(dtype)
create_input = False
else:
inparr = sdfg.arrays[inpname]

Expand All @@ -636,10 +641,17 @@ def _simple_call(sdfg: SDFG, state: SDFGState, inpname: str, func: str, restype:
outarr.dtype = restype
num_elements = data._prod(inparr.shape)
if num_elements == 1:
inp = state.add_read(inpname)
if create_input:
inp = state.add_read(inpname)
inconn_name = '__inp'
else:
inconn_name = symbolic.symstr(inpname)

out = state.add_write(outname)
tasklet = state.add_tasklet(func, {'__inp'}, {'__out'}, '__out = {f}(__inp)'.format(f=func))
state.add_edge(inp, None, tasklet, '__inp', Memlet.from_array(inpname, inparr))
tasklet = state.add_tasklet(func, {'__inp'} if create_input else {}, {'__out'},
f'__out = {func}({inconn_name})')
if create_input:
state.add_edge(inp, None, tasklet, '__inp', Memlet.from_array(inpname, inparr))
state.add_edge(tasklet, '__out', out, None, Memlet.from_array(outname, outarr))
else:
state.add_mapped_tasklet(
Expand Down Expand Up @@ -2158,8 +2170,9 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op

res = symbolic.equal(arr1.shape[-1], arr2.shape[-2])
if res is None:
warnings.warn(f'Last mode of first tesnsor/matrix {arr1.shape[-1]} and second-last mode of '
f'second tensor/matrix {arr2.shape[-2]} may not match', UserWarning)
warnings.warn(
f'Last mode of first tesnsor/matrix {arr1.shape[-1]} and second-last mode of '
f'second tensor/matrix {arr2.shape[-2]} may not match', UserWarning)
elif not res:
raise SyntaxError('Matrix dimension mismatch %s != %s' % (arr1.shape[-1], arr2.shape[-2]))

Expand All @@ -2176,8 +2189,9 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op

res = symbolic.equal(arr1.shape[-1], arr2.shape[0])
if res is None:
warnings.warn(f'Number of matrix columns {arr1.shape[-1]} and length of vector {arr2.shape[0]} '
f'may not match', UserWarning)
warnings.warn(
f'Number of matrix columns {arr1.shape[-1]} and length of vector {arr2.shape[0]} '
f'may not match', UserWarning)
elif not res:
raise SyntaxError("Number of matrix columns {} must match"
"size of vector {}.".format(arr1.shape[1], arr2.shape[0]))
Expand All @@ -2188,8 +2202,9 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op

res = symbolic.equal(arr1.shape[0], arr2.shape[0])
if res is None:
warnings.warn(f'Length of vector {arr1.shape[0]} and number of matrix rows {arr2.shape[0]} '
f'may not match', UserWarning)
warnings.warn(
f'Length of vector {arr1.shape[0]} and number of matrix rows {arr2.shape[0]} '
f'may not match', UserWarning)
elif not res:
raise SyntaxError("Size of vector {} must match number of matrix "
"rows {} must match".format(arr1.shape[0], arr2.shape[0]))
Expand All @@ -2200,8 +2215,9 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op

res = symbolic.equal(arr1.shape[0], arr2.shape[0])
if res is None:
warnings.warn(f'Length of first vector {arr1.shape[0]} and length of second vector {arr2.shape[0]} '
f'may not match', UserWarning)
warnings.warn(
f'Length of first vector {arr1.shape[0]} and length of second vector {arr2.shape[0]} '
f'may not match', UserWarning)
elif not res:
raise SyntaxError("Vectors in vector product must have same size: "
"{} vs. {}".format(arr1.shape[0], arr2.shape[0]))
Expand Down Expand Up @@ -4401,10 +4417,13 @@ def _datatype_converter(sdfg: SDFG, state: SDFGState, arg: UfuncInput, dtype: dt

# Set tasklet parameters
impl = {
'name': "_convert_to_{}_".format(dtype.to_string()),
'name':
"_convert_to_{}_".format(dtype.to_string()),
'inputs': ['__inp'],
'outputs': ['__out'],
'code': "__out = dace.{}(__inp)".format(dtype.to_string())
'code':
"__out = {}(__inp)".format(f"dace.{dtype.to_string()}" if dtype not in (dace.bool,
dace.bool_) else dtype.to_string())
}
if dtype in (dace.bool, dace.bool_):
impl['code'] = "__out = dace.bool_(__inp)"
Expand Down
4 changes: 3 additions & 1 deletion dace/libraries/blas/nodes/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,5 +217,7 @@ class MatMul(dace.sdfg.nodes.LibraryNode):
default=0,
desc="A scalar which will be multiplied with C before adding C")

def __init__(self, name, location=None):
def __init__(self, name, location=None, alpha=1, beta=0):
self.alpha = alpha
self.beta = beta
super().__init__(name, location=location, inputs={"_a", "_b"}, outputs={"_c"})
5 changes: 3 additions & 2 deletions dace/libraries/standard/nodes/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,13 +1562,14 @@ class Reduce(dace.sdfg.nodes.LibraryNode):
identity = Property(allow_none=True)

def __init__(self,
name,
wcr='lambda a, b: a',
axes=None,
identity=None,
schedule=dtypes.ScheduleType.Default,
debuginfo=None,
**kwargs):
super().__init__(name='Reduce', **kwargs)
super().__init__(name=name, **kwargs)
self.wcr = wcr
self.axes = axes
self.identity = identity
Expand All @@ -1577,7 +1578,7 @@ def __init__(self,

@staticmethod
def from_json(json_obj, context=None):
ret = Reduce("lambda a, b: a", None)
ret = Reduce('reduce', 'lambda a, b: a', None)
dace.serialize.set_properties_from_json(ret, json_obj, context=context)
return ret

Expand Down
7 changes: 5 additions & 2 deletions dace/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,8 +1001,11 @@ def get_free_symbols(self, defined_syms: Set[str] = None) -> Set[str]:
if self.language == dace.dtypes.Language.Python:
visitor = TaskletFreeSymbolVisitor(defined_syms)
if self.code:
for stmt in self.code:
visitor.visit(stmt)
if isinstance(self.code, list):
for stmt in self.code:
visitor.visit(stmt)
else:
visitor.visit(self.code)
return visitor.free_symbols

return set()
Expand Down
Empty file.
60 changes: 60 additions & 0 deletions dace/sdfg/analysis/schedule_tree/passes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
"""
Assortment of passes for schedule trees.
"""

from dace.sdfg.analysis.schedule_tree import treenodes as tn
from typing import Set


def remove_unused_and_duplicate_labels(stree: tn.ScheduleTreeScope):
"""
Removes unused and duplicate labels from the schedule tree.
:param stree: The schedule tree to remove labels from.
"""

class FindGotos(tn.ScheduleNodeVisitor):

def __init__(self):
self.gotos: Set[str] = set()

def visit_GotoNode(self, node: tn.GotoNode):
if node.target is not None:
self.gotos.add(node.target)

class RemoveLabels(tn.ScheduleNodeTransformer):

def __init__(self, labels_to_keep: Set[str]) -> None:
self.labels_to_keep = labels_to_keep
self.labels_seen = set()

def visit_StateLabel(self, node: tn.StateLabel):
if node.state.name not in self.labels_to_keep:
return None
if node.state.name in self.labels_seen:
return None
self.labels_seen.add(node.state.name)
return node

fg = FindGotos()
fg.visit(stree)
return RemoveLabels(fg.gotos).visit(stree)


def remove_empty_scopes(stree: tn.ScheduleTreeScope):
"""
Removes empty scopes from the schedule tree.
:warning: This pass is not safe to use for for-loops, as it will remove indices that may be used after the loop.
"""

class RemoveEmptyScopes(tn.ScheduleNodeTransformer):

def visit_scope(self, node: tn.ScheduleTreeScope):
if len(node.children) == 0:
return None

return self.generic_visit(node)

return RemoveEmptyScopes().visit(stree)
Loading

0 comments on commit a582261

Please sign in to comment.