diff --git a/dace/codegen/compiled_sdfg.py b/dace/codegen/compiled_sdfg.py index dcd529865f..bc1a80f703 100644 --- a/dace/codegen/compiled_sdfg.py +++ b/dace/codegen/compiled_sdfg.py @@ -452,9 +452,10 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]: # GPU scalars and return values are pointers, so this is fine if atype.storage != dtypes.StorageType.GPU_Global and not a.startswith('__return'): raise TypeError('Passing an array to a scalar (type %s) in argument "%s"' % (atype.dtype.ctype, a)) - elif not isinstance(atype, dt.Array) and not isinstance(atype.dtype, dtypes.callback) and not isinstance( - arg, - (atype.dtype.type, sp.Basic)) and not (isinstance(arg, symbolic.symbol) and arg.dtype == atype.dtype): + elif (not isinstance(atype, (dt.Array, dt.Structure)) and + not isinstance(atype.dtype, dtypes.callback) and + not isinstance(arg, (atype.dtype.type, sp.Basic)) and + not (isinstance(arg, symbolic.symbol) and arg.dtype == atype.dtype)): if isinstance(arg, int) and atype.dtype.type == np.int64: pass elif isinstance(arg, float) and atype.dtype.type == np.float64: @@ -472,7 +473,7 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]: else: warnings.warn(f'Casting scalar argument "{a}" from {type(arg).__name__} to {atype.dtype.type}') arglist[i] = atype.dtype.type(arg) - elif (isinstance(atype, dt.Array) and isinstance(arg, np.ndarray) + elif (isinstance(atype, dt.Array) and isinstance(arg, np.ndarray) and not isinstance(atype, dt.StructArray) and atype.dtype.as_numpy_dtype() != arg.dtype): # Make exception for vector types if (isinstance(atype.dtype, dtypes.vector) and atype.dtype.vtype.as_numpy_dtype() == arg.dtype): @@ -521,7 +522,7 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]: # Construct init args, which only consist of the symbols symbols = self._free_symbols initargs = tuple( - actype(arg) if (not isinstance(arg, ctypes._SimpleCData)) else arg + actype(arg) if not isinstance(arg, ctypes._SimpleCData) else arg for arg, actype, atype, aname in callparams if aname in symbols) # Replace arrays with their base host/device pointers @@ -531,7 +532,8 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]: try: newargs = tuple( - actype(arg) if (not isinstance(arg, ctypes._SimpleCData)) else arg for arg, actype, atype in newargs) + actype(arg) if not isinstance(arg, (ctypes._SimpleCData)) else arg + for arg, actype, atype in newargs) except TypeError: # Pinpoint bad argument for i, (arg, actype, _) in enumerate(newargs): diff --git a/dace/codegen/dispatcher.py b/dace/codegen/dispatcher.py index be032556a0..359d3a5853 100644 --- a/dace/codegen/dispatcher.py +++ b/dace/codegen/dispatcher.py @@ -505,11 +505,11 @@ def get_copy_dispatcher(self, src_node, dst_node, edge, sdfg, state): dst_is_data = True # Skip copies to/from views where edge matches - if src_is_data and isinstance(src_node.desc(sdfg), dt.View): + if src_is_data and isinstance(src_node.desc(sdfg), (dt.StructureView, dt.View)): e = sdutil.get_view_edge(state, src_node) if e is edge: return None - if dst_is_data and isinstance(dst_node.desc(sdfg), dt.View): + if dst_is_data and isinstance(dst_node.desc(sdfg), (dt.StructureView, dt.View)): e = sdutil.get_view_edge(state, dst_node) if e is edge: return None diff --git a/dace/codegen/targets/cpp.py b/dace/codegen/targets/cpp.py index 264311a45c..d3d4f50ccd 100644 --- a/dace/codegen/targets/cpp.py +++ b/dace/codegen/targets/cpp.py @@ -370,6 +370,10 @@ def make_const(expr: str) -> str: # Register defined variable dispatcher.defined_vars.add(pointer_name, defined_type, typedef, allow_shadowing=True) + # NOTE: `expr` may only be a name or a sequence of names and dots. The latter indicates nested data and structures. + # NOTE: Since structures are implemented as pointers, we replace dots with arrows. + expr = expr.replace('.', '->') + return (typedef + ref, pointer_name, expr) diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index ef97b0bbad..59d635e14e 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -55,10 +55,29 @@ def __init__(self, frame_codegen, sdfg): # Keep track of generated NestedSDG, and the name of the assigned function self._generated_nested_sdfg = dict() - # Keeps track of generated connectors, so we know how to access them in - # nested scopes + # NOTE: Multi-nesting with StructArrays must be further investigated. + def _visit_structure(struct: data.Structure, args: dict, prefix: str = ''): + for k, v in struct.members.items(): + if isinstance(v, data.Structure): + _visit_structure(v, args, f'{prefix}.{k}') + elif isinstance(v, data.StructArray): + _visit_structure(v.stype, args, f'{prefix}.{k}') + elif isinstance(v, data.Data): + args[f'{prefix}.{k}'] = v + + # Keeps track of generated connectors, so we know how to access them in nested scopes + arglist = dict(self._frame.arglist) for name, arg_type in self._frame.arglist.items(): - if isinstance(arg_type, data.Scalar): + if isinstance(arg_type, data.Structure): + desc = sdfg.arrays[name] + _visit_structure(arg_type, arglist, name) + elif isinstance(arg_type, data.StructArray): + desc = sdfg.arrays[name] + desc = desc.stype + _visit_structure(desc, arglist, name) + + for name, arg_type in arglist.items(): + if isinstance(arg_type, (data.Scalar, data.Structure)): # GPU global memory is only accessed via pointers # TODO(later): Fix workaround somehow if arg_type.storage is dtypes.StorageType.GPU_Global: @@ -195,9 +214,21 @@ def allocate_view(self, sdfg: SDFG, dfg: SDFGState, state_id: int, node: nodes.A ancestor=0, is_write=is_write) if not declared: - declaration_stream.write(f'{atype} {aname};', sdfg, state_id, node) ctypedef = dtypes.pointer(nodedesc.dtype).ctype self._dispatcher.declared_arrays.add(aname, DefinedType.Pointer, ctypedef) + if isinstance(nodedesc, data.StructureView): + for k, v in nodedesc.members.items(): + if isinstance(v, data.Data): + ctypedef = dtypes.pointer(v.dtype).ctype if isinstance(v, data.Array) else v.dtype.ctype + defined_type = DefinedType.Scalar if isinstance(v, data.Scalar) else DefinedType.Pointer + self._dispatcher.declared_arrays.add(f"{name}.{k}", defined_type, ctypedef) + self._dispatcher.defined_vars.add(f"{name}.{k}", defined_type, ctypedef) + # TODO: Find a better way to do this (the issue is with pointers of pointers) + if atype.endswith('*'): + atype = atype[:-1] + if value.startswith('&'): + value = value[1:] + declaration_stream.write(f'{atype} {aname};', sdfg, state_id, node) allocation_stream.write(f'{aname} = {value};', sdfg, state_id, node) def allocate_reference(self, sdfg: SDFG, dfg: SDFGState, state_id: int, node: nodes.AccessNode, @@ -268,16 +299,19 @@ def allocate_array(self, sdfg, dfg, state_id, node, nodedesc, function_stream, d name = node.data alloc_name = cpp.ptr(name, nodedesc, sdfg, self._frame) name = alloc_name + # NOTE: `expr` may only be a name or a sequence of names and dots. The latter indicates nested data and + # NOTE: structures. Since structures are implemented as pointers, we replace dots with arrows. + alloc_name = alloc_name.replace('.', '->') if nodedesc.transient is False: return # Check if array is already allocated - if self._dispatcher.defined_vars.has(alloc_name): + if self._dispatcher.defined_vars.has(name): return # Check if array is already declared - declared = self._dispatcher.declared_arrays.has(alloc_name) + declared = self._dispatcher.declared_arrays.has(name) define_var = self._dispatcher.defined_vars.add if nodedesc.lifetime in (dtypes.AllocationLifetime.Persistent, dtypes.AllocationLifetime.External): @@ -290,7 +324,18 @@ def allocate_array(self, sdfg, dfg, state_id, node, nodedesc, function_stream, d if not isinstance(nodedesc.dtype, dtypes.opaque): arrsize_bytes = arrsize * nodedesc.dtype.bytes - if isinstance(nodedesc, data.View): + if isinstance(nodedesc, data.Structure) and not isinstance(nodedesc, data.StructureView): + declaration_stream.write(f"{nodedesc.ctype} {name} = new {nodedesc.dtype.base_type};\n") + define_var(name, DefinedType.Pointer, nodedesc.ctype) + for k, v in nodedesc.members.items(): + if isinstance(v, data.Data): + ctypedef = dtypes.pointer(v.dtype).ctype if isinstance(v, data.Array) else v.dtype.ctype + defined_type = DefinedType.Scalar if isinstance(v, data.Scalar) else DefinedType.Pointer + self._dispatcher.declared_arrays.add(f"{name}.{k}", defined_type, ctypedef) + self.allocate_array(sdfg, dfg, state_id, nodes.AccessNode(f"{name}.{k}"), v, function_stream, + declaration_stream, allocation_stream) + return + if isinstance(nodedesc, (data.StructureView, data.View)): return self.allocate_view(sdfg, dfg, state_id, node, function_stream, declaration_stream, allocation_stream) if isinstance(nodedesc, data.Reference): return self.allocate_reference(sdfg, dfg, state_id, node, function_stream, declaration_stream, @@ -455,7 +500,7 @@ def deallocate_array(self, sdfg, dfg, state_id, node, nodedesc, function_stream, dtypes.AllocationLifetime.External) self._dispatcher.declared_arrays.remove(alloc_name, is_global=is_global) - if isinstance(nodedesc, (data.Scalar, data.View, data.Stream, data.Reference)): + if isinstance(nodedesc, (data.Scalar, data.StructureView, data.View, data.Stream, data.Reference)): return elif (nodedesc.storage == dtypes.StorageType.CPU_Heap or (nodedesc.storage == dtypes.StorageType.Register and symbolic.issymbolic(arrsize, sdfg.constants))): @@ -1139,6 +1184,9 @@ def memlet_definition(self, if not types: types = self._dispatcher.defined_vars.get(ptr, is_global=True) var_type, ctypedef = types + # NOTE: `expr` may only be a name or a sequence of names and dots. The latter indicates nested data and + # NOTE: structures. Since structures are implemented as pointers, we replace dots with arrows. + ptr = ptr.replace('.', '->') if fpga.is_fpga_array(desc): decouple_array_interfaces = Config.get_bool("compiler", "xilinx", "decouple_array_interfaces") diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 56419b9701..9ee5c2ef17 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -153,15 +153,23 @@ def generate_fileheader(self, sdfg: SDFG, global_stream: CodeIOStream, backend: for _, arrname, arr in sdfg.arrays_recursive(): if arr is not None: datatypes.add(arr.dtype) + + def _emit_definitions(dtype: dtypes.typeclass, wrote_something: bool) -> bool: + if isinstance(dtype, dtypes.pointer): + wrote_something = _emit_definitions(dtype._typeclass, wrote_something) + elif isinstance(dtype, dtypes.struct): + for field in dtype.fields.values(): + wrote_something = _emit_definitions(field, wrote_something) + if hasattr(dtype, 'emit_definition'): + if not wrote_something: + global_stream.write("", sdfg) + global_stream.write(dtype.emit_definition(), sdfg) + return wrote_something # Emit unique definitions wrote_something = False for typ in datatypes: - if hasattr(typ, 'emit_definition'): - if not wrote_something: - global_stream.write("", sdfg) - wrote_something = True - global_stream.write(typ.emit_definition(), sdfg) + wrote_something = _emit_definitions(typ, wrote_something) if wrote_something: global_stream.write("", sdfg) @@ -741,7 +749,7 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG): instances = access_instances[sdfg.sdfg_id][name] # A view gets "allocated" everywhere it appears - if isinstance(desc, data.View): + if isinstance(desc, (data.StructureView, data.View)): for s, n in instances: self.to_allocate[s].append((sdfg, s, n, False, True, False)) self.to_allocate[s].append((sdfg, s, n, False, False, True)) diff --git a/dace/data.py b/dace/data.py index d492d06258..3b571e6537 100644 --- a/dace/data.py +++ b/dace/data.py @@ -1,10 +1,11 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import copy as cp import ctypes import functools -import re + +from collections import OrderedDict from numbers import Number -from typing import Any, Dict, Optional, Sequence, Set, Tuple +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union import numpy import sympy as sp @@ -17,8 +18,8 @@ import dace.dtypes as dtypes from dace import serialize, symbolic from dace.codegen import cppunparse -from dace.properties import (CodeProperty, DebugInfoProperty, DictProperty, EnumProperty, ListProperty, Property, - ReferenceProperty, ShapeProperty, SubsetProperty, SymbolicProperty, TypeClassProperty, +from dace.properties import (DebugInfoProperty, DictProperty, EnumProperty, ListProperty, NestedDataClassProperty, + OrderedDictProperty, Property, ShapeProperty, SymbolicProperty, TypeClassProperty, make_properties) @@ -354,6 +355,157 @@ def add(X: dace.float32[10, 10] @ dace.StorageType.GPU_Global): return new_desc +def _arrays_to_json(arrays): + if arrays is None: + return None + return [(k, serialize.to_json(v)) for k, v in arrays.items()] + + +def _arrays_from_json(obj, context=None): + if obj is None: + return {} + return OrderedDict((k, serialize.from_json(v, context)) for k, v in obj) + + +@make_properties +class Structure(Data): + """ Base class for structures. """ + + members = OrderedDictProperty(default=OrderedDict(), + desc="Dictionary of structure members", + from_json=_arrays_from_json, + to_json=_arrays_to_json) + name = Property(dtype=str, desc="Structure type name") + + def __init__(self, + members: Union[Dict[str, Data], List[Tuple[str, Data]]], + name: str = 'Structure', + transient: bool = False, + storage: dtypes.StorageType = dtypes.StorageType.Default, + location: Dict[str, str] = None, + lifetime: dtypes.AllocationLifetime = dtypes.AllocationLifetime.Scope, + debuginfo: dtypes.DebugInfo = None): + + self.members = OrderedDict(members) + for k, v in self.members.items(): + v.transient = transient + + self.name = name + fields_and_types = OrderedDict() + symbols = set() + for k, v in self.members.items(): + if isinstance(v, Structure): + symbols |= v.free_symbols + fields_and_types[k] = (v.dtype, str(v.total_size)) + elif isinstance(v, Array): + symbols |= v.free_symbols + fields_and_types[k] = (dtypes.pointer(v.dtype), str(_prod(v.shape))) + elif isinstance(v, Scalar): + symbols |= v.free_symbols + fields_and_types[k] = v.dtype + elif isinstance(v, (sp.Basic, symbolic.SymExpr)): + symbols |= v.free_symbols + fields_and_types[k] = symbolic.symtype(v) + elif isinstance(v, (int, numpy.integer)): + fields_and_types[k] = dtypes.typeclass(type(v)) + else: + raise TypeError(f"Attribute {k}'s value {v} has unsupported type: {type(v)}") + + # NOTE: We will not store symbols in the dtype for now, but leaving it as a comment to investigate later. + # NOTE: See discussion about data/object symbols. + # for s in symbols: + # if str(s) in fields_and_types: + # continue + # if hasattr(s, "dtype"): + # fields_and_types[str(s)] = s.dtype + # else: + # fields_and_types[str(s)] = dtypes.int32 + + dtype = dtypes.pointer(dtypes.struct(name, **fields_and_types)) + shape = (1,) + super(Structure, self).__init__(dtype, shape, transient, storage, location, lifetime, debuginfo) + + @staticmethod + def from_json(json_obj, context=None): + if json_obj['type'] != 'Structure': + raise TypeError("Invalid data type") + + # Create dummy object + ret = Structure({}) + serialize.set_properties_from_json(ret, json_obj, context=context) + + return ret + + @property + def total_size(self): + return -1 + + @property + def offset(self): + return [0] + + @property + def start_offset(self): + return 0 + + @property + def strides(self): + return [1] + + @property + def free_symbols(self) -> Set[symbolic.SymbolicType]: + """ Returns a set of undefined symbols in this data descriptor. """ + result = set() + for k, v in self.members.items(): + result |= v.free_symbols + return result + + def __repr__(self): + return f"{self.name} ({', '.join([f'{k}: {v}' for k, v in self.members.items()])})" + + def as_arg(self, with_types=True, for_call=False, name=None): + if self.storage is dtypes.StorageType.GPU_Global: + return Array(self.dtype, [1]).as_arg(with_types, for_call, name) + if not with_types or for_call: + return name + return self.dtype.as_arg(name) + + def __getitem__(self, s): + """ This is syntactic sugar that allows us to define an array type + with the following syntax: ``Structure[N,M]`` + :return: A ``data.StructArray`` data descriptor. + """ + if isinstance(s, list) or isinstance(s, tuple): + return StructArray(self, tuple(s)) + return StructArray(self, (s, )) + + +@make_properties +class StructureView(Structure): + """ + Data descriptor that acts as a reference (or view) of another structure. + """ + + @staticmethod + def from_json(json_obj, context=None): + if json_obj['type'] != 'StructureView': + raise TypeError("Invalid data type") + + # Create dummy object + ret = StructureView({}) + serialize.set_properties_from_json(ret, json_obj, context=context) + + return ret + + def validate(self): + super().validate() + + # We ensure that allocation lifetime is always set to Scope, since the + # view is generated upon "allocation" + if self.lifetime != dtypes.AllocationLifetime.Scope: + raise ValueError('Only Scope allocation lifetime is supported for Views') + + @make_properties class Scalar(Data): """ Data descriptor of a scalar value. """ @@ -920,6 +1072,56 @@ def free_symbols(self): return self.used_symbols(all_symbols=True) +@make_properties +class StructArray(Array): + """ Array of Structures. """ + + stype = NestedDataClassProperty(allow_none=True, default=None) + + def __init__(self, + stype: Structure, + shape, + transient=False, + allow_conflicts=False, + storage=dtypes.StorageType.Default, + location=None, + strides=None, + offset=None, + may_alias=False, + lifetime=dtypes.AllocationLifetime.Scope, + alignment=0, + debuginfo=None, + total_size=-1, + start_offset=None, + optional=None, + pool=False): + + self.stype = stype + if stype: + dtype = stype.dtype + else: + dtype = dtypes.int8 + super(StructArray, self).__init__(dtype, shape, transient, allow_conflicts, storage, location, strides, offset, + may_alias, lifetime, alignment, debuginfo, total_size, start_offset, optional, pool) + + @classmethod + def from_json(cls, json_obj, context=None): + # Create dummy object + ret = cls(None, ()) + serialize.set_properties_from_json(ret, json_obj, context=context) + + # Default shape-related properties + if not ret.offset: + ret.offset = [0] * len(ret.shape) + if not ret.strides: + # Default strides are C-ordered + ret.strides = [_prod(ret.shape[i + 1:]) for i in range(len(ret.shape))] + if ret.total_size == 0: + ret.total_size = _prod(ret.shape) + + return ret + + @make_properties class View(Array): """ diff --git a/dace/dtypes.py b/dace/dtypes.py index cbbc4125c1..f0bac23958 100644 --- a/dace/dtypes.py +++ b/dace/dtypes.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ A module that contains various DaCe type definitions. """ from __future__ import print_function import ctypes @@ -7,6 +7,7 @@ import itertools import numpy import re +from collections import OrderedDict from functools import wraps from typing import Any from dace.config import Config @@ -657,6 +658,8 @@ def from_json(json_obj, context=None): def as_ctypes(self): """ Returns the ctypes version of the typeclass. """ + if isinstance(self._typeclass, struct): + return ctypes.POINTER(self._typeclass.as_ctypes()) return ctypes.POINTER(_FFI_CTYPES[self.type]) def as_numpy_dtype(self): @@ -772,10 +775,8 @@ def to_json(self): return { 'type': 'struct', 'name': self.name, - 'data': {k: v.to_json() - for k, v in self._data.items()}, - 'length': {k: v - for k, v in self._length.items()}, + 'data': [(k, v.to_json()) for k, v in self._data.items()], + 'length': [(k, v) for k, v in self._length.items()], 'bytes': self.bytes } @@ -787,23 +788,28 @@ def from_json(json_obj, context=None): import dace.serialize # Avoid import loop ret = struct(json_obj['name']) - ret._data = {k: json_to_typeclass(v, context) for k, v in json_obj['data'].items()} - ret._length = {k: v for k, v in json_obj['length'].items()} + ret._data = {k: json_to_typeclass(v, context) for k, v in json_obj['data']} + ret._length = {k: v for k, v in json_obj['length']} ret.bytes = json_obj['bytes'] return ret def _parse_field_and_types(self, **fields_and_types): - self._data = dict() - self._length = dict() + # from dace.symbolic import pystr_to_symbolic + self._data = OrderedDict() + self._length = OrderedDict() self.bytes = 0 for k, v in fields_and_types.items(): if isinstance(v, tuple): t, l = v if not isinstance(t, pointer): raise TypeError("Only pointer types may have a length.") - if l not in fields_and_types.keys(): - raise ValueError("Length {} not a field of struct {}".format(l, self.name)) + # TODO: Do we need the free symbols of the length in the struct? + # NOTE: It is needed for the old use of dtype.struct. Are we deprecating that? + # sym_tokens = pystr_to_symbolic(l).free_symbols + # for sym in sym_tokens: + # if str(sym) not in fields_and_types.keys(): + # raise ValueError(f"Symbol {sym} in {k}'s length {l} is not a field of struct {self.name}") self._data[k] = t self._length[k] = l self.bytes += t.bytes @@ -815,16 +821,24 @@ def _parse_field_and_types(self, **fields_and_types): def as_ctypes(self): """ Returns the ctypes version of the typeclass. """ + if self in _FFI_CTYPES: + return _FFI_CTYPES[self] # Populate the ctype fields for the struct class. fields = [] for k, v in self._data.items(): if isinstance(v, pointer): - fields.append((k, ctypes.c_void_p)) # ctypes.POINTER(_FFI_CTYPES[v.type]))) + if isinstance(v._typeclass, struct): + fields.append((k, ctypes.POINTER(v._typeclass.as_ctypes()))) + else: + fields.append((k, ctypes.c_void_p)) + elif isinstance(v, struct): + fields.append((k, v.as_ctypes())) else: fields.append((k, _FFI_CTYPES[v.type])) - fields = sorted(fields, key=lambda f: f[0]) # Create new struct class. struct_class = type("NewStructClass", (ctypes.Structure, ), {"_fields_": fields}) + # NOTE: Each call to `type` returns a different class, so we need to cache it to ensure uniqueness. + _FFI_CTYPES[self] = struct_class return struct_class def as_numpy_dtype(self): @@ -835,7 +849,7 @@ def emit_definition(self): {typ} }};""".format( name=self.name, - typ='\n'.join([" %s %s;" % (t.ctype, tname) for tname, t in sorted(self._data.items())]), + typ='\n'.join([" %s %s;" % (t.ctype, tname) for tname, t in self._data.items()]), ) diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index 6d1be7138a..d7112892fe 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -463,6 +463,7 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, if i.type == "ALL": shape.append(array.shape[indices]) mysize = mysize * array.shape[indices] + index_list.append(None) else: raise NotImplementedError("Index in ParDecl should be ALL") else: diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index 009f45ca10..9643d51c1f 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -986,7 +986,7 @@ def _argminmax(pv: ProgramVisitor, reduced_shape = list(copy.deepcopy(a_arr.shape)) reduced_shape.pop(axis) - val_and_idx = dace.struct('_val_and_idx', val=a_arr.dtype, idx=result_type) + val_and_idx = dace.struct('_val_and_idx', idx=result_type, val=a_arr.dtype) # HACK: since identity cannot be specified for structs, we have to init the output array reduced_structs, reduced_struct_arr = sdfg.add_temp_transient(reduced_shape, val_and_idx) diff --git a/dace/properties.py b/dace/properties.py index 951a0564cc..61e569341f 100644 --- a/dace/properties.py +++ b/dace/properties.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import ast from collections import OrderedDict import copy @@ -145,11 +145,15 @@ def fs(obj, *args, **kwargs): self._from_json = lambda *args, **kwargs: dace.serialize.from_json(*args, known_type=dtype, **kwargs) else: self._from_json = from_json + if self.from_json != from_json: + self.from_json = from_json if to_json is None: self._to_json = dace.serialize.to_json else: self._to_json = to_json + if self.to_json != to_json: + self.to_json = to_json if meta_to_json is None: @@ -412,8 +416,7 @@ def initialize_properties(obj, *args, **kwargs): except AttributeError: if not prop.unmapped: raise PropertyError("Property {} is unassigned in __init__ for {}".format(name, cls.__name__)) - # Assert that there are no fields in the object not captured by - # properties, unless they are prefixed with "_" + # Assert that there are no fields in the object not captured by properties, unless they are prefixed with "_" for name, prop in obj.__dict__.items(): if (name not in properties and not name.startswith("_") and name not in dir(type(obj))): raise PropertyError("{} : Variable {} is neither a Property nor " @@ -1385,6 +1388,47 @@ def from_json(obj, context=None): raise TypeError("Cannot parse type from: {}".format(obj)) +class NestedDataClassProperty(Property): + """ Custom property type for nested data. """ + + def __get__(self, obj, objtype=None) -> 'Data': + return super().__get__(obj, objtype) + + @property + def dtype(self): + from dace import data as dt + return dt.Data + + @staticmethod + def from_string(s): + from dace import data as dt + dtype = getattr(dt, s, None) + if dtype is None or not isinstance(dtype, dt.Data): + raise ValueError("Not a valid data type: {}".format(s)) + return dtype + + @staticmethod + def to_string(obj): + return obj.to_string() + + def to_json(self, obj): + if obj is None: + return None + return obj.to_json() + + @staticmethod + def from_json(obj, context=None): + if obj is None: + return None + elif isinstance(obj, str): + return NestedDataClassProperty.from_string(obj) + elif isinstance(obj, dict): + # Let the deserializer handle this + return dace.serialize.from_json(obj) + else: + raise TypeError("Cannot parse type from: {}".format(obj)) + + class LibraryImplementationProperty(Property): """ Property for choosing an implementation type for a library node. On the diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 89ba6928c7..0fec4812b7 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -10,7 +10,7 @@ import itertools import functools import sympy -from sympy import ceiling +from sympy import ceiling, Symbol from sympy.concrete.summations import Sum import warnings import networkx as nx @@ -564,8 +564,7 @@ def _annotate_loop_ranges(sdfg, unannotated_cycle_states): Annotate each valid for loop construct with its loop variable ranges. :param sdfg: The SDFG in which to look. - :param unannotated_cycle_states: List of states in cycles without valid - for loop ranges. + :param unannotated_cycle_states: List of lists. Each sub-list contains the states of one unannotated cycle. """ # We import here to avoid cyclic imports. @@ -652,7 +651,7 @@ def _annotate_loop_ranges(sdfg, unannotated_cycle_states): res = find_for_loop(sdfg, guard, begin, itervar=itvar) if res is None: # No range detected, mark as unbounded. - unannotated_cycle_states.extend(cycle) + unannotated_cycle_states.append(cycle) else: itervar, rng, _ = res @@ -674,10 +673,10 @@ def _annotate_loop_ranges(sdfg, unannotated_cycle_states): else: # There's no guard state, so this cycle marks all states in it as # dynamically unbounded. - unannotated_cycle_states.extend(cycle) + unannotated_cycle_states.append(cycle) -def propagate_states(sdfg) -> None: +def propagate_states(sdfg, concretize_dynamic_unbounded=False) -> None: """ Annotate the states of an SDFG with the number of executions. @@ -728,6 +727,9 @@ def propagate_states(sdfg) -> None: once. :param sdfg: The SDFG to annotate. + :param concretize_dynamic_unbounded: If True, we annotate dyncamic unbounded states with symbols of the + form "num_execs_{sdfg_id}_{loop_start_state_id}". Hence, for each + unbounded loop its states will have the same number of symbolic executions. :note: This operates on the SDFG in-place. """ @@ -759,6 +761,9 @@ def propagate_states(sdfg) -> None: # cycle should be marked as unannotated. unannotated_cycle_states = [] _annotate_loop_ranges(sdfg, unannotated_cycle_states) + if not concretize_dynamic_unbounded: + # Flatten the list. This keeps the old behavior of propagate_states. + unannotated_cycle_states = [state for cycle in unannotated_cycle_states for state in cycle] # Keep track of states that fully merge a previous conditional split. We do # this so we can remove the dynamic executions flag for those states. @@ -800,7 +805,7 @@ def propagate_states(sdfg) -> None: # The only exception to this rule: If the state is in an # unannotated loop, i.e. should be annotated as dynamic # unbounded instead, we do that. - if (state in unannotated_cycle_states): + if (not concretize_dynamic_unbounded) and state in unannotated_cycle_states: state.executions = 0 state.dynamic_executions = True else: @@ -872,17 +877,39 @@ def propagate_states(sdfg) -> None: else: # Conditional split or unannotated (dynamic unbounded) loop. unannotated_loop_edge = None - for oedge in out_edges: - if oedge.dst in unannotated_cycle_states: - # This is an unannotated loop down this branch. - unannotated_loop_edge = oedge + if concretize_dynamic_unbounded: + to_remove = [] + for oedge in out_edges: + for cycle in unannotated_cycle_states: + if oedge.dst in cycle: + # This is an unannotated loop down this branch. + unannotated_loop_edge = oedge + # remove cycle, since it is now annotated with symbol + to_remove.append(cycle) + + for c in to_remove: + unannotated_cycle_states.remove(c) + else: + for oedge in out_edges: + if oedge.dst in unannotated_cycle_states: + # This is an unannotated loop down this branch. + unannotated_loop_edge = oedge if unannotated_loop_edge is not None: # Traverse as an unbounded loop. out_edges.remove(unannotated_loop_edge) for oedge in out_edges: traversal_q.append((oedge.dst, state.executions, False, itvar_stack)) - traversal_q.append((unannotated_loop_edge.dst, 0, True, itvar_stack)) + if concretize_dynamic_unbounded: + # Here we introduce the num_exec symbol and propagate it down the loop. + # We can always assume these symbols to be non-negative. + traversal_q.append( + (unannotated_loop_edge.dst, + Symbol(f'num_execs_{sdfg.sdfg_id}_{sdfg.node_id(unannotated_loop_edge.dst)}', + nonnegative=True), False, itvar_stack)) + else: + # Propagate dynamic unbounded. + traversal_q.append((unannotated_loop_edge.dst, 0, True, itvar_stack)) else: # Traverse as a conditional split. proposed_executions = state.executions diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index bbdf7de041..aecaf91a75 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -48,6 +48,41 @@ from dace.codegen.compiled_sdfg import CompiledSDFG +class NestedDict(dict): + + def __init__(self, mapping=None): + mapping = mapping or {} + super(NestedDict, self).__init__(mapping) + + def __getitem__(self, key): + tokens = key.split('.') if isinstance(key, str) else [key] + token = tokens.pop(0) + result = super(NestedDict, self).__getitem__(token) + while tokens: + token = tokens.pop(0) + result = result.members[token] + return result + + def __setitem__(self, key, val): + if isinstance(key, str) and '.' in key: + raise KeyError('NestedDict does not support setting nested keys') + super(NestedDict, self).__setitem__(key, val) + + def __contains__(self, key): + tokens = key.split('.') if isinstance(key, str) else [key] + token = tokens.pop(0) + result = super(NestedDict, self).__contains__(token) + desc = None + while tokens and result: + if desc is None: + desc = super(NestedDict, self).__getitem__(token) + else: + desc = desc.members[token] + token = tokens.pop(0) + result = token in desc.members + return result + + def _arrays_to_json(arrays): if arrays is None: return None @@ -60,6 +95,12 @@ def _arrays_from_json(obj, context=None): return {k: dace.serialize.from_json(v, context) for k, v in obj.items()} +def _nested_arrays_from_json(obj, context=None): + if obj is None: + return NestedDict({}) + return NestedDict({k: dace.serialize.from_json(v, context) for k, v in obj.items()}) + + def _replace_dict_keys(d, old, new): if old in d: if new in d: @@ -379,10 +420,10 @@ class SDFG(OrderedDiGraph[SDFGState, InterstateEdge]): name = Property(dtype=str, desc="Name of the SDFG") arg_names = ListProperty(element_type=str, desc='Ordered argument names (used for calling conventions).') constants_prop = Property(dtype=dict, default={}, desc="Compile-time constants") - _arrays = Property(dtype=dict, + _arrays = Property(dtype=NestedDict, desc="Data descriptors for this SDFG", to_json=_arrays_to_json, - from_json=_arrays_from_json) + from_json=_nested_arrays_from_json) symbols = DictProperty(str, dtypes.typeclass, desc="Global symbols for this SDFG") instrument = EnumProperty(dtype=dtypes.InstrumentationType, @@ -460,7 +501,7 @@ def __init__(self, self._sdfg_list = [self] self._start_state: Optional[int] = None self._cached_start_state: Optional[SDFGState] = None - self._arrays = {} # type: Dict[str, dt.Array] + self._arrays = NestedDict() # type: Dict[str, dt.Array] self._labels: Set[str] = set() self.global_code = {'frame': CodeBlock("", dtypes.Language.CPP)} self.init_code = {'frame': CodeBlock("", dtypes.Language.CPP)} @@ -1987,10 +2028,17 @@ def add_datadesc(self, name: str, datadesc: dt.Data, find_new_name=False) -> str raise NameError(f'Array or Stream with name "{name}" already exists in SDFG') self._arrays[name] = datadesc + def _add_symbols(desc: dt.Data): + if isinstance(desc, dt.Structure): + for v in desc.members.values(): + if isinstance(v, dt.Data): + _add_symbols(v) + for sym in desc.free_symbols: + if sym.name not in self.symbols: + self.add_symbol(sym.name, sym.dtype) + # Add free symbols to the SDFG global symbol storage - for sym in datadesc.free_symbols: - if sym.name not in self.symbols: - self.add_symbol(sym.name, sym.dtype) + _add_symbols(datadesc) return name diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index d08518b10c..3396335ece 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1396,7 +1396,7 @@ def is_nonfree_sym_dependent(node: nd.AccessNode, desc: dt.Data, state: SDFGStat :param state: the state that contains the node :param fsymbols: the free symbols to check against """ - if isinstance(desc, dt.View): + if isinstance(desc, (dt.StructureView, dt.View)): # Views can be non-free symbol dependent due to the adjacent edges. e = get_view_edge(state, node) if e.data: diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index aa7674ca45..0bb3e9a64e 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -604,9 +604,14 @@ def validate_state(state: 'dace.sdfg.SDFGState', break # Check if memlet data matches src or dst nodes - if (e.data.data is not None and (isinstance(src_node, nd.AccessNode) or isinstance(dst_node, nd.AccessNode)) - and (not isinstance(src_node, nd.AccessNode) or e.data.data != src_node.data) - and (not isinstance(dst_node, nd.AccessNode) or e.data.data != dst_node.data)): + name = e.data.data + if isinstance(src_node, nd.AccessNode) and isinstance(sdfg.arrays[src_node.data], dt.Structure): + name = None + if isinstance(dst_node, nd.AccessNode) and isinstance(sdfg.arrays[dst_node.data], dt.Structure): + name = None + if (name is not None and (isinstance(src_node, nd.AccessNode) or isinstance(dst_node, nd.AccessNode)) + and (not isinstance(src_node, nd.AccessNode) or (name != src_node.data and name != e.src_conn)) + and (not isinstance(dst_node, nd.AccessNode) or (name != dst_node.data and name != e.dst_conn))): raise InvalidSDFGEdgeError( "Memlet data does not match source or destination " "data nodes)", diff --git a/dace/sdfg/work_depth_analysis/helpers.py b/dace/sdfg/work_depth_analysis/helpers.py new file mode 100644 index 0000000000..a80e769f64 --- /dev/null +++ b/dace/sdfg/work_depth_analysis/helpers.py @@ -0,0 +1,331 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Helper functions used by the work depth analysis. """ + +from dace import SDFG, SDFGState, nodes +from collections import deque +from typing import List, Dict, Set, Tuple, Optional, Union +import networkx as nx + +NodeT = str +EdgeT = Tuple[NodeT, NodeT] + + +class NodeCycle: + + nodes: Set[NodeT] = [] + + def __init__(self, nodes: List[NodeT]) -> None: + self.nodes = set(nodes) + + @property + def length(self) -> int: + return len(self.nodes) + + +UUID_SEPARATOR = '/' + + +def ids_to_string(sdfg_id, state_id=-1, node_id=-1, edge_id=-1): + return (str(sdfg_id) + UUID_SEPARATOR + str(state_id) + UUID_SEPARATOR + str(node_id) + UUID_SEPARATOR + + str(edge_id)) + + +def get_uuid(element, state=None): + if isinstance(element, SDFG): + return ids_to_string(element.sdfg_id) + elif isinstance(element, SDFGState): + return ids_to_string(element.parent.sdfg_id, element.parent.node_id(element)) + elif isinstance(element, nodes.Node): + return ids_to_string(state.parent.sdfg_id, state.parent.node_id(state), state.node_id(element)) + else: + return ids_to_string(-1) + + +def get_domtree(graph: nx.DiGraph, start_node: str, idom: Dict[str, str] = None): + idom = idom or nx.immediate_dominators(graph, start_node) + + alldominated = {n: set() for n in graph.nodes} + domtree = nx.DiGraph() + + for node, dom in idom.items(): + if node is dom: + continue + domtree.add_edge(dom, node) + alldominated[dom].add(node) + + nextidom = idom[dom] + ndom = nextidom if nextidom != dom else None + + while ndom: + alldominated[ndom].add(node) + nextidom = idom[ndom] + ndom = nextidom if nextidom != ndom else None + + # 'Rank' the tree, i.e., annotate each node with the level it is on. + q = deque() + q.append((start_node, 0)) + while q: + node, level = q.popleft() + domtree.add_node(node, level=level) + for s in domtree.successors(node): + q.append((s, level + 1)) + + return alldominated, domtree + + +def get_backedges(graph: nx.DiGraph, + start: Optional[NodeT], + strict: bool = False) -> Union[Set[EdgeT], Tuple[Set[EdgeT], Set[EdgeT]]]: + '''Find all backedges in a directed graph. + + Note: + This algorithm has an algorithmic complexity of O((|V|+|E|)*C) for a + graph with vertices V, edges E, and C cycles. + + Args: + graph (nx.DiGraph): The graph for which to search backedges. + start (str): Start node of the graph. If no start is provided, a node + with no incoming edges is used as the start. If no such node can + be found, a `ValueError` is raised. + + Returns: + A set of backedges in the graph. + + Raises: + ValueError: If no `start` is provided and the graph contains no nodes + with no incoming edges. + ''' + backedges = set() + eclipsed_backedges = set() + + if start is None: + for node in graph.nodes(): + if graph.in_degree(node) == 0: + start = node + break + if start is None: + raise ValueError('No start node provided and no start node could ' + 'be determined automatically') + + # Gather all cycles in the graph. Cycles are represented as a sequence of + # nodes. + # O((|V|+|E|)*(C+1)), for C cycles. + all_cycles_nx: List[List[NodeT]] = nx.cycles.simple_cycles(graph) + #all_cycles_nx: List[List[NodeT]] = nx.simple_cycles(graph) + all_cycles: Set[NodeCycle] = set() + for cycle in all_cycles_nx: + all_cycles.add(NodeCycle(cycle)) + + # Construct a dictionary mapping a node to the cycles containing that node. + # O(|V|*|C|) + cycle_map: Dict[NodeT, Set[NodeCycle]] = dict() + for cycle in all_cycles: + for node in cycle.nodes: + try: + cycle_map[node].add(cycle) + except KeyError: + cycle_map[node] = set([cycle]) + + # Do a BFS traversal of the graph to detect the back edges. + # For each node that is part of an (unhandled) cycle, find the longest + # still unhandled cycle and try to use it to find the back edge for it. + bfs_frontier = [start] + visited: Set[NodeT] = set([start]) + handled_cycles: Set[NodeCycle] = set() + unhandled_cycles = all_cycles + while bfs_frontier: + node = bfs_frontier.pop(0) + pred = [p for p in graph.predecessors(node) if p not in visited] + longest_cycles: Dict[NodeT, NodeCycle] = dict() + try: + cycles = cycle_map[node] + remove_cycles = set() + for cycle in cycles: + if cycle not in handled_cycles: + for p in pred: + if p in cycle.nodes: + if p not in longest_cycles: + longest_cycles[p] = cycle + else: + if cycle.length > longest_cycles[p].length: + longest_cycles[p] = cycle + else: + remove_cycles.add(cycle) + for cycle in remove_cycles: + cycles.remove(cycle) + except KeyError: + longest_cycles = dict() + + # For the current node, find the incoming edge which belongs to the + # cycle and has not been visited yet, which indicates a backedge. + node_backedge_candidates: Set[Tuple[EdgeT, NodeCycle]] = set() + for p, longest_cycle in longest_cycles.items(): + handled_cycles.add(longest_cycle) + unhandled_cycles.remove(longest_cycle) + cycle_map[node].remove(longest_cycle) + backedge_candidates = graph.in_edges(node) + for candidate in backedge_candidates: + src = candidate[0] + dst = candidate[0] + if src not in visited and src in longest_cycle.nodes: + node_backedge_candidates.add((candidate, longest_cycle)) + if not strict: + backedges.add(candidate) + + # Make sure that any cycle containing this back edge is + # not evaluated again, i.e., mark as handled. + remove_cycles = set() + for cycle in unhandled_cycles: + if src in cycle.nodes and dst in cycle.nodes: + handled_cycles.add(cycle) + remove_cycles.add(cycle) + for cycle in remove_cycles: + unhandled_cycles.remove(cycle) + + # If strict is set, we only report the longest cycle's back edges for + # any given node, and separately return any other backedges as + # 'eclipsed' backedges. In the case of a while-loop, for example, + # the loop edge is considered a backedge, while a continue inside the + # loop is considered an 'eclipsed' backedge. + if strict: + longest_candidate: Tuple[EdgeT, NodeCycle] = None + eclipsed_candidates = set() + for be_candidate in node_backedge_candidates: + if longest_candidate is None: + longest_candidate = be_candidate + elif longest_candidate[1].length < be_candidate[1].length: + eclipsed_candidates.add(longest_candidate[0]) + longest_candidate = be_candidate + else: + eclipsed_candidates.add(be_candidate[0]) + if longest_candidate is not None: + backedges.add(longest_candidate[0]) + if eclipsed_candidates: + eclipsed_backedges.update(eclipsed_candidates) + + # Continue BFS. + for neighbour in graph.successors(node): + if neighbour not in visited: + visited.add(neighbour) + bfs_frontier.append(neighbour) + + if strict: + return backedges, eclipsed_backedges + else: + return backedges + + +def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): + """ + Detects loops in a SDFG. For each loop, it identifies (node, oNode, exit). + We know that there is a backedge from oNode to node that creates the loop and that exit is the exit state of the loop. + + :param sdfg_nx: The networkx representation of a SDFG. + """ + + # preparation phase: compute dominators, backedges etc + for node in sdfg_nx.nodes(): + if sdfg_nx.in_degree(node) == 0: + start = node + break + if start is None: + raise ValueError('No start node could be determined') + + # sdfg can have multiple end nodes --> not good for postDomTree + # --> add a new end node + artificial_end_node = 'artificial_end_node' + sdfg_nx.add_node(artificial_end_node) + for node in sdfg_nx.nodes(): + if sdfg_nx.out_degree(node) == 0 and node != artificial_end_node: + # this is an end node of the sdfg + sdfg_nx.add_edge(node, artificial_end_node) + + # sanity check: + if sdfg_nx.in_degree(artificial_end_node) == 0: + raise ValueError('No end node could be determined in the SDFG') + + # compute dominators and backedges + iDoms = nx.immediate_dominators(sdfg_nx, start) + allDom, domTree = get_domtree(sdfg_nx, start, iDoms) + + reversed_sdfg_nx = sdfg_nx.reverse() + iPostDoms = nx.immediate_dominators(reversed_sdfg_nx, artificial_end_node) + allPostDoms, postDomTree = get_domtree(reversed_sdfg_nx, artificial_end_node, iPostDoms) + + backedges = get_backedges(sdfg_nx, start) + backedgesDstDict = {} + for be in backedges: + if be[1] in backedgesDstDict: + backedgesDstDict[be[1]].add(be) + else: + backedgesDstDict[be[1]] = set([be]) + + # This list will be filled with triples (node, oNode, exit), one triple for each loop construct in the SDFG. + # There will always be a backedge from oNode to node. Either node or oNode will be the corresponding loop guard, + # depending on whether it is a while-do or a do-while loop. exit will always be the exit state of the loop. + nodes_oNodes_exits = [] + + # iterate over all nodes + for node in sdfg_nx.nodes(): + # Check if any backedge ends in node. + if node in backedgesDstDict: + inc_backedges = backedgesDstDict[node] + + # gather all successors of node that are not reached by backedges + successors = [] + for edge in sdfg_nx.out_edges(node): + if not edge in backedges: + successors.append(edge[1]) + + # For each incoming backedge, we want to find oNode and exit. There can be multiple backedges, in case + # we have a continue statement in the original code. But we can handle these backedges normally. + for be in inc_backedges: + # since node has an incoming backedge, it is either a loop guard or loop tail + # oNode will exactly be the other thing + oNode = be[0] + exitCandidates = set() + # search for exit candidates: + # a state is a exit candidate if: + # - it is in successor and it does not dominate oNode (else it dominates + # the last loop state, and hence is inside the loop itself) + # - is is a successor of oNode (but not node) + # This handles both cases of while-do and do-while loops + for succ in successors: + if succ != oNode and oNode not in allDom[succ]: + exitCandidates.add(succ) + for succ in sdfg_nx.successors(oNode): + if succ != node: + exitCandidates.add(succ) + + if len(exitCandidates) == 0: + raise ValueError('failed to find any exit nodes') + elif len(exitCandidates) > 1: + # Find the exit candidate that sits highest up in the + # postdominator tree (i.e., has the lowest level). + # That must be the exit node (it must post-dominate) + # everything inside the loop. If there are multiple + # candidates on the lowest level (i.e., disjoint set of + # postdominated nodes), there are multiple exit paths, + # and they all share one level. + cand = exitCandidates.pop() + minSet = set([cand]) + minLevel = nx.get_node_attributes(postDomTree, 'level')[cand] + for cand in exitCandidates: + curr_level = nx.get_node_attributes(postDomTree, 'level')[cand] + if curr_level < minLevel: + # new minimum found + minLevel = curr_level + minSet.clear() + minSet.add(cand) + elif curr_level == minLevel: + # add cand to curr set + minSet.add(cand) + + if len(minSet) > 0: + exitCandidates = minSet + else: + raise ValueError('failed to find exit minSet') + + # now we have a triple (node, oNode, exitCandidates) + nodes_oNodes_exits.append((node, oNode, exitCandidates)) + + return nodes_oNodes_exits diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/work_depth_analysis/work_depth.py new file mode 100644 index 0000000000..a05fe10266 --- /dev/null +++ b/dace/sdfg/work_depth_analysis/work_depth.py @@ -0,0 +1,653 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Work depth analysis for any input SDFG. Can be used with the DaCe VS Code extension or +from command line as a Python script. """ + +import argparse +from collections import deque +from dace.sdfg import nodes as nd, propagation, InterstateEdge +from dace import SDFG, SDFGState, dtypes +from dace.subsets import Range +from typing import Tuple, Dict +import os +import sympy as sp +from copy import deepcopy +from dace.libraries.blas import MatMul +from dace.libraries.standard import Reduce, Transpose +from dace.symbolic import pystr_to_symbolic +import ast +import astunparse +import warnings + +from dace.sdfg.work_depth_analysis.helpers import get_uuid, find_loop_guards_tails_exits + + +def get_array_size_symbols(sdfg): + """ + Returns all symbols that appear isolated in shapes of the SDFG's arrays. + These symbols can then be assumed to be positive. + + :note: This only works if a symbol appears in isolation, i.e. array A[N]. + If we have A[N+1], we cannot assume N to be positive. + :param sdfg: The SDFG in which it searches for symbols. + :return: A set containing symbols which we can assume to be positive. + """ + symbols = set() + for _, _, arr in sdfg.arrays_recursive(): + for s in arr.shape: + if isinstance(s, sp.Symbol): + symbols.add(s) + return symbols + + +def posify_certain_symbols(expr, syms_to_posify): + """ + Takes an expression and evaluates it while assuming that certain symbols are positive. + + :param expr: The expression to evaluate. + :param syms_to_posify: List of symbols we assume to be positive. + :note: This is adapted from the Sympy function posify. + """ + + expr = sp.sympify(expr) + + reps = {s: sp.Dummy(s.name, positive=True, **s.assumptions0) for s in syms_to_posify if s.is_positive is None} + expr = expr.subs(reps) + return expr.subs({r: s for s, r in reps.items()}) + + +def symeval(val, symbols): + """ + Takes a sympy expression and substitutes its symbols according to a dict { old_symbol: new_symbol}. + + :param val: The expression we are updating. + :param symbols: Dictionary of key value pairs { old_symbol: new_symbol}. + """ + first_replacement = {pystr_to_symbolic(k): pystr_to_symbolic('__REPLSYM_' + k) for k in symbols.keys()} + second_replacement = {pystr_to_symbolic('__REPLSYM_' + k): v for k, v in symbols.items()} + return val.subs(first_replacement).subs(second_replacement) + + +def evaluate_symbols(base, new): + result = {} + for k, v in new.items(): + result[k] = symeval(v, base) + return result + + +def count_work_matmul(node, symbols, state): + A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') + B_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_b') + C_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_c') + result = 2 # Multiply, add + # Batch + if len(C_memlet.data.subset) == 3: + result *= symeval(C_memlet.data.subset.size()[0], symbols) + # M*N + result *= symeval(C_memlet.data.subset.size()[-2], symbols) + result *= symeval(C_memlet.data.subset.size()[-1], symbols) + # K + result *= symeval(A_memlet.data.subset.size()[-1], symbols) + return result + + +def count_work_reduce(node, symbols, state): + result = 0 + if node.wcr is not None: + result += count_arithmetic_ops_code(node.wcr) + in_memlet = None + in_edges = state.in_edges(node) + if in_edges is not None and len(in_edges) == 1: + in_memlet = in_edges[0] + if in_memlet is not None and in_memlet.data.volume is not None: + result *= in_memlet.data.volume + else: + result = 0 + return result + + +LIBNODES_TO_WORK = { + MatMul: count_work_matmul, + Transpose: lambda *args: 0, + Reduce: count_work_reduce, +} + + +def count_depth_matmul(node, symbols, state): + # For now we set it equal to work: see comments in count_depth_reduce just below + return count_work_matmul(node, symbols, state) + + +def count_depth_reduce(node, symbols, state): + # depth of reduction is log2 of the work + # TODO: Can we actually assume this? Or is it equal to the work? + # Another thing to consider is that we essetially do NOT count wcr edges as operations for now... + + # return sp.ceiling(sp.log(count_work_reduce(node, symbols, state), 2)) + # set it equal to work for now + return count_work_reduce(node, symbols, state) + + +LIBNODES_TO_DEPTH = { + MatMul: count_depth_matmul, + Transpose: lambda *args: 0, + Reduce: count_depth_reduce, +} + +bigo = sp.Function('bigo') +PYFUNC_TO_ARITHMETICS = { + 'float': 0, + 'dace.float64': 0, + 'dace.int64': 0, + 'math.exp': 1, + 'exp': 1, + 'math.tanh': 1, + 'sin': 1, + 'cos': 1, + 'tanh': 1, + 'math.sqrt': 1, + 'sqrt': 1, + 'atan2:': 1, + 'min': 0, + 'max': 0, + 'ceiling': 0, + 'floor': 0, + 'abs': 0 +} + + +class ArithmeticCounter(ast.NodeVisitor): + + def __init__(self): + self.count = 0 + + def visit_BinOp(self, node): + if isinstance(node.op, ast.MatMult): + raise NotImplementedError('MatMult op count requires shape ' + 'inference') + self.count += 1 + return self.generic_visit(node) + + def visit_UnaryOp(self, node): + self.count += 1 + return self.generic_visit(node) + + def visit_Call(self, node): + fname = astunparse.unparse(node.func)[:-1] + if fname not in PYFUNC_TO_ARITHMETICS: + print( + 'WARNING: Unrecognized python function "%s". If this is a type conversion, like "dace.float64", then this is fine.' + % fname) + return self.generic_visit(node) + self.count += PYFUNC_TO_ARITHMETICS[fname] + return self.generic_visit(node) + + def visit_AugAssign(self, node): + return self.visit_BinOp(node) + + def visit_For(self, node): + raise NotImplementedError + + def visit_While(self, node): + raise NotImplementedError + + +def count_arithmetic_ops_code(code): + ctr = ArithmeticCounter() + if isinstance(code, (tuple, list)): + for stmt in code: + ctr.visit(stmt) + elif isinstance(code, str): + ctr.visit(ast.parse(code)) + else: + ctr.visit(code) + return ctr.count + + +class DepthCounter(ast.NodeVisitor): + # so far this is identical to the ArithmeticCounter above. + def __init__(self): + self.count = 0 + + def visit_BinOp(self, node): + if isinstance(node.op, ast.MatMult): + raise NotImplementedError('MatMult op count requires shape ' + 'inference') + self.count += 1 + return self.generic_visit(node) + + def visit_UnaryOp(self, node): + self.count += 1 + return self.generic_visit(node) + + def visit_Call(self, node): + fname = astunparse.unparse(node.func)[:-1] + if fname not in PYFUNC_TO_ARITHMETICS: + print( + 'WARNING: Unrecognized python function "%s". If this is a type conversion, like "dace.float64", then this is fine.' + % fname) + return self.generic_visit(node) + self.count += PYFUNC_TO_ARITHMETICS[fname] + return self.generic_visit(node) + + def visit_AugAssign(self, node): + return self.visit_BinOp(node) + + def visit_For(self, node): + raise NotImplementedError + + def visit_While(self, node): + raise NotImplementedError + + +def count_depth_code(code): + # so far this is the same as the work counter, since work = depth for each tasklet, as we can't assume any parallelism + ctr = ArithmeticCounter() + if isinstance(code, (tuple, list)): + for stmt in code: + ctr.visit(stmt) + elif isinstance(code, str): + ctr.visit(ast.parse(code)) + else: + ctr.visit(code) + return ctr.count + + +def tasklet_work(tasklet_node, state): + if tasklet_node.code.language == dtypes.Language.CPP: + for oedge in state.out_edges(tasklet_node): + return bigo(oedge.data.num_accesses) + + elif tasklet_node.code.language == dtypes.Language.Python: + return count_arithmetic_ops_code(tasklet_node.code.code) + else: + # other languages not implemented, count whole tasklet as work of 1 + warnings.warn('Work of tasklets only properly analyzed for Python or CPP. For all other ' + 'languages work = 1 will be counted for each tasklet.') + return 1 + + +def tasklet_depth(tasklet_node, state): + # TODO: how to get depth of CPP tasklets? + # For now we use depth == work: + if tasklet_node.code.language == dtypes.Language.CPP: + for oedge in state.out_edges(tasklet_node): + return bigo(oedge.data.num_accesses) + if tasklet_node.code.language == dtypes.Language.Python: + return count_depth_code(tasklet_node.code.code) + else: + # other languages not implemented, count whole tasklet as work of 1 + warnings.warn('Depth of tasklets only properly analyzed for Python code. For all other ' + 'languages depth = 1 will be counted for each tasklet.') + return 1 + + +def get_tasklet_work(node, state): + return tasklet_work(node, state), -1 + + +def get_tasklet_work_depth(node, state): + return tasklet_work(node, state), tasklet_depth(node, state) + + +def get_tasklet_avg_par(node, state): + return tasklet_work(node, state), tasklet_depth(node, state) + + +def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, + symbols) -> Tuple[sp.Expr, sp.Expr]: + """ + Analyze the work and depth of a given SDFG. + First we determine the work and depth of each state. Then we break loops in the state machine, such that we get a DAG. + Lastly, we compute the path with most work and the path with the most depth in order to get the total work depth. + + :param sdfg: The SDFG to analyze. + :param w_d_map: Dictionary which will save the result. + :param analyze_tasklet: Function used to analyze tasklet nodes. + :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. + :return: A tuple containing the work and depth of the SDFG. + """ + + # First determine the work and depth of each state individually. + # Keep track of the work and depth for each state in a dictionary, where work and depth are multiplied by the number + # of times the state will be executed. + state_depths: Dict[SDFGState, sp.Expr] = {} + state_works: Dict[SDFGState, sp.Expr] = {} + for state in sdfg.nodes(): + state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols) + state_works[state] = sp.simplify(state_work * state.executions) + state_depths[state] = sp.simplify(state_depth * state.executions) + w_d_map[get_uuid(state)] = (state_works[state], state_depths[state]) + + # Prepare the SDFG for a depth analysis by breaking loops. This removes the edge between the last loop state and + # the guard, and instead places an edge between the last loop state and the exit state. + # This transforms the state machine into a DAG. Hence, we can find the "heaviest" and "deepest" paths in linear time. + # Additionally, construct a dummy exit state and connect every state that has no outgoing edges to it. + + # identify all loops in the SDFG + nodes_oNodes_exits = find_loop_guards_tails_exits(sdfg._nx) + + # Now we need to go over each triple (node, oNode, exits). For each triple, we + # - remove edge (oNode, node), i.e. the backward edge + # - for all exits e, add edge (oNode, e). This edge may already exist + for node, oNode, exits in nodes_oNodes_exits: + sdfg.remove_edge(sdfg.edges_between(oNode, node)[0]) + for e in exits: + if len(sdfg.edges_between(oNode, e)) == 0: + # no edge there yet + sdfg.add_edge(oNode, e, InterstateEdge()) + + # add a dummy exit to the SDFG, such that each path ends there. + dummy_exit = sdfg.add_state('dummy_exit') + for state in sdfg.nodes(): + if len(sdfg.out_edges(state)) == 0 and state != dummy_exit: + sdfg.add_edge(state, dummy_exit, InterstateEdge()) + + # These two dicts save the current length of the "heaviest", resp. "deepest", paths at each state. + work_map: Dict[SDFGState, sp.Expr] = {} + depth_map: Dict[SDFGState, sp.Expr] = {} + # The dummy state has 0 work and depth. + state_depths[dummy_exit] = sp.sympify(0) + state_works[dummy_exit] = sp.sympify(0) + + # Perform a BFS traversal of the state machine and calculate the maximum work / depth at each state. Only advance to + # the next state in the BFS if all incoming edges have been visited, to ensure the maximum work / depth expressions + # have been calculated. + traversal_q = deque() + traversal_q.append((sdfg.start_state, sp.sympify(0), sp.sympify(0), None)) + visited = set() + while traversal_q: + state, depth, work, ie = traversal_q.popleft() + + if ie is not None: + visited.add(ie) + + n_depth = sp.simplify(depth + state_depths[state]) + n_work = sp.simplify(work + state_works[state]) + + # If we are analysing average parallelism, we don't search "heaviest" and "deepest" paths separately, but we want one + # single path with the least average parallelsim (of all paths with more than 0 work). + if analyze_tasklet == get_tasklet_avg_par: + if state in depth_map: # and hence als state in work_map + # if current path has 0 depth, we don't do anything. + if n_depth != 0: + # see if we need to update the work and depth of the current state + # we update if avg parallelism of new incoming path is less than current avg parallelism + old_avg_par = sp.simplify(work_map[state] / depth_map[state]) + new_avg_par = sp.simplify(n_work / n_depth) + + if depth_map[state] == 0 or new_avg_par < old_avg_par: + # old value was divided by zero or new path gives actually worse avg par, then we keep new value + depth_map[state] = n_depth + work_map[state] = n_work + else: + depth_map[state] = n_depth + work_map[state] = n_work + else: + # search heaviest and deepest path separately + if state in depth_map: # and consequently also in work_map + depth_map[state] = sp.Max(depth_map[state], n_depth) + work_map[state] = sp.Max(work_map[state], n_work) + else: + depth_map[state] = n_depth + work_map[state] = n_work + + out_edges = sdfg.out_edges(state) + # only advance after all incoming edges were visited (meaning that current work depth values of state are final). + if any(iedge not in visited for iedge in sdfg.in_edges(state)): + pass + else: + for oedge in out_edges: + traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge)) + + try: + max_depth = depth_map[dummy_exit] + max_work = work_map[dummy_exit] + except KeyError: + # If we get a KeyError above, this means that the traversal never reached the dummy_exit state. + # This happens if the loops were not properly detected and broken. + raise Exception( + 'Analysis failed, since not all loops got detected. It may help to use more structured loop constructs.') + + sdfg_result = (sp.simplify(max_work), sp.simplify(max_depth)) + w_d_map[get_uuid(sdfg)] = sdfg_result + return sdfg_result + + +def scope_work_depth(state: SDFGState, + w_d_map: Dict[str, sp.Expr], + analyze_tasklet, + symbols, + entry: nd.EntryNode = None) -> Tuple[sp.Expr, sp.Expr]: + """ + Analyze the work and depth of a scope. + This works by traversing through the scope analyzing the work and depth of each encountered node. + Depending on what kind of node we encounter, we do the following: + - EntryNode: Recursively analyze work depth of scope. + - Tasklet: use analyze_tasklet to get work depth of tasklet node. + - NestedSDFG: After translating its local symbols to global symbols, we analyze the nested SDFG recursively. + - LibraryNode: Library nodes are analyzed with special functions depending on their type. + Work inside a state can simply be summed up, but for the depth we need to find the longest path. Since dataflow is a DAG, + this can be done in linear time by traversing the graph in topological order. + + :param state: The state in which the scope to analyze is contained. + :param sym_map: A dictionary mapping symbols to their values. + :param entry: The entry node of the scope to analyze. If None, the entire state is analyzed. + :return: A tuple containing the work and depth of the scope. + """ + + # find the work and depth of each node + # for maps and nested SDFG, we do it recursively + work = sp.sympify(0) + max_depth = sp.sympify(0) + scope_nodes = state.scope_children()[entry] + scope_exit = None if entry is None else state.exit_node(entry) + for node in scope_nodes: + # add node to map + w_d_map[get_uuid(node, state)] = (sp.sympify(0), sp.sympify(0)) + if isinstance(node, nd.EntryNode): + # If the scope contains an entry node, we need to recursively analyze the sub-scope of the entry node first. + # The resulting work/depth are summarized into the entry node + s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, node) + # add up work for whole state, but also save work for this sub-scope scope in w_d_map + work += s_work + w_d_map[get_uuid(node, state)] = (s_work, s_depth) + elif node == scope_exit: + # don't do anything for exit nodes, everthing handled already in the corresponding entry node. + pass + elif isinstance(node, nd.Tasklet): + # add up work for whole state, but also save work for this node in w_d_map + t_work, t_depth = analyze_tasklet(node, state) + work += t_work + w_d_map[get_uuid(node, state)] = (sp.sympify(t_work), sp.sympify(t_depth)) + elif isinstance(node, nd.NestedSDFG): + # keep track of nested symbols: "symbols" maps local nested SDFG symbols to global symbols. + # We only want global symbols in our final work depth expressions. + nested_syms = {} + nested_syms.update(symbols) + nested_syms.update(evaluate_symbols(symbols, node.symbol_mapping)) + # Nested SDFGs are recursively analyzed first. + nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms) + + # add up work for whole state, but also save work for this nested SDFG in w_d_map + work += nsdfg_work + w_d_map[get_uuid(node, state)] = (nsdfg_work, nsdfg_depth) + elif isinstance(node, nd.LibraryNode): + lib_node_work = LIBNODES_TO_WORK[type(node)](node, symbols, state) + work += lib_node_work + lib_node_depth = -1 # not analyzed + if analyze_tasklet != get_tasklet_work: + # we are analyzing depth + lib_node_depth = LIBNODES_TO_DEPTH[type(node)](node, symbols, state) + w_d_map[get_uuid(node, state)] = (lib_node_work, lib_node_depth) + + if entry is not None: + # If the scope being analyzed is a map, multiply the work by the number of iterations of the map. + if isinstance(entry, nd.MapEntry): + nmap: nd.Map = entry.map + range: Range = nmap.range + n_exec = range.num_elements_exact() + work = work * sp.simplify(n_exec) + else: + print('WARNING: Only Map scopes are supported in work analysis for now. Assuming 1 iteration.') + + # Work inside a state can simply be summed up. But now we need to find the depth of a state (i.e. longest path). + # Since dataflow graph is a DAG, this can be done in linear time. + max_depth = sp.sympify(0) + # only do this if we are analyzing depth + if analyze_tasklet == get_tasklet_work_depth or analyze_tasklet == get_tasklet_avg_par: + # Calculate the maximum depth of the scope by finding the 'deepest' path from the source to the sink. This is done by + # a traversal in topological order, where each node propagates its current max depth for all incoming paths. + traversal_q = deque() + visited = set() + # find all starting nodes + if entry: + # the entry is the starting node + traversal_q.append((entry, sp.sympify(0), None)) + else: + for node in scope_nodes: + if len(state.in_edges(node)) == 0: + # This node is a start node of the traversal + traversal_q.append((node, sp.sympify(0), None)) + # this map keeps track of the length of the longest path ending at each state so far seen. + depth_map = {} + while traversal_q: + node, in_depth, in_edge = traversal_q.popleft() + + if in_edge is not None: + visited.add(in_edge) + + n_depth = sp.simplify(in_depth + w_d_map[get_uuid(node, state)][1]) + + if node in depth_map: + depth_map[node] = sp.Max(depth_map[node], n_depth) + else: + depth_map[node] = n_depth + + out_edges = state.out_edges(node) + # Only advance to next node, if all incoming edges have been visited or the current node is the entry (aka starting node). + # If the current node is the exit of the scope, we stop, such that we don't leave the scope. + if (all(iedge in visited for iedge in state.in_edges(node)) or node == entry) and node != scope_exit: + # If we encounter a nested map, we must not analyze its contents (as they have already been recursively analyzed). + # Hence, we continue from the outgoing edges of the corresponding exit. + if isinstance(node, nd.EntryNode) and node != entry: + exit_node = state.exit_node(node) + # replace out_edges with the out_edges of the scope exit node + out_edges = state.out_edges(exit_node) + for oedge in out_edges: + traversal_q.append((oedge.dst, depth_map[node], oedge)) + if len(out_edges) == 0 or node == scope_exit: + # We have reached an end node --> update max_depth + max_depth = sp.Max(max_depth, depth_map[node]) + + # summarise work / depth of the whole scope in the dictionary + scope_result = (sp.simplify(work), sp.simplify(max_depth)) + w_d_map[get_uuid(state)] = scope_result + return scope_result + + +def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, + symbols) -> Tuple[sp.Expr, sp.Expr]: + """ + Analyze the work and depth of a state. + + :param state: The state to analyze. + :param w_d_map: The result will be saved to this map. + :param analyze_tasklet: Function used to analyze tasklet nodes. + :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. + :return: A tuple containing the work and depth of the state. + """ + work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, None) + return work, depth + + +def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet) -> None: + """ + Analyze a given SDFG. We can either analyze work, work and depth or average parallelism. + + :note: SDFGs should have split interstate edges. This means there should be no interstate edges containing both a + condition and an assignment. + :param sdfg: The SDFG to analyze. + :param w_d_map: Dictionary of SDFG elements to (work, depth) tuples. Result will be saved in here. + :param analyze_tasklet: The function used to analyze tasklet nodes. Analyzes either just work, work and depth or average parallelism. + """ + + # deepcopy such that original sdfg not changed + sdfg = deepcopy(sdfg) + + # Run state propagation for all SDFGs recursively. This is necessary to determine the number of times each state + # will be executed, or to determine upper bounds for that number (such as in the case of branching) + for sd in sdfg.all_sdfgs_recursive(): + propagation.propagate_states(sd, concretize_dynamic_unbounded=True) + + # Analyze the work and depth of the SDFG. + symbols = {} + sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, symbols) + + # Note: This posify could be done more often to improve performance. + array_symbols = get_array_size_symbols(sdfg) + for k, (v_w, v_d) in w_d_map.items(): + # The symeval replaces nested SDFG symbols with their global counterparts. + v_w = posify_certain_symbols(symeval(v_w, symbols), array_symbols) + v_d = posify_certain_symbols(symeval(v_d, symbols), array_symbols) + w_d_map[k] = (v_w, v_d) + + +################################################################################ +# Utility functions for running the analysis from the command line ############# +################################################################################ + + +def main() -> None: + + parser = argparse.ArgumentParser('work_depth', + usage='python work_depth.py [-h] filename --analyze {work,workDepth,avgPar}', + description='Analyze the work/depth of an SDFG.') + + parser.add_argument('filename', type=str, help='The SDFG file to analyze.') + parser.add_argument('--analyze', + choices=['work', 'workDepth', 'avgPar'], + default='workDepth', + help='Choose what to analyze. Default: workDepth') + + args = parser.parse_args() + + if not os.path.exists(args.filename): + print(args.filename, 'does not exist.') + exit() + + if args.analyze == 'workDepth': + analyze_tasklet = get_tasklet_work_depth + elif args.analyze == 'avgPar': + analyze_tasklet = get_tasklet_avg_par + elif args.analyze == 'work': + analyze_tasklet = get_tasklet_work + + sdfg = SDFG.from_file(args.filename) + work_depth_map = {} + analyze_sdfg(sdfg, work_depth_map, analyze_tasklet) + + if args.analyze == 'workDepth': + for k, v, in work_depth_map.items(): + work_depth_map[k] = (str(sp.simplify(v[0])), str(sp.simplify(v[1]))) + elif args.analyze == 'work': + for k, v, in work_depth_map.items(): + work_depth_map[k] = str(sp.simplify(v[0])) + elif args.analyze == 'avgPar': + for k, v, in work_depth_map.items(): + work_depth_map[k] = str(sp.simplify(v[0] / v[1]) if str(v[1]) != '0' else 0) # work / depth = avg par + + result_whole_sdfg = work_depth_map[get_uuid(sdfg)] + + print(80 * '-') + if args.analyze == 'workDepth': + print("Work:\t", result_whole_sdfg[0]) + print("Depth:\t", result_whole_sdfg[1]) + elif args.analyze == 'work': + print("Work:\t", result_whole_sdfg) + elif args.analyze == 'avgPar': + print("Average Parallelism:\t", result_whole_sdfg) + print(80 * '-') + + +if __name__ == '__main__': + main() diff --git a/dace/transformation/auto/fpga.py b/dace/transformation/auto/fpga.py index 4295699cdb..573341e1f6 100644 --- a/dace/transformation/auto/fpga.py +++ b/dace/transformation/auto/fpga.py @@ -44,24 +44,28 @@ def fpga_global_to_local(sdfg: SDFG, max_size: int = 1048576) -> None: print(f'Applied {len(converted)} Global-To-Local{": " if len(converted)>0 else "."} {", ".join(converted)}') -def fpga_rr_interleave_containers_to_banks(sdfg: SDFG, num_banks: int = 4): +def fpga_rr_interleave_containers_to_banks(sdfg: SDFG, num_banks: int = 4, memory_type: str = "DDR"): """ Allocates the (global) arrays to FPGA off-chip memory banks, interleaving them in a Round-Robin (RR) fashion. This applies to all the arrays in the SDFG hierarchy. :param sdfg: The SDFG to operate on. :param num_banks: number of off-chip memory banks to consider + :param memory_type: type of off-chip memory, either "DDR" or "HBM" (if the target FPGA supports it) :return: a list containing the number of (transient) arrays allocated to each bank :note: Operates in-place on the SDFG. """ + if memory_type.upper() not in {"DDR", "HBM"}: + raise ValueError("Memory type should be either \"DDR\" or \"HBM\"") + # keep track of memory allocated to each bank num_allocated = [0 for i in range(num_banks)] i = 0 for sd, aname, desc in sdfg.arrays_recursive(): if not isinstance(desc, dt.Stream) and desc.storage == dtypes.StorageType.FPGA_Global and desc.transient: - desc.location["memorytype"] = "ddr" + desc.location["memorytype"] = memory_type.upper() desc.location["bank"] = str(i % num_banks) num_allocated[i % num_banks] = num_allocated[i % num_banks] + 1 i = i + 1 diff --git a/tests/fortran/array_test.py b/tests/fortran/array_test.py index 8685628012..a8ece680a6 100644 --- a/tests/fortran/array_test.py +++ b/tests/fortran/array_test.py @@ -11,6 +11,7 @@ from dace.frontend.fortran import fortran_parser from fparser.two.symbol_table import SymbolTable from dace.sdfg import utils as sdutil +from dace.sdfg.nodes import AccessNode import dace.frontend.fortran.ast_components as ast_components import dace.frontend.fortran.ast_transforms as ast_transforms @@ -167,6 +168,54 @@ def test_fortran_frontend_input_output_connector(): assert (a[1, 2] == 0) +def test_fortran_frontend_memlet_in_map_test(): + """ + Tests that no assumption is made where the iteration variable is inside a memlet subset + """ + test_string = """ + PROGRAM memlet_range_test + implicit None + REAL INP(100, 10) + REAL OUT(100, 10) + CALL memlet_range_test_routine(INP, OUT) + END PROGRAM + + SUBROUTINE memlet_range_test_routine(INP, OUT) + REAL INP(100, 10) + REAL OUT(100, 10) + DO I=1,100 + CALL inner_loops(INP(I, :), OUT(I, :)) + ENDDO + END SUBROUTINE memlet_range_test_routine + + SUBROUTINE inner_loops(INP, OUT) + REAL INP(10) + REAL OUT(10) + DO J=1,10 + OUT(J) = INP(J) + 1 + ENDDO + END SUBROUTINE inner_loops + + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, "memlet_range_test") + sdfg.simplify() + # Expect that start is begin of for loop -> only one out edge to guard defining iterator variable + assert len(sdfg.out_edges(sdfg.start_state)) == 1 + iter_var = symbolic.symbol(list(sdfg.out_edges(sdfg.start_state)[0].data.assignments.keys())[0]) + + for state in sdfg.states(): + if len(state.nodes()) > 1: + for node in state.nodes(): + if isinstance(node, AccessNode) and node.data in ['INP', 'OUT']: + edges = [*state.in_edges(node), *state.out_edges(node)] + # There should be only one edge in/to the access node + assert len(edges) == 1 + memlet = edges[0].data + # Check that the correct memlet has the iteration variable + assert memlet.subset[0] == (iter_var, iter_var, 1) + assert memlet.subset[1] == (1, 10, 1) + + if __name__ == "__main__": test_fortran_frontend_array_3dmap() @@ -174,3 +223,4 @@ def test_fortran_frontend_input_output_connector(): test_fortran_frontend_input_output_connector() test_fortran_frontend_array_ranges() test_fortran_frontend_twoconnector() + test_fortran_frontend_memlet_in_map_test() diff --git a/tests/sdfg/data/struct_array_test.py b/tests/sdfg/data/struct_array_test.py new file mode 100644 index 0000000000..8e0f2f4739 --- /dev/null +++ b/tests/sdfg/data/struct_array_test.py @@ -0,0 +1,183 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import ctypes +import dace +import numpy as np + +from scipy import sparse + + +def test_read_struct_array(): + + L, M, N, nnz = (dace.symbol(s) for s in ('L', 'M', 'N', 'nnz')) + csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix') + csr_obj_view = dace.data.StructureView( + [('indptr', dace.int32[M + 1]), ('indices', dace.int32[nnz]), ('data', dace.float32[nnz])], + name='CSRMatrix', + transient=True) + + sdfg = dace.SDFG('array_of_csr_to_dense') + + sdfg.add_datadesc('A', csr_obj[L]) + sdfg.add_array('B', [L, M, N], dace.float32) + + sdfg.add_datadesc('vcsr', csr_obj_view) + sdfg.add_view('vindptr', csr_obj.members['indptr'].shape, csr_obj.members['indptr'].dtype) + sdfg.add_view('vindices', csr_obj.members['indices'].shape, csr_obj.members['indices'].dtype) + sdfg.add_view('vdata', csr_obj.members['data'].shape, csr_obj.members['data'].dtype) + + state = sdfg.add_state() + + A = state.add_access('A') + B = state.add_access('B') + + bme, bmx = state.add_map('b', dict(b='0:L')) + bme.map.schedule = dace.ScheduleType.Sequential + + vcsr = state.add_access('vcsr') + indptr = state.add_access('vindptr') + indices = state.add_access('vindices') + data = state.add_access('vdata') + + state.add_memlet_path(A, bme, vcsr, dst_conn='views', memlet=dace.Memlet(data='A', subset='b')) + state.add_edge(vcsr, None, indptr, 'views', memlet=dace.Memlet.from_array('vcsr.indptr', csr_obj.members['indptr'])) + state.add_edge(vcsr, None, indices, 'views', memlet=dace.Memlet.from_array('vcsr.indices', csr_obj.members['indices'])) + state.add_edge(vcsr, None, data, 'views', memlet=dace.Memlet.from_array('vcsr.data', csr_obj.members['data'])) + + ime, imx = state.add_map('i', dict(i='0:M')) + jme, jmx = state.add_map('idx', dict(idx='start:stop')) + jme.add_in_connector('start') + jme.add_in_connector('stop') + t = state.add_tasklet('indirection', {'j', '__val'}, {'__out'}, '__out[i, j] = __val') + + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='vindptr', subset='i'), dst_conn='start') + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='vindptr', subset='i+1'), dst_conn='stop') + state.add_memlet_path(indices, ime, jme, t, memlet=dace.Memlet(data='vindices', subset='idx'), dst_conn='j') + state.add_memlet_path(data, ime, jme, t, memlet=dace.Memlet(data='vdata', subset='idx'), dst_conn='__val') + state.add_memlet_path(t, jmx, imx, bmx, B, memlet=dace.Memlet(data='B', subset='b, 0:M, 0:N', volume=1), src_conn='__out') + + func = sdfg.compile() + + rng = np.random.default_rng(42) + A = np.ndarray((10,), dtype=sparse.csr_matrix) + dace_A = np.ndarray((10,), dtype=ctypes.c_void_p) + B = np.zeros((10, 20, 20), dtype=np.float32) + + ctypes_A = [] + for b in range(10): + A[b] = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + ctypes_obj = csr_obj.dtype._typeclass.as_ctypes()(indptr=A[b].indptr.__array_interface__['data'][0], + indices=A[b].indices.__array_interface__['data'][0], + data=A[b].data.__array_interface__['data'][0]) + ctypes_A.append(ctypes_obj) # This is needed to keep the object alive ... + dace_A[b] = ctypes.addressof(ctypes_obj) + + func(A=dace_A, B=B, L=A.shape[0], M=A[0].shape[0], N=A[0].shape[1], nnz=A[0].nnz) + ref = np.ndarray((10, 20, 20), dtype=np.float32) + for b in range(10): + ref[b] = A[b].toarray() + + assert np.allclose(B, ref) + + +def test_write_struct_array(): + + L, M, N, nnz = (dace.symbol(s) for s in ('L', 'M', 'N', 'nnz')) + csr_obj = dace.data.Structure( + [('indptr', dace.int32[M + 1]), ('indices', dace.int32[nnz]), ('data', dace.float32[nnz])], + name='CSRMatrix') + csr_obj_view = dace.data.StructureView( + dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix', + transient=True) + + sdfg = dace.SDFG('array_dense_to_csr') + + sdfg.add_array('A', [L, M, N], dace.float32) + sdfg.add_datadesc('B', csr_obj[L]) + + sdfg.add_datadesc('vcsr', csr_obj_view) + sdfg.add_view('vindptr', csr_obj.members['indptr'].shape, csr_obj.members['indptr'].dtype) + sdfg.add_view('vindices', csr_obj.members['indices'].shape, csr_obj.members['indices'].dtype) + sdfg.add_view('vdata', csr_obj.members['data'].shape, csr_obj.members['data'].dtype) + + # Make If + if_before = sdfg.add_state('if_before') + if_guard = sdfg.add_state('if_guard') + if_body = sdfg.add_state('if_body') + if_after = sdfg.add_state('if_after') + sdfg.add_edge(if_before, if_guard, dace.InterstateEdge()) + sdfg.add_edge(if_guard, if_body, dace.InterstateEdge(condition='A[k, i, j] != 0')) + sdfg.add_edge(if_body, if_after, dace.InterstateEdge(assignments={'idx': 'idx + 1'})) + sdfg.add_edge(if_guard, if_after, dace.InterstateEdge(condition='A[k, i, j] == 0')) + A = if_body.add_access('A') + vcsr = if_body.add_access('vcsr') + B = if_body.add_access('B') + indices = if_body.add_access('vindices') + data = if_body.add_access('vdata') + if_body.add_edge(A, None, data, None, dace.Memlet(data='A', subset='k, i, j', other_subset='idx')) + if_body.add_edge(data, 'views', vcsr, None, dace.Memlet(data='vcsr.data', subset='0:nnz')) + t = if_body.add_tasklet('set_indices', {}, {'__out'}, '__out = j') + if_body.add_edge(t, '__out', indices, None, dace.Memlet(data='vindices', subset='idx')) + if_body.add_edge(indices, 'views', vcsr, None, dace.Memlet(data='vcsr.indices', subset='0:nnz')) + if_body.add_edge(vcsr, 'views', B, None, dace.Memlet(data='B', subset='k')) + # Make For Loop for j + j_before, j_guard, j_after = sdfg.add_loop(None, + if_before, + None, + 'j', + '0', + 'j < N', + 'j + 1', + loop_end_state=if_after) + # Make For Loop for i + i_before, i_guard, i_after = sdfg.add_loop(None, j_before, None, 'i', '0', 'i < M', 'i + 1', loop_end_state=j_after) + sdfg.start_state = sdfg.node_id(i_before) + i_before_guard = sdfg.edges_between(i_before, i_guard)[0] + i_before_guard.data.assignments['idx'] = '0' + vcsr = i_guard.add_access('vcsr') + B = i_guard.add_access('B') + indptr = i_guard.add_access('vindptr') + t = i_guard.add_tasklet('set_indptr', {}, {'__out'}, '__out = idx') + i_guard.add_edge(t, '__out', indptr, None, dace.Memlet(data='vindptr', subset='i')) + i_guard.add_edge(indptr, 'views', vcsr, None, dace.Memlet(data='vcsr.indptr', subset='0:M+1')) + i_guard.add_edge(vcsr, 'views', B, None, dace.Memlet(data='B', subset='k')) + vcsr = i_after.add_access('vcsr') + B = i_after.add_access('B') + indptr = i_after.add_access('vindptr') + t = i_after.add_tasklet('set_indptr', {}, {'__out'}, '__out = nnz') + i_after.add_edge(t, '__out', indptr, None, dace.Memlet(data='vindptr', subset='M')) + i_after.add_edge(indptr, 'views', vcsr, None, dace.Memlet(data='vcsr.indptr', subset='0:M+1')) + i_after.add_edge(vcsr, 'views', B, None, dace.Memlet(data='B', subset='k')) + + k_before, k_guard, k_after = sdfg.add_loop(None, i_before, None, 'k', '0', 'k < L', 'k + 1', loop_end_state=i_after) + + func = sdfg.compile() + + rng = np.random.default_rng(42) + B = np.ndarray((10,), dtype=sparse.csr_matrix) + dace_B = np.ndarray((10,), dtype=ctypes.c_void_p) + A = np.empty((10, 20, 20), dtype=np.float32) + + ctypes_B = [] + for b in range(10): + B[b] = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + A[b] = B[b].toarray() + nnz = B[b].nnz + B[b].indptr[:] = -1 + B[b].indices[:] = -1 + B[b].data[:] = -1 + ctypes_obj = csr_obj.dtype._typeclass.as_ctypes()(indptr=B[b].indptr.__array_interface__['data'][0], + indices=B[b].indices.__array_interface__['data'][0], + data=B[b].data.__array_interface__['data'][0]) + ctypes_B.append(ctypes_obj) # This is needed to keep the object alive ... + dace_B[b] = ctypes.addressof(ctypes_obj) + + func(A=A, B=dace_B, L=B.shape[0], M=B[0].shape[0], N=B[0].shape[1], nnz=nnz) + for b in range(10): + assert np.allclose(A[b], B[b].toarray()) + + +if __name__ == '__main__': + test_read_struct_array() + test_write_struct_array() diff --git a/tests/sdfg/data/structure_test.py b/tests/sdfg/data/structure_test.py new file mode 100644 index 0000000000..02b8f0c174 --- /dev/null +++ b/tests/sdfg/data/structure_test.py @@ -0,0 +1,507 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +import numpy as np +import pytest + +from dace import serialize +from dace.properties import make_properties +from scipy import sparse + + +def test_read_structure(): + + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix') + + sdfg = dace.SDFG('csr_to_dense') + + sdfg.add_datadesc('A', csr_obj) + sdfg.add_array('B', [M, N], dace.float32) + + sdfg.add_view('vindptr', csr_obj.members['indptr'].shape, csr_obj.members['indptr'].dtype) + sdfg.add_view('vindices', csr_obj.members['indices'].shape, csr_obj.members['indices'].dtype) + sdfg.add_view('vdata', csr_obj.members['data'].shape, csr_obj.members['data'].dtype) + + state = sdfg.add_state() + + A = state.add_access('A') + B = state.add_access('B') + + indptr = state.add_access('vindptr') + indices = state.add_access('vindices') + data = state.add_access('vdata') + + state.add_edge(A, None, indptr, 'views', dace.Memlet.from_array('A.indptr', csr_obj.members['indptr'])) + state.add_edge(A, None, indices, 'views', dace.Memlet.from_array('A.indices', csr_obj.members['indices'])) + state.add_edge(A, None, data, 'views', dace.Memlet.from_array('A.data', csr_obj.members['data'])) + + ime, imx = state.add_map('i', dict(i='0:M')) + jme, jmx = state.add_map('idx', dict(idx='start:stop')) + jme.add_in_connector('start') + jme.add_in_connector('stop') + t = state.add_tasklet('indirection', {'j', '__val'}, {'__out'}, '__out[i, j] = __val') + + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='vindptr', subset='i'), dst_conn='start') + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='vindptr', subset='i+1'), dst_conn='stop') + state.add_memlet_path(indices, ime, jme, t, memlet=dace.Memlet(data='vindices', subset='idx'), dst_conn='j') + state.add_memlet_path(data, ime, jme, t, memlet=dace.Memlet(data='vdata', subset='idx'), dst_conn='__val') + state.add_memlet_path(t, jmx, imx, B, memlet=dace.Memlet(data='B', subset='0:M, 0:N', volume=1), src_conn='__out') + + func = sdfg.compile() + + rng = np.random.default_rng(42) + A = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + B = np.zeros((20, 20), dtype=np.float32) + + inpA = csr_obj.dtype._typeclass.as_ctypes()(indptr=A.indptr.__array_interface__['data'][0], + indices=A.indices.__array_interface__['data'][0], + data=A.data.__array_interface__['data'][0]) + + func(A=inpA, B=B, M=A.shape[0], N=A.shape[1], nnz=A.nnz) + ref = A.toarray() + + assert np.allclose(B, ref) + + +def test_write_structure(): + + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix') + + sdfg = dace.SDFG('dense_to_csr') + + sdfg.add_array('A', [M, N], dace.float32) + sdfg.add_datadesc('B', csr_obj) + + sdfg.add_view('vindptr', csr_obj.members['indptr'].shape, csr_obj.members['indptr'].dtype) + sdfg.add_view('vindices', csr_obj.members['indices'].shape, csr_obj.members['indices'].dtype) + sdfg.add_view('vdata', csr_obj.members['data'].shape, csr_obj.members['data'].dtype) + + # Make If + if_before = sdfg.add_state('if_before') + if_guard = sdfg.add_state('if_guard') + if_body = sdfg.add_state('if_body') + if_after = sdfg.add_state('if_after') + sdfg.add_edge(if_before, if_guard, dace.InterstateEdge()) + sdfg.add_edge(if_guard, if_body, dace.InterstateEdge(condition='A[i, j] != 0')) + sdfg.add_edge(if_body, if_after, dace.InterstateEdge(assignments={'idx': 'idx + 1'})) + sdfg.add_edge(if_guard, if_after, dace.InterstateEdge(condition='A[i, j] == 0')) + A = if_body.add_access('A') + B = if_body.add_access('B') + indices = if_body.add_access('vindices') + data = if_body.add_access('vdata') + if_body.add_edge(A, None, data, None, dace.Memlet(data='A', subset='i, j', other_subset='idx')) + if_body.add_edge(data, 'views', B, None, dace.Memlet(data='B.data', subset='0:nnz')) + t = if_body.add_tasklet('set_indices', {}, {'__out'}, '__out = j') + if_body.add_edge(t, '__out', indices, None, dace.Memlet(data='vindices', subset='idx')) + if_body.add_edge(indices, 'views', B, None, dace.Memlet(data='B.indices', subset='0:nnz')) + # Make For Loop for j + j_before, j_guard, j_after = sdfg.add_loop(None, + if_before, + None, + 'j', + '0', + 'j < N', + 'j + 1', + loop_end_state=if_after) + # Make For Loop for i + i_before, i_guard, i_after = sdfg.add_loop(None, j_before, None, 'i', '0', 'i < M', 'i + 1', loop_end_state=j_after) + sdfg.start_state = sdfg.node_id(i_before) + i_before_guard = sdfg.edges_between(i_before, i_guard)[0] + i_before_guard.data.assignments['idx'] = '0' + B = i_guard.add_access('B') + indptr = i_guard.add_access('vindptr') + t = i_guard.add_tasklet('set_indptr', {}, {'__out'}, '__out = idx') + i_guard.add_edge(t, '__out', indptr, None, dace.Memlet(data='vindptr', subset='i')) + i_guard.add_edge(indptr, 'views', B, None, dace.Memlet(data='B.indptr', subset='0:M+1')) + B = i_after.add_access('B') + indptr = i_after.add_access('vindptr') + t = i_after.add_tasklet('set_indptr', {}, {'__out'}, '__out = nnz') + i_after.add_edge(t, '__out', indptr, None, dace.Memlet(data='vindptr', subset='M')) + i_after.add_edge(indptr, 'views', B, None, dace.Memlet(data='B.indptr', subset='0:M+1')) + + func = sdfg.compile() + + rng = np.random.default_rng(42) + tmp = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + A = tmp.toarray() + B = tmp.tocsr(copy=True) + B.indptr[:] = -1 + B.indices[:] = -1 + B.data[:] = -1 + + outB = csr_obj.dtype._typeclass.as_ctypes()(indptr=B.indptr.__array_interface__['data'][0], + indices=B.indices.__array_interface__['data'][0], + data=B.data.__array_interface__['data'][0]) + + func(A=A, B=outB, M=tmp.shape[0], N=tmp.shape[1], nnz=tmp.nnz) + + assert np.allclose(A, B.toarray()) + + +def test_local_structure(): + + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix') + tmp_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix', + transient=True) + + sdfg = dace.SDFG('dense_to_csr_local') + + sdfg.add_array('A', [M, N], dace.float32) + sdfg.add_datadesc('B', csr_obj) + sdfg.add_datadesc('tmp', tmp_obj) + + sdfg.add_view('vindptr', csr_obj.members['indptr'].shape, csr_obj.members['indptr'].dtype) + sdfg.add_view('vindices', csr_obj.members['indices'].shape, csr_obj.members['indices'].dtype) + sdfg.add_view('vdata', csr_obj.members['data'].shape, csr_obj.members['data'].dtype) + + sdfg.add_view('tmp_vindptr', tmp_obj.members['indptr'].shape, tmp_obj.members['indptr'].dtype) + sdfg.add_view('tmp_vindices', tmp_obj.members['indices'].shape, tmp_obj.members['indices'].dtype) + sdfg.add_view('tmp_vdata', tmp_obj.members['data'].shape, tmp_obj.members['data'].dtype) + + # Make If + if_before = sdfg.add_state('if_before') + if_guard = sdfg.add_state('if_guard') + if_body = sdfg.add_state('if_body') + if_after = sdfg.add_state('if_after') + sdfg.add_edge(if_before, if_guard, dace.InterstateEdge()) + sdfg.add_edge(if_guard, if_body, dace.InterstateEdge(condition='A[i, j] != 0')) + sdfg.add_edge(if_body, if_after, dace.InterstateEdge(assignments={'idx': 'idx + 1'})) + sdfg.add_edge(if_guard, if_after, dace.InterstateEdge(condition='A[i, j] == 0')) + A = if_body.add_access('A') + tmp = if_body.add_access('tmp') + indices = if_body.add_access('tmp_vindices') + data = if_body.add_access('tmp_vdata') + if_body.add_edge(A, None, data, None, dace.Memlet(data='A', subset='i, j', other_subset='idx')) + if_body.add_edge(data, 'views', tmp, None, dace.Memlet(data='tmp.data', subset='0:nnz')) + t = if_body.add_tasklet('set_indices', {}, {'__out'}, '__out = j') + if_body.add_edge(t, '__out', indices, None, dace.Memlet(data='tmp_vindices', subset='idx')) + if_body.add_edge(indices, 'views', tmp, None, dace.Memlet(data='tmp.indices', subset='0:nnz')) + # Make For Loop for j + j_before, j_guard, j_after = sdfg.add_loop(None, + if_before, + None, + 'j', + '0', + 'j < N', + 'j + 1', + loop_end_state=if_after) + # Make For Loop for i + i_before, i_guard, i_after = sdfg.add_loop(None, j_before, None, 'i', '0', 'i < M', 'i + 1', loop_end_state=j_after) + sdfg.start_state = sdfg.node_id(i_before) + i_before_guard = sdfg.edges_between(i_before, i_guard)[0] + i_before_guard.data.assignments['idx'] = '0' + tmp = i_guard.add_access('tmp') + indptr = i_guard.add_access('tmp_vindptr') + t = i_guard.add_tasklet('set_indptr', {}, {'__out'}, '__out = idx') + i_guard.add_edge(t, '__out', indptr, None, dace.Memlet(data='tmp_vindptr', subset='i')) + i_guard.add_edge(indptr, 'views', tmp, None, dace.Memlet(data='tmp.indptr', subset='0:M+1')) + tmp = i_after.add_access('tmp') + indptr = i_after.add_access('tmp_vindptr') + t = i_after.add_tasklet('set_indptr', {}, {'__out'}, '__out = nnz') + i_after.add_edge(t, '__out', indptr, None, dace.Memlet(data='tmp_vindptr', subset='M')) + i_after.add_edge(indptr, 'views', tmp, None, dace.Memlet(data='tmp.indptr', subset='0:M+1')) + + set_B = sdfg.add_state('set_B') + sdfg.add_edge(i_after, set_B, dace.InterstateEdge()) + tmp = set_B.add_access('tmp') + tmp_indptr = set_B.add_access('tmp_vindptr') + tmp_indices = set_B.add_access('tmp_vindices') + tmp_data = set_B.add_access('tmp_vdata') + set_B.add_edge(tmp, None, tmp_indptr, 'views', dace.Memlet(data='tmp.indptr', subset='0:M+1')) + set_B.add_edge(tmp, None, tmp_indices, 'views', dace.Memlet(data='tmp.indices', subset='0:nnz')) + set_B.add_edge(tmp, None, tmp_data, 'views', dace.Memlet(data='tmp.data', subset='0:nnz')) + B = set_B.add_access('B') + B_indptr = set_B.add_access('vindptr') + B_indices = set_B.add_access('vindices') + B_data = set_B.add_access('vdata') + set_B.add_edge(B_indptr, 'views', B, None, dace.Memlet(data='B.indptr', subset='0:M+1')) + set_B.add_edge(B_indices, 'views', B, None, dace.Memlet(data='B.indices', subset='0:nnz')) + set_B.add_edge(B_data, 'views', B, None, dace.Memlet(data='B.data', subset='0:nnz')) + set_B.add_edge(tmp_indptr, None, B_indptr, None, dace.Memlet(data='tmp_vindptr', subset='0:M+1')) + set_B.add_edge(tmp_indices, None, B_indices, None, dace.Memlet(data='tmp_vindices', subset='0:nnz')) + t, me, mx = set_B.add_mapped_tasklet('set_data', {'idx': '0:nnz'}, + {'__inp': dace.Memlet(data='tmp_vdata', subset='idx')}, + '__out = 2 * __inp', {'__out': dace.Memlet(data='vdata', subset='idx')}, + external_edges=True, + input_nodes={'tmp_vdata': tmp_data}, + output_nodes={'vdata': B_data}) + + func = sdfg.compile() + + rng = np.random.default_rng(42) + tmp = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + A = tmp.toarray() + B = tmp.tocsr(copy=True) + B.indptr[:] = -1 + B.indices[:] = -1 + B.data[:] = -1 + + outB = csr_obj.dtype._typeclass.as_ctypes()(indptr=B.indptr.__array_interface__['data'][0], + indices=B.indices.__array_interface__['data'][0], + data=B.data.__array_interface__['data'][0]) + + func(A=A, B=outB, M=tmp.shape[0], N=tmp.shape[1], nnz=tmp.nnz) + + assert np.allclose(A * 2, B.toarray()) + + +def test_read_nested_structure(): + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix') + wrapper_obj = dace.data.Structure(dict(csr=csr_obj), name='Wrapper') + + sdfg = dace.SDFG('nested_csr_to_dense') + + sdfg.add_datadesc('A', wrapper_obj) + sdfg.add_array('B', [M, N], dace.float32) + + spmat = wrapper_obj.members['csr'] + sdfg.add_view('vindptr', spmat.members['indptr'].shape, spmat.members['indptr'].dtype) + sdfg.add_view('vindices', spmat.members['indices'].shape, spmat.members['indices'].dtype) + sdfg.add_view('vdata', spmat.members['data'].shape, spmat.members['data'].dtype) + + state = sdfg.add_state() + + A = state.add_access('A') + B = state.add_access('B') + + indptr = state.add_access('vindptr') + indices = state.add_access('vindices') + data = state.add_access('vdata') + + state.add_edge(A, None, indptr, 'views', dace.Memlet.from_array('A.csr.indptr', spmat.members['indptr'])) + state.add_edge(A, None, indices, 'views', dace.Memlet.from_array('A.csr.indices', spmat.members['indices'])) + state.add_edge(A, None, data, 'views', dace.Memlet.from_array('A.csr.data', spmat.members['data'])) + + ime, imx = state.add_map('i', dict(i='0:M')) + jme, jmx = state.add_map('idx', dict(idx='start:stop')) + jme.add_in_connector('start') + jme.add_in_connector('stop') + t = state.add_tasklet('indirection', {'j', '__val'}, {'__out'}, '__out[i, j] = __val') + + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='vindptr', subset='i'), dst_conn='start') + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='vindptr', subset='i+1'), dst_conn='stop') + state.add_memlet_path(indices, ime, jme, t, memlet=dace.Memlet(data='vindices', subset='idx'), dst_conn='j') + state.add_memlet_path(data, ime, jme, t, memlet=dace.Memlet(data='vdata', subset='idx'), dst_conn='__val') + state.add_memlet_path(t, jmx, imx, B, memlet=dace.Memlet(data='B', subset='0:M, 0:N', volume=1), src_conn='__out') + + func = sdfg.compile() + + rng = np.random.default_rng(42) + A = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + B = np.zeros((20, 20), dtype=np.float32) + + structclass = csr_obj.dtype._typeclass.as_ctypes() + inpCSR = structclass(indptr=A.indptr.__array_interface__['data'][0], + indices=A.indices.__array_interface__['data'][0], + data=A.data.__array_interface__['data'][0]) + import ctypes + inpW = wrapper_obj.dtype._typeclass.as_ctypes()(csr=ctypes.pointer(inpCSR)) + + func(A=inpW, B=B, M=A.shape[0], N=A.shape[1], nnz=A.nnz) + ref = A.toarray() + + assert np.allclose(B, ref) + + +def test_write_nested_structure(): + + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix') + wrapper_obj = dace.data.Structure(dict(csr=csr_obj), name='Wrapper') + + sdfg = dace.SDFG('dense_to_csr') + + sdfg.add_array('A', [M, N], dace.float32) + sdfg.add_datadesc('B', wrapper_obj) + + spmat = wrapper_obj.members['csr'] + sdfg.add_view('vindptr', spmat.members['indptr'].shape, spmat.members['indptr'].dtype) + sdfg.add_view('vindices', spmat.members['indices'].shape, spmat.members['indices'].dtype) + sdfg.add_view('vdata', spmat.members['data'].shape, spmat.members['data'].dtype) + + # Make If + if_before = sdfg.add_state('if_before') + if_guard = sdfg.add_state('if_guard') + if_body = sdfg.add_state('if_body') + if_after = sdfg.add_state('if_after') + sdfg.add_edge(if_before, if_guard, dace.InterstateEdge()) + sdfg.add_edge(if_guard, if_body, dace.InterstateEdge(condition='A[i, j] != 0')) + sdfg.add_edge(if_body, if_after, dace.InterstateEdge(assignments={'idx': 'idx + 1'})) + sdfg.add_edge(if_guard, if_after, dace.InterstateEdge(condition='A[i, j] == 0')) + A = if_body.add_access('A') + B = if_body.add_access('B') + indices = if_body.add_access('vindices') + data = if_body.add_access('vdata') + if_body.add_edge(A, None, data, None, dace.Memlet(data='A', subset='i, j', other_subset='idx')) + if_body.add_edge(data, 'views', B, None, dace.Memlet(data='B.csr.data', subset='0:nnz')) + t = if_body.add_tasklet('set_indices', {}, {'__out'}, '__out = j') + if_body.add_edge(t, '__out', indices, None, dace.Memlet(data='vindices', subset='idx')) + if_body.add_edge(indices, 'views', B, None, dace.Memlet(data='B.csr.indices', subset='0:nnz')) + # Make For Loop for j + j_before, j_guard, j_after = sdfg.add_loop(None, + if_before, + None, + 'j', + '0', + 'j < N', + 'j + 1', + loop_end_state=if_after) + # Make For Loop for i + i_before, i_guard, i_after = sdfg.add_loop(None, j_before, None, 'i', '0', 'i < M', 'i + 1', loop_end_state=j_after) + sdfg.start_state = sdfg.node_id(i_before) + i_before_guard = sdfg.edges_between(i_before, i_guard)[0] + i_before_guard.data.assignments['idx'] = '0' + B = i_guard.add_access('B') + indptr = i_guard.add_access('vindptr') + t = i_guard.add_tasklet('set_indptr', {}, {'__out'}, '__out = idx') + i_guard.add_edge(t, '__out', indptr, None, dace.Memlet(data='vindptr', subset='i')) + i_guard.add_edge(indptr, 'views', B, None, dace.Memlet(data='B.csr.indptr', subset='0:M+1')) + B = i_after.add_access('B') + indptr = i_after.add_access('vindptr') + t = i_after.add_tasklet('set_indptr', {}, {'__out'}, '__out = nnz') + i_after.add_edge(t, '__out', indptr, None, dace.Memlet(data='vindptr', subset='M')) + i_after.add_edge(indptr, 'views', B, None, dace.Memlet(data='B.csr.indptr', subset='0:M+1')) + + func = sdfg.compile() + + rng = np.random.default_rng(42) + tmp = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + A = tmp.toarray() + B = tmp.tocsr(copy=True) + B.indptr[:] = -1 + B.indices[:] = -1 + B.data[:] = -1 + + outCSR = csr_obj.dtype._typeclass.as_ctypes()(indptr=B.indptr.__array_interface__['data'][0], + indices=B.indices.__array_interface__['data'][0], + data=B.data.__array_interface__['data'][0]) + import ctypes + outW = wrapper_obj.dtype._typeclass.as_ctypes()(csr=ctypes.pointer(outCSR)) + + func(A=A, B=outW, M=tmp.shape[0], N=tmp.shape[1], nnz=tmp.nnz) + + assert np.allclose(A, B.toarray()) + + +def test_direct_read_structure(): + + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix') + + sdfg = dace.SDFG('csr_to_dense_direct') + + sdfg.add_datadesc('A', csr_obj) + sdfg.add_array('B', [M, N], dace.float32) + + state = sdfg.add_state() + + indptr = state.add_access('A.indptr') + indices = state.add_access('A.indices') + data = state.add_access('A.data') + B = state.add_access('B') + + ime, imx = state.add_map('i', dict(i='0:M')) + jme, jmx = state.add_map('idx', dict(idx='start:stop')) + jme.add_in_connector('start') + jme.add_in_connector('stop') + t = state.add_tasklet('indirection', {'j', '__val'}, {'__out'}, '__out[i, j] = __val') + + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='A.indptr', subset='i'), dst_conn='start') + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='A.indptr', subset='i+1'), dst_conn='stop') + state.add_memlet_path(indices, ime, jme, t, memlet=dace.Memlet(data='A.indices', subset='idx'), dst_conn='j') + state.add_memlet_path(data, ime, jme, t, memlet=dace.Memlet(data='A.data', subset='idx'), dst_conn='__val') + state.add_memlet_path(t, jmx, imx, B, memlet=dace.Memlet(data='B', subset='0:M, 0:N', volume=1), src_conn='__out') + + func = sdfg.compile() + + rng = np.random.default_rng(42) + A = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + B = np.zeros((20, 20), dtype=np.float32) + + inpA = csr_obj.dtype._typeclass.as_ctypes()(indptr=A.indptr.__array_interface__['data'][0], + indices=A.indices.__array_interface__['data'][0], + data=A.data.__array_interface__['data'][0], + rows=A.shape[0], + cols=A.shape[1], + M=A.shape[0], + N=A.shape[1], + nnz=A.nnz) + + func(A=inpA, B=B, M=20, N=20, nnz=A.nnz) + ref = A.toarray() + + assert np.allclose(B, ref) + + +def test_direct_read_nested_structure(): + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix') + wrapper_obj = dace.data.Structure(dict(csr=csr_obj), name='Wrapper') + + sdfg = dace.SDFG('nested_csr_to_dense_direct') + + sdfg.add_datadesc('A', wrapper_obj) + sdfg.add_array('B', [M, N], dace.float32) + + spmat = wrapper_obj.members['csr'] + sdfg.add_view('vindptr', spmat.members['indptr'].shape, spmat.members['indptr'].dtype) + sdfg.add_view('vindices', spmat.members['indices'].shape, spmat.members['indices'].dtype) + sdfg.add_view('vdata', spmat.members['data'].shape, spmat.members['data'].dtype) + + state = sdfg.add_state() + + indptr = state.add_access('A.csr.indptr') + indices = state.add_access('A.csr.indices') + data = state.add_access('A.csr.data') + B = state.add_access('B') + + ime, imx = state.add_map('i', dict(i='0:M')) + jme, jmx = state.add_map('idx', dict(idx='start:stop')) + jme.add_in_connector('start') + jme.add_in_connector('stop') + t = state.add_tasklet('indirection', {'j', '__val'}, {'__out'}, '__out[i, j] = __val') + + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='A.csr.indptr', subset='i'), dst_conn='start') + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='A.csr.indptr', subset='i+1'), dst_conn='stop') + state.add_memlet_path(indices, ime, jme, t, memlet=dace.Memlet(data='A.csr.indices', subset='idx'), dst_conn='j') + state.add_memlet_path(data, ime, jme, t, memlet=dace.Memlet(data='A.csr.data', subset='idx'), dst_conn='__val') + state.add_memlet_path(t, jmx, imx, B, memlet=dace.Memlet(data='B', subset='0:M, 0:N', volume=1), src_conn='__out') + + func = sdfg.compile() + + rng = np.random.default_rng(42) + A = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + B = np.zeros((20, 20), dtype=np.float32) + + structclass = csr_obj.dtype._typeclass.as_ctypes() + inpCSR = structclass(indptr=A.indptr.__array_interface__['data'][0], + indices=A.indices.__array_interface__['data'][0], + data=A.data.__array_interface__['data'][0]) + import ctypes + inpW = wrapper_obj.dtype._typeclass.as_ctypes()(csr=ctypes.pointer(inpCSR)) + + func(A=inpW, B=B, M=A.shape[0], N=A.shape[1], nnz=A.nnz) + ref = A.toarray() + + assert np.allclose(B, ref) + + +if __name__ == "__main__": + test_read_structure() + test_write_structure() + test_local_structure() + test_read_nested_structure() + test_write_nested_structure() + test_direct_read_structure() + test_direct_read_nested_structure() diff --git a/tests/sdfg/work_depth_tests.py b/tests/sdfg/work_depth_tests.py new file mode 100644 index 0000000000..133afe8ae4 --- /dev/null +++ b/tests/sdfg/work_depth_tests.py @@ -0,0 +1,201 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Contains test cases for the work depth analysis. """ +import dace as dc +from dace.sdfg.work_depth_analysis.work_depth import analyze_sdfg, get_tasklet_work_depth +from dace.sdfg.work_depth_analysis.helpers import get_uuid +import sympy as sp + +from dace.transformation.interstate import NestSDFG +from dace.transformation.dataflow import MapExpansion + +# TODO: add tests for library nodes (e.g. reduce, matMul) + +N = dc.symbol('N') +M = dc.symbol('M') +K = dc.symbol('K') + + +@dc.program +def single_map(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): + z[:] = x + y + + +@dc.program +def single_for_loop(x: dc.float64[N], y: dc.float64[N]): + for i in range(N): + x[i] += y[i] + + +@dc.program +def if_else(x: dc.int64[1000], y: dc.int64[1000], z: dc.int64[1000], sum: dc.int64[1]): + if x[10] > 50: + z[:] = x + y # 1000 work, 1 depth + else: + for i in range(100): # 100 work, 100 depth + sum += x[i] + + +@dc.program +def if_else_sym(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], sum: dc.int64[1]): + if x[10] > 50: + z[:] = x + y # N work, 1 depth + else: + for i in range(K): # K work, K depth + sum += x[i] + + +@dc.program +def nested_sdfg(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): + single_map(x, y, z) + single_for_loop(x, y) + + +@dc.program +def nested_maps(x: dc.float64[N, M], y: dc.float64[N, M], z: dc.float64[N, M]): + z[:, :] = x + y + + +@dc.program +def nested_for_loops(x: dc.float64[N], y: dc.float64[K]): + for i in range(N): + for j in range(K): + x[i] += y[j] + + +@dc.program +def nested_if_else(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], sum: dc.int64[1]): + if x[10] > 50: + if x[9] > 50: + z[:] = x + y # N work, 1 depth + z[:] += 2 * x # 2*N work, 2 depth --> total outer if: 3*N work, 3 depth + else: + if y[9] > 50: + for i in range(K): + sum += x[i] # K work, K depth + else: + for j in range(M): + sum += x[j] # M work, M depth + z[:] = x + y # N work, depth 1 --> total inner else: M+N work, M+1 depth + # --> total outer else: Max(K, M+N) work, Max(K, M+1) depth + # --> total over both branches: Max(K, M+N, 3*N) work, Max(K, M+1, 3) depth + + +@dc.program +def max_of_positive_symbol(x: dc.float64[N]): + if x[0] > 0: + for i in range(2 * N): # work 2*N^2, depth 2*N + x += 1 + else: + for j in range(3 * N): # work 3*N^2, depth 3*N + x += 1 + # total is work 3*N^2, depth 3*N without any max + + +@dc.program +def multiple_array_sizes(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], x2: dc.int64[M], y2: dc.int64[M], + z2: dc.int64[M], x3: dc.int64[K], y3: dc.int64[K], z3: dc.int64[K]): + if x[0] > 0: + z[:] = 2 * x + y # work 2*N, depth 2 + elif x[1] > 0: + z2[:] = 2 * x2 + y2 # work 2*M + 3, depth 5 + z2[0] += 3 + z[1] + z[2] + elif x[2] > 0: + z3[:] = 2 * x3 + y3 # work 2*K, depth 2 + elif x[3] > 0: + z[:] = 3 * x + y + 1 # work 3*N, depth 3 + # --> work= Max(3*N, 2*M, 2*K) and depth = 5 + + +@dc.program +def unbounded_while_do(x: dc.float64[N]): + while x[0] < 100: + x += 1 + + +@dc.program +def unbounded_do_while(x: dc.float64[N]): + while True: + x += 1 + if x[0] >= 100: + break + + +@dc.program +def unbounded_nonnegify(x: dc.float64[N]): + while x[0] < 100: + if x[1] < 42: + x += 3 * x + else: + x += x + + +@dc.program +def continue_for_loop(x: dc.float64[N]): + for i in range(N): + if x[i] > 100: + continue + x += 1 + + +@dc.program +def break_for_loop(x: dc.float64[N]): + for i in range(N): + if x[i] > 100: + break + x += 1 + + +@dc.program +def break_while_loop(x: dc.float64[N]): + while x[0] > 10: + if x[1] > 100: + break + x += 1 + + +tests_cases = [ + (single_map, (N, 1)), + (single_for_loop, (N, N)), + (if_else, (1000, 100)), + (if_else_sym, (sp.Max(K, N), sp.Max(1, K))), + (nested_sdfg, (2 * N, N + 1)), + (nested_maps, (M * N, 1)), + (nested_for_loops, (K * N, K * N)), + (nested_if_else, (sp.Max(K, 3 * N, M + N), sp.Max(3, K, M + 1))), + (max_of_positive_symbol, (3 * N**2, 3 * N)), + (multiple_array_sizes, (sp.Max(2 * K, 3 * N, 2 * M + 3), 5)), + (unbounded_while_do, (sp.Symbol('num_execs_0_2', nonnegative=True) * N, sp.Symbol('num_execs_0_2', + nonnegative=True))), + # We get this Max(1, num_execs), since it is a do-while loop, but the num_execs symbol does not capture this. + (unbounded_do_while, (sp.Max(1, sp.Symbol('num_execs_0_1', nonnegative=True)) * N, + sp.Max(1, sp.Symbol('num_execs_0_1', nonnegative=True)))), + (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7', nonnegative=True) * N, + 2 * sp.Symbol('num_execs_0_7', nonnegative=True))), + (continue_for_loop, (sp.Symbol('num_execs_0_6', nonnegative=True) * N, sp.Symbol('num_execs_0_6', + nonnegative=True))), + (break_for_loop, (N**2, N)), + (break_while_loop, (sp.Symbol('num_execs_0_5', nonnegative=True) * N, sp.Symbol('num_execs_0_5', nonnegative=True))) +] + + +def test_work_depth(): + good = 0 + failed = 0 + exception = 0 + failed_tests = [] + for test, correct in tests_cases: + w_d_map = {} + sdfg = test.to_sdfg() + if 'nested_sdfg' in test.name: + sdfg.apply_transformations(NestSDFG) + if 'nested_maps' in test.name: + sdfg.apply_transformations(MapExpansion) + + analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth) + res = w_d_map[get_uuid(sdfg)] + # check result + assert correct == res + + +if __name__ == '__main__': + test_work_depth()