Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Avoid allocating Bundles on the host if transient #2503

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 62 additions & 6 deletions devito/passes/iet/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
import numpy as np

from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction,
FindSymbols, MapExprStmts, Transformer, make_callable)
FindNodes, FindSymbols, MapExprStmts, Transformer,
make_callable)
from devito.passes import is_gpu_create
from devito.passes.iet.engine import iet_pass
from devito.passes.iet.langbase import LangBB
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer,
SizeOf, VOID, Keyword, pow_to_mul)
from devito.tools import as_mapper, as_list, as_tuple, filter_sorted, flatten
from devito.types import Array, CustomDimension, DeviceMap, DeviceRM, Eq, Symbol
from devito.types import (Array, ComponentAccess, CustomDimension, DeviceMap,
DeviceRM, Eq, Symbol)

__all__ = ['DataManager', 'DeviceAwareDataManager', 'Storage']

Expand Down Expand Up @@ -214,6 +216,38 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):

storage.update(obj, site, allocs=alloc, frees=free, efuncs=(efunc0, efunc1))

def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: some comments would help reduce the bus factor of this code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, if it turns green I will add it to the to-be-rebased #2500

"""
Allocate a Bundle struct in the host high bandwidth memory.
"""
decl = Definition(obj)

memptr = VOID(Byref(obj._C_symbol), '**')
alignment = obj._data_alignment
nbytes = SizeOf(obj._C_typedata)
alloc = self.lang['host-alloc'](memptr, alignment, nbytes)

nbytes_param = Symbol(name='nbytes', dtype=np.uint64, is_const=True)
nbytes_arg = SizeOf(obj.indexed._C_typedata)*obj.size

ffp1 = FieldFromPointer(obj._C_field_nbytes, obj._C_symbol)
init0 = DummyExpr(ffp1, nbytes_param)
ffp2 = FieldFromPointer(obj._C_field_size, obj._C_symbol)
init1 = DummyExpr(ffp2, 0)

free = self.lang['host-free'](obj._C_symbol)

ret = Return(obj._C_symbol)

name = self.sregistry.make_name(prefix='alloc')
body = (decl, alloc, init0, init1, ret)
efunc0 = make_callable(name, body, retval=obj)
args = list(efunc0.parameters)
args[args.index(nbytes_param)] = nbytes_arg
alloc = Call(name, args, retobj=obj)

storage.update(obj, site, allocs=alloc, frees=free, efuncs=efunc0)

def _alloc_object_array_on_low_lat_mem(self, site, obj, storage):
"""
Allocate an Array of Objects in the low latency memory.
Expand Down Expand Up @@ -340,9 +374,22 @@ def place_definitions(self, iet, globs=None, **kwargs):
for i in FindSymbols().visit(iet):
if i in defines:
continue

elif i.is_LocalObject:
self._alloc_object_on_low_lat_mem(iet, i, storage)
elif i.is_Array or i.is_Bundle:

elif i.is_Bundle:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way to reduce the "death by conditionals" here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After spending a couple of hours on this very matter this morning, I'm inclined to say that it's pretty challenging to come up with one. The problem is that there's a combinatorial explosion of possibilities (luckily with small exponent), and all those ifs implement such space

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose it's all attributes too, which makes it even more challenging without obfuscating what's going on

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the implicit question was "can you use singledispatch", yeah, you got it right

if i._mem_heap:
if i.is_transient:
self._alloc_bundle_struct_on_high_bw_mem(iet, i, storage)
elif i._mem_local:
self._alloc_local_array_on_high_bw_mem(iet, i, storage)
elif i._mem_mapped:
self._alloc_mapped_array_on_high_bw_mem(iet, i, storage)
elif i._mem_stack:
self._alloc_array_on_low_lat_mem(iet, i, storage)

elif i.is_Array:
if i._mem_heap:
if i._mem_host:
self._alloc_host_array_on_high_bw_mem(iet, i, storage)
Expand All @@ -355,8 +402,10 @@ def place_definitions(self, iet, globs=None, **kwargs):
elif globs is not None:
# Track, to be handled by the EntryFunction being a global obj!
globs.add(i)

elif i.is_ObjectArray:
self._alloc_object_array_on_low_lat_mem(iet, i, storage)

elif i.is_PointerArray:
self._alloc_pointed_array_on_high_bw_mem(iet, i, storage)

Expand Down Expand Up @@ -571,16 +620,23 @@ def make_zero_init(obj, rcompile, sregistry):
cdims.append(CustomDimension(name=d.name, parent=d,
symbolic_min=m, symbolic_max=M))

eq = Eq(obj[cdims], 0)
if obj.is_Bundle:
eqns = [Eq(ComponentAccess(obj[cdims], i), 0) for i in range(obj.ncomp)]
else:
eqns = [Eq(obj[cdims], 0)]

irs, byproduct = rcompile(eq)
irs, byproduct = rcompile(eqns)

init = irs.iet.body.body[0]

name = sregistry.make_name(prefix='init')
efunc = make_callable(name, init)
init = Call(name, efunc.parameters)

efuncs = [efunc] + [i.root for i in byproduct.funcs]
efuncs = [efunc]

# Also the called device kernels, if any
calls = [i.name for i in FindNodes(Call).visit(efunc)]
efuncs.extend([i.root for i in byproduct.funcs if i.root.name in calls])

return efuncs, init
Loading