From 15fb33ca8867bd433b12249fe34790a9aaa58acb Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Mon, 13 Nov 2023 18:54:22 +0100 Subject: [PATCH] Clean up. --- dace/codegen/targets/cpp.py | 6 +- dace/codegen/targets/cpu.py | 3 + dace/codegen/targets/cuda.py | 10 +-- dace/codegen/targets/framecode.py | 3 +- dace/frontend/python/newast.py | 47 +++++++------- dace/frontend/python/replacements.py | 92 +++++++++++++++++----------- dace/sdfg/sdfg.py | 19 +++--- 7 files changed, 101 insertions(+), 79 deletions(-) diff --git a/dace/codegen/targets/cpp.py b/dace/codegen/targets/cpp.py index 68df157269..b0d7c2779e 100644 --- a/dace/codegen/targets/cpp.py +++ b/dace/codegen/targets/cpp.py @@ -61,7 +61,8 @@ def copy_expr( packed_types=False, ): data_desc = sdfg.arrays[data_name] - # TODO: Are there any cases where a mix of '.' and '->' is needed when traversing nested structs? + # NOTE: Are there any cases where a mix of '.' and '->' is needed when traversing nested structs? + # TODO: Study this when changing Structures to be (optionally?) non-pointers. tokens = data_name.split('.') if len(tokens) > 1 and tokens[0] in sdfg.arrays and isinstance(sdfg.arrays[tokens[0]], data.Structure): name = data_name.replace('.', '->') @@ -585,7 +586,8 @@ def cpp_array_expr(sdfg, desc = (sdfg.arrays[memlet.data] if referenced_array is None else referenced_array) offset_cppstr = cpp_offset_expr(desc, s, o, packed_veclen, indices=indices) - # TODO: Are there any cases where a mix of '.' and '->' is needed when traversing nested structs? + # NOTE: Are there any cases where a mix of '.' and '->' is needed when traversing nested structs? + # TODO: Study this when changing Structures to be (optionally?) non-pointers. tokens = memlet.data.split('.') if len(tokens) > 1 and tokens[0] in sdfg.arrays and isinstance(sdfg.arrays[tokens[0]], data.Structure): name = memlet.data.replace('.', '->') diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index 8feb7184ff..c2b79fb8e6 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -309,6 +309,8 @@ def allocate_array(self, sdfg, dfg, state_id, node, nodedesc, function_stream, d tokens = node.data.split('.') top_desc = sdfg.arrays[tokens[0]] + # NOTE: Assuming here that all Structure members share transient/storage/lifetime properties. + # TODO: Study what is needed in the DaCe stuck to ensure this assumption is correct. top_transient = top_desc.transient top_storage = top_desc.storage top_lifetime = top_desc.lifetime @@ -644,6 +646,7 @@ def _emit_copy( ############################################# # Corner cases + # NOTE: This looks obsolete but keeping it commented out in case tests fail. # Writing one index # if (isinstance(memlet.subset, subsets.Indices) and memlet.wcr is None # and self._dispatcher.defined_vars.get(vconn)[0] == DefinedType.Scalar): diff --git a/dace/codegen/targets/cuda.py b/dace/codegen/targets/cuda.py index fd2a7e0c67..ad4aae8522 100644 --- a/dace/codegen/targets/cuda.py +++ b/dace/codegen/targets/cuda.py @@ -202,12 +202,6 @@ def preprocess(self, sdfg: SDFG) -> None: and node.map.schedule in (dtypes.ScheduleType.GPU_Device, dtypes.ScheduleType.GPU_Persistent)): if state.parent not in shared_transients: shared_transients[state.parent] = state.parent.shared_transients() - # sgraph = state.scope_subgraph(node) - # used_symbols = sgraph.used_symbols(all_symbols=False) - # arglist = sgraph.arglist(defined_syms, shared_transients[state.parent]) - # arglist = {k: v for k, v in arglist.items() if not k in defined_syms or k in used_symbols} - # self._arglists[node] = arglist - # TODO/NOTE: Did we change defined_syms? self._arglists[node] = state.scope_subgraph(node).arglist(defined_syms, shared_transients[state.parent]) def _compute_pool_release(self, top_sdfg: SDFG): @@ -1029,11 +1023,11 @@ def _emit_copy(self, state_id, src_node, src_storage, dst_node, dst_storage, dst if issubclass(node_dtype.type, ctypes.Structure): callsite_stream.write('for (size_t __idx = 0; __idx < {arrlen}; ++__idx) ' '{{'.format(arrlen=array_length)) - # for field_name, field_type in node_dtype._data.items(): + # TODO: Study further when tackling Structures on GPU. for field_name, field_type in node_dtype._typeclass.fields.items(): if isinstance(field_type, dtypes.pointer): tclass = field_type.type - # length = node_dtype._length[field_name] + length = node_dtype._typeclass._length[field_name] size = 'sizeof({})*{}[__idx].{}'.format(dtypes._CTYPES[tclass], str(src_node), length) callsite_stream.write('DACE_GPU_CHECK({backend}Malloc(&{dst}[__idx].{fname}, ' diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 32e37eb24f..eb6bbd5750 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -565,7 +565,8 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG): access_instances[sdfg.sdfg_id] = instances for sdfg, name, desc in top_sdfg.arrays_recursive(include_nested_data=True): - # NOTE/TODO: Temporary fix for nested data not having the same attributes as their parent + # NOTE: Assuming here that all Structure members share transient/storage/lifetime properties. + # TODO: Study what is needed in the DaCe stuck to ensure this assumption is correct. top_desc = sdfg.arrays[name.split('.')[0]] top_transient = top_desc.transient top_storage = top_desc.storage diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index c5fe0e6134..ce62535c50 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -49,7 +49,6 @@ Shape = Union[ShapeTuple, ShapeList] DependencyType = Dict[str, Tuple[SDFGState, Union[Memlet, nodes.Tasklet], Tuple[int]]] - if sys.version_info < (3, 8): _simple_ast_nodes = (ast.Constant, ast.Name, ast.NameConstant, ast.Num) BytesConstant = ast.Bytes @@ -65,7 +64,6 @@ NumConstant = ast.Constant StrConstant = ast.Constant - if sys.version_info < (3, 9): Index = ast.Index ExtSlice = ast.ExtSlice @@ -73,7 +71,6 @@ Index = type(None) ExtSlice = type(None) - if sys.version_info < (3, 12): TypeAlias = type(None) else: @@ -452,10 +449,11 @@ def add_indirection_subgraph(sdfg: SDFG, for i, r in enumerate(memlet.subset): if i in nonsqz_dims: mapped_rng.append(r) - ind_entry, ind_exit = graph.add_map( - 'indirection', {'__i%d' % i: '%s:%s+1:%s' % (s, e, t) - for i, (s, e, t) in enumerate(mapped_rng)}, - debuginfo=pvisitor.current_lineinfo) + ind_entry, ind_exit = graph.add_map('indirection', { + '__i%d' % i: '%s:%s+1:%s' % (s, e, t) + for i, (s, e, t) in enumerate(mapped_rng) + }, + debuginfo=pvisitor.current_lineinfo) inp_base_path.insert(0, ind_entry) out_base_path.append(ind_exit) @@ -1339,9 +1337,10 @@ def defined(self): result.update(self.sdfg.arrays) # MPI-related stuff - result.update( - {k: self.sdfg.process_grids[v] - for k, v in self.variables.items() if v in self.sdfg.process_grids}) + result.update({ + k: self.sdfg.process_grids[v] + for k, v in self.variables.items() if v in self.sdfg.process_grids + }) try: from mpi4py import MPI result.update({k: v for k, v in self.globals.items() if isinstance(v, MPI.Comm)}) @@ -3218,8 +3217,9 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): if (not is_return and isinstance(target, ast.Name) and true_name and not op and not isinstance(true_array, data.Scalar) and not (true_array.shape == (1, ))): if true_name in self.views: - if result in self.sdfg.arrays and self.views[true_name] == ( - result, Memlet.from_array(result, self.sdfg.arrays[result])): + if result in self.sdfg.arrays and self.views[true_name] == (result, + Memlet.from_array( + result, self.sdfg.arrays[result])): continue else: raise DaceSyntaxError(self, target, 'Cannot reassign View "{}"'.format(name)) @@ -3762,14 +3762,12 @@ def _parse_sdfg_call(self, funcname: str, func: Union[SDFG, SDFGConvertible], no from dace.frontend.python.parser import infer_symbols_from_datadescriptor # Map internal SDFG symbols by adding keyword arguments - # symbols = set(sdfg.symbols.keys()) - # symbols = sdfg.free_symbols symbols = sdfg.used_symbols(all_symbols=False) try: - mapping = infer_symbols_from_datadescriptor( - sdfg, {k: self.sdfg.arrays[v] - for k, v in args if v in self.sdfg.arrays}, - set(sym.arg for sym in node.keywords if sym.arg in symbols)) + mapping = infer_symbols_from_datadescriptor(sdfg, { + k: self.sdfg.arrays[v] + for k, v in args if v in self.sdfg.arrays + }, set(sym.arg for sym in node.keywords if sym.arg in symbols)) except ValueError as ex: raise DaceSyntaxError(self, node, str(ex)) if len(mapping) == 0: # Default to same-symbol mapping @@ -4733,7 +4731,7 @@ def visit_Dict(self, node: ast.Dict): def visit_Lambda(self, node: ast.Lambda): # Return a string representation of the function return astutils.unparse(node) - + def visit_TypeAlias(self, node: TypeAlias): raise NotImplementedError('Type aliases are not supported in DaCe') @@ -4922,11 +4920,12 @@ def _add_read_slice(self, array: str, node: ast.Subscript, expr: MemletExpr): # NOTE: We convert the subsets to string because keeping the original symbolic information causes # equality check failures, e.g., in LoopToMap. self.last_state.add_nedge( - rnode, wnode, Memlet(data=array, - subset=str(expr.subset), - other_subset=str(other_subset), - volume=expr.accesses, - wcr=expr.wcr)) + rnode, wnode, + Memlet(data=array, + subset=str(expr.subset), + other_subset=str(other_subset), + volume=expr.accesses, + wcr=expr.wcr)) return tmp def _parse_subscript_slice(self, diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index 92d76b21a2..4775c572b5 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -93,9 +93,9 @@ def _define_local_structure(pv: ProgramVisitor, """ Defines a local structure in a DaCe program. """ name = sdfg.temp_data_name() desc = copy.deepcopy(dtype) - desc.transient=True - desc.storage=storage - desc.lifetime=lifetime + desc.transient = True + desc.storage = storage + desc.lifetime = lifetime sdfg.add_datadesc(name, desc) pv.variables[name] = name return name @@ -318,16 +318,20 @@ def _numpy_full(pv: ProgramVisitor, if is_data: state.add_mapped_tasklet( - '_numpy_full_', {"__i{}".format(i): "0: {}".format(s) - for i, s in enumerate(shape)}, + '_numpy_full_', { + "__i{}".format(i): "0: {}".format(s) + for i, s in enumerate(shape) + }, dict(__inp=dace.Memlet(data=fill_value, subset='0')), "__out = __inp", dict(__out=dace.Memlet.simple(name, ",".join(["__i{}".format(i) for i in range(len(shape))]))), external_edges=True) else: state.add_mapped_tasklet( - '_numpy_full_', {"__i{}".format(i): "0: {}".format(s) - for i, s in enumerate(shape)}, {}, + '_numpy_full_', { + "__i{}".format(i): "0: {}".format(s) + for i, s in enumerate(shape) + }, {}, "__out = {}".format(fill_value), dict(__out=dace.Memlet.simple(name, ",".join(["__i{}".format(i) for i in range(len(shape))]))), external_edges=True) @@ -447,8 +451,10 @@ def _numpy_flip(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, axis inpidx = ','.join([f'__i{i}' for i in range(ndim)]) outidx = ','.join([f'{s} - __i{i} - 1' if a else f'__i{i}' for i, (a, s) in enumerate(zip(axis, desc.shape))]) state.add_mapped_tasklet(name="_numpy_flip_", - map_ranges={f'__i{i}': f'0:{s}:1' - for i, s in enumerate(desc.shape)}, + map_ranges={ + f'__i{i}': f'0:{s}:1' + for i, s in enumerate(desc.shape) + }, inputs={'__inp': Memlet(f'{arr}[{inpidx}]')}, code='__out = __inp', outputs={'__out': Memlet(f'{arr_copy}[{outidx}]')}, @@ -518,8 +524,10 @@ def _numpy_rot90(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, k=1 outidx = ','.join(out_indices) state.add_mapped_tasklet(name="_rot90_", - map_ranges={f'__i{i}': f'0:{s}:1' - for i, s in enumerate(desc.shape)}, + map_ranges={ + f'__i{i}': f'0:{s}:1' + for i, s in enumerate(desc.shape) + }, inputs={'__inp': Memlet(f'{arr}[{inpidx}]')}, code='__out = __inp', outputs={'__out': Memlet(f'{arr_copy}[{outidx}]')}, @@ -623,8 +631,10 @@ def _elementwise(pv: 'ProgramVisitor', else: state.add_mapped_tasklet( name="_elementwise_", - map_ranges={'__i%d' % i: '0:%s' % n - for i, n in enumerate(inparr.shape)}, + map_ranges={ + '__i%d' % i: '0:%s' % n + for i, n in enumerate(inparr.shape) + }, inputs={'__inp': Memlet.simple(in_array, ','.join(['__i%d' % i for i in range(len(inparr.shape))]))}, code=code, outputs={'__out': Memlet.simple(out_array, ','.join(['__i%d' % i for i in range(len(inparr.shape))]))}, @@ -674,8 +684,10 @@ def _simple_call(sdfg: SDFG, state: SDFGState, inpname: str, func: str, restype: else: state.add_mapped_tasklet( name=func, - map_ranges={'__i%d' % i: '0:%s' % n - for i, n in enumerate(inparr.shape)}, + map_ranges={ + '__i%d' % i: '0:%s' % n + for i, n in enumerate(inparr.shape) + }, inputs={'__inp': Memlet.simple(inpname, ','.join(['__i%d' % i for i in range(len(inparr.shape))]))}, code='__out = {f}(__inp)'.format(f=func), outputs={'__out': Memlet.simple(outname, ','.join(['__i%d' % i for i in range(len(inparr.shape))]))}, @@ -1024,22 +1036,27 @@ def _argminmax(pv: ProgramVisitor, code = "__init = _val_and_idx(val={}, idx=-1)".format( dtypes.min_value(a_arr.dtype) if func == 'max' else dtypes.max_value(a_arr.dtype)) - nest.add_state().add_mapped_tasklet( - name="_arg{}_convert_".format(func), - map_ranges={'__i%d' % i: '0:%s' % n - for i, n in enumerate(a_arr.shape) if i != axis}, - inputs={}, - code=code, - outputs={ - '__init': Memlet.simple(reduced_structs, - ','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis)) - }, - external_edges=True) + nest.add_state().add_mapped_tasklet(name="_arg{}_convert_".format(func), + map_ranges={ + '__i%d' % i: '0:%s' % n + for i, n in enumerate(a_arr.shape) if i != axis + }, + inputs={}, + code=code, + outputs={ + '__init': + Memlet.simple( + reduced_structs, + ','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis)) + }, + external_edges=True) nest.add_state().add_mapped_tasklet( name="_arg{}_reduce_".format(func), - map_ranges={'__i%d' % i: '0:%s' % n - for i, n in enumerate(a_arr.shape)}, + map_ranges={ + '__i%d' % i: '0:%s' % n + for i, n in enumerate(a_arr.shape) + }, inputs={'__in': Memlet.simple(a, ','.join('__i%d' % i for i in range(len(a_arr.shape))))}, code="__out = _val_and_idx(idx={}, val=__in)".format("__i%d" % axis), outputs={ @@ -1059,8 +1076,10 @@ def _argminmax(pv: ProgramVisitor, nest.add_state().add_mapped_tasklet( name="_arg{}_extract_".format(func), - map_ranges={'__i%d' % i: '0:%s' % n - for i, n in enumerate(a_arr.shape) if i != axis}, + map_ranges={ + '__i%d' % i: '0:%s' % n + for i, n in enumerate(a_arr.shape) if i != axis + }, inputs={ '__in': Memlet.simple(reduced_structs, ','.join('__i%d' % i for i in range(len(a_arr.shape)) if i != axis)) @@ -1183,9 +1202,10 @@ def _unop(sdfg: SDFG, state: SDFGState, op1: str, opcode: str, opname: str): opcode = 'not' name, _ = sdfg.add_temp_transient(arr1.shape, restype, arr1.storage) - state.add_mapped_tasklet("_%s_" % opname, {'__i%d' % i: '0:%s' % s - for i, s in enumerate(arr1.shape)}, - {'__in1': Memlet.simple(op1, ','.join(['__i%d' % i for i in range(len(arr1.shape))]))}, + state.add_mapped_tasklet("_%s_" % opname, { + '__i%d' % i: '0:%s' % s + for i, s in enumerate(arr1.shape) + }, {'__in1': Memlet.simple(op1, ','.join(['__i%d' % i for i in range(len(arr1.shape))]))}, '__out = %s __in1' % opcode, {'__out': Memlet.simple(name, ','.join(['__i%d' % i for i in range(len(arr1.shape))]))}, external_edges=True) @@ -4709,8 +4729,10 @@ def _cupy_full(pv: ProgramVisitor, name, _ = sdfg.add_temp_transient(shape, dtype, storage=dtypes.StorageType.GPU_Global) state.add_mapped_tasklet( - '_cupy_full_', {"__i{}".format(i): "0: {}".format(s) - for i, s in enumerate(shape)}, {}, + '_cupy_full_', { + "__i{}".format(i): "0: {}".format(s) + for i, s in enumerate(shape) + }, {}, "__out = {}".format(fill_value), dict(__out=dace.Memlet.simple(name, ",".join(["__i{}".format(i) for i in range(len(shape))]))), external_edges=True) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index eb37fa3d7a..8af5f2bcb0 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -81,7 +81,7 @@ def __contains__(self, key): token = tokens.pop(0) result = hasattr(desc, 'members') and token in desc.members return result - + def keys(self): result = super(NestedDict, self).keys() for k, v in self.items(): @@ -1279,10 +1279,10 @@ def _yield_nested_data(name, arr): def _used_symbols_internal(self, all_symbols: bool, - defined_syms: Optional[Set]=None, - free_syms: Optional[Set]=None, - used_before_assignment: Optional[Set]=None, - keep_defined_in_mapping: bool=False) -> Tuple[Set[str], Set[str], Set[str]]: + defined_syms: Optional[Set] = None, + free_syms: Optional[Set] = None, + used_before_assignment: Optional[Set] = None, + keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: defined_syms = set() if defined_syms is None else defined_syms free_syms = set() if free_syms is None else free_syms used_before_assignment = set() if used_before_assignment is None else used_before_assignment @@ -1299,10 +1299,11 @@ def _used_symbols_internal(self, for code in self.exit_code.values(): free_syms |= symbolic.symbols_in_code(code.as_string, self.symbols.keys()) - return super()._used_symbols_internal( - all_symbols=all_symbols, keep_defined_in_mapping=keep_defined_in_mapping, - defined_syms=defined_syms, free_syms=free_syms, used_before_assignment=used_before_assignment - ) + return super()._used_symbols_internal(all_symbols=all_symbols, + keep_defined_in_mapping=keep_defined_in_mapping, + defined_syms=defined_syms, + free_syms=free_syms, + used_before_assignment=used_before_assignment) def get_all_toplevel_symbols(self) -> Set[str]: """