-
Notifications
You must be signed in to change notification settings - Fork 229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
compiler: Avoid allocating Bundles on the host if transient #2503
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,14 +10,16 @@ | |
import numpy as np | ||
|
||
from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction, | ||
FindSymbols, MapExprStmts, Transformer, make_callable) | ||
FindNodes, FindSymbols, MapExprStmts, Transformer, | ||
make_callable) | ||
from devito.passes import is_gpu_create | ||
from devito.passes.iet.engine import iet_pass | ||
from devito.passes.iet.langbase import LangBB | ||
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer, | ||
SizeOf, VOID, Keyword, pow_to_mul) | ||
from devito.tools import as_mapper, as_list, as_tuple, filter_sorted, flatten | ||
from devito.types import Array, CustomDimension, DeviceMap, DeviceRM, Eq, Symbol | ||
from devito.types import (Array, ComponentAccess, CustomDimension, DeviceMap, | ||
DeviceRM, Eq, Symbol) | ||
|
||
__all__ = ['DataManager', 'DeviceAwareDataManager', 'Storage'] | ||
|
||
|
@@ -214,6 +216,38 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args): | |
|
||
storage.update(obj, site, allocs=alloc, frees=free, efuncs=(efunc0, efunc1)) | ||
|
||
def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage): | ||
""" | ||
Allocate a Bundle struct in the host high bandwidth memory. | ||
""" | ||
decl = Definition(obj) | ||
|
||
memptr = VOID(Byref(obj._C_symbol), '**') | ||
alignment = obj._data_alignment | ||
nbytes = SizeOf(obj._C_typedata) | ||
alloc = self.lang['host-alloc'](memptr, alignment, nbytes) | ||
|
||
nbytes_param = Symbol(name='nbytes', dtype=np.uint64, is_const=True) | ||
nbytes_arg = SizeOf(obj.indexed._C_typedata)*obj.size | ||
|
||
ffp1 = FieldFromPointer(obj._C_field_nbytes, obj._C_symbol) | ||
init0 = DummyExpr(ffp1, nbytes_param) | ||
ffp2 = FieldFromPointer(obj._C_field_size, obj._C_symbol) | ||
init1 = DummyExpr(ffp2, 0) | ||
|
||
free = self.lang['host-free'](obj._C_symbol) | ||
|
||
ret = Return(obj._C_symbol) | ||
|
||
name = self.sregistry.make_name(prefix='alloc') | ||
body = (decl, alloc, init0, init1, ret) | ||
efunc0 = make_callable(name, body, retval=obj) | ||
args = list(efunc0.parameters) | ||
args[args.index(nbytes_param)] = nbytes_arg | ||
alloc = Call(name, args, retobj=obj) | ||
|
||
storage.update(obj, site, allocs=alloc, frees=free, efuncs=efunc0) | ||
|
||
def _alloc_object_array_on_low_lat_mem(self, site, obj, storage): | ||
""" | ||
Allocate an Array of Objects in the low latency memory. | ||
|
@@ -340,9 +374,22 @@ def place_definitions(self, iet, globs=None, **kwargs): | |
for i in FindSymbols().visit(iet): | ||
if i in defines: | ||
continue | ||
|
||
elif i.is_LocalObject: | ||
self._alloc_object_on_low_lat_mem(iet, i, storage) | ||
elif i.is_Array or i.is_Bundle: | ||
|
||
elif i.is_Bundle: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any way to reduce the "death by conditionals" here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After spending a couple of hours on this very matter this morning, I'm inclined to say that it's pretty challenging to come up with one. The problem is that there's a combinatorial explosion of possibilities (luckily with small exponent), and all those There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose it's all attributes too, which makes it even more challenging without obfuscating what's going on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if the implicit question was "can you use singledispatch", yeah, you got it right |
||
if i._mem_heap: | ||
if i.is_transient: | ||
self._alloc_bundle_struct_on_high_bw_mem(iet, i, storage) | ||
elif i._mem_local: | ||
self._alloc_local_array_on_high_bw_mem(iet, i, storage) | ||
elif i._mem_mapped: | ||
self._alloc_mapped_array_on_high_bw_mem(iet, i, storage) | ||
elif i._mem_stack: | ||
self._alloc_array_on_low_lat_mem(iet, i, storage) | ||
|
||
elif i.is_Array: | ||
if i._mem_heap: | ||
if i._mem_host: | ||
self._alloc_host_array_on_high_bw_mem(iet, i, storage) | ||
|
@@ -355,8 +402,10 @@ def place_definitions(self, iet, globs=None, **kwargs): | |
elif globs is not None: | ||
# Track, to be handled by the EntryFunction being a global obj! | ||
globs.add(i) | ||
|
||
elif i.is_ObjectArray: | ||
self._alloc_object_array_on_low_lat_mem(iet, i, storage) | ||
|
||
elif i.is_PointerArray: | ||
self._alloc_pointed_array_on_high_bw_mem(iet, i, storage) | ||
|
||
|
@@ -571,16 +620,23 @@ def make_zero_init(obj, rcompile, sregistry): | |
cdims.append(CustomDimension(name=d.name, parent=d, | ||
symbolic_min=m, symbolic_max=M)) | ||
|
||
eq = Eq(obj[cdims], 0) | ||
if obj.is_Bundle: | ||
eqns = [Eq(ComponentAccess(obj[cdims], i), 0) for i in range(obj.ncomp)] | ||
else: | ||
eqns = [Eq(obj[cdims], 0)] | ||
|
||
irs, byproduct = rcompile(eq) | ||
irs, byproduct = rcompile(eqns) | ||
|
||
init = irs.iet.body.body[0] | ||
|
||
name = sregistry.make_name(prefix='init') | ||
efunc = make_callable(name, init) | ||
init = Call(name, efunc.parameters) | ||
|
||
efuncs = [efunc] + [i.root for i in byproduct.funcs] | ||
efuncs = [efunc] | ||
|
||
# Also the called device kernels, if any | ||
calls = [i.name for i in FindNodes(Call).visit(efunc)] | ||
efuncs.extend([i.root for i in byproduct.funcs if i.root.name in calls]) | ||
|
||
return efuncs, init |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: some comments would help reduce the bus factor of this code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, if it turns green I will add it to the to-be-rebased #2500