diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 28981d3bf1..b5e761e28e 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -10,14 +10,16 @@ import numpy as np from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction, - FindSymbols, MapExprStmts, Transformer, make_callable) + FindNodes, FindSymbols, MapExprStmts, Transformer, + make_callable) from devito.passes import is_gpu_create from devito.passes.iet.engine import iet_pass from devito.passes.iet.langbase import LangBB from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer, SizeOf, VOID, Keyword, pow_to_mul) from devito.tools import as_mapper, as_list, as_tuple, filter_sorted, flatten -from devito.types import Array, CustomDimension, DeviceMap, DeviceRM, Eq, Symbol +from devito.types import (Array, ComponentAccess, CustomDimension, DeviceMap, + DeviceRM, Eq, Symbol) __all__ = ['DataManager', 'DeviceAwareDataManager', 'Storage'] @@ -214,6 +216,38 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args): storage.update(obj, site, allocs=alloc, frees=free, efuncs=(efunc0, efunc1)) + def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage): + """ + Allocate a Bundle struct in the host high bandwidth memory. + """ + decl = Definition(obj) + + memptr = VOID(Byref(obj._C_symbol), '**') + alignment = obj._data_alignment + nbytes = SizeOf(obj._C_typedata) + alloc = self.lang['host-alloc'](memptr, alignment, nbytes) + + nbytes_param = Symbol(name='nbytes', dtype=np.uint64, is_const=True) + nbytes_arg = SizeOf(obj.indexed._C_typedata)*obj.size + + ffp1 = FieldFromPointer(obj._C_field_nbytes, obj._C_symbol) + init0 = DummyExpr(ffp1, nbytes_param) + ffp2 = FieldFromPointer(obj._C_field_size, obj._C_symbol) + init1 = DummyExpr(ffp2, 0) + + free = self.lang['host-free'](obj._C_symbol) + + ret = Return(obj._C_symbol) + + name = self.sregistry.make_name(prefix='alloc') + body = (decl, alloc, init0, init1, ret) + efunc0 = make_callable(name, body, retval=obj) + args = list(efunc0.parameters) + args[args.index(nbytes_param)] = nbytes_arg + alloc = Call(name, args, retobj=obj) + + storage.update(obj, site, allocs=alloc, frees=free, efuncs=efunc0) + def _alloc_object_array_on_low_lat_mem(self, site, obj, storage): """ Allocate an Array of Objects in the low latency memory. @@ -340,9 +374,22 @@ def place_definitions(self, iet, globs=None, **kwargs): for i in FindSymbols().visit(iet): if i in defines: continue + elif i.is_LocalObject: self._alloc_object_on_low_lat_mem(iet, i, storage) - elif i.is_Array or i.is_Bundle: + + elif i.is_Bundle: + if i._mem_heap: + if i.is_transient: + self._alloc_bundle_struct_on_high_bw_mem(iet, i, storage) + elif i._mem_local: + self._alloc_local_array_on_high_bw_mem(iet, i, storage) + elif i._mem_mapped: + self._alloc_mapped_array_on_high_bw_mem(iet, i, storage) + elif i._mem_stack: + self._alloc_array_on_low_lat_mem(iet, i, storage) + + elif i.is_Array: if i._mem_heap: if i._mem_host: self._alloc_host_array_on_high_bw_mem(iet, i, storage) @@ -355,8 +402,10 @@ def place_definitions(self, iet, globs=None, **kwargs): elif globs is not None: # Track, to be handled by the EntryFunction being a global obj! globs.add(i) + elif i.is_ObjectArray: self._alloc_object_array_on_low_lat_mem(iet, i, storage) + elif i.is_PointerArray: self._alloc_pointed_array_on_high_bw_mem(iet, i, storage) @@ -571,9 +620,12 @@ def make_zero_init(obj, rcompile, sregistry): cdims.append(CustomDimension(name=d.name, parent=d, symbolic_min=m, symbolic_max=M)) - eq = Eq(obj[cdims], 0) + if obj.is_Bundle: + eqns = [Eq(ComponentAccess(obj[cdims], i), 0) for i in range(obj.ncomp)] + else: + eqns = [Eq(obj[cdims], 0)] - irs, byproduct = rcompile(eq) + irs, byproduct = rcompile(eqns) init = irs.iet.body.body[0] @@ -581,6 +633,10 @@ def make_zero_init(obj, rcompile, sregistry): efunc = make_callable(name, init) init = Call(name, efunc.parameters) - efuncs = [efunc] + [i.root for i in byproduct.funcs] + efuncs = [efunc] + + # Also the called device kernels, if any + calls = [i.name for i in FindNodes(Call).visit(efunc)] + efuncs.extend([i.root for i in byproduct.funcs if i.root.name in calls]) return efuncs, init