Skip to content

Commit

Permalink
compiler: class HaloSchemeEntry(Reconstructable)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Nov 6, 2024
1 parent fee66c8 commit c620588
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 24 deletions.
4 changes: 2 additions & 2 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Forward, WithLock, PrefetchUpdate, detect_io)
from devito.symbolics import ListInitializer, CallFromPointer, ccode
from devito.tools import (Signer, as_tuple, filter_ordered, filter_sorted, flatten,
ctypes_to_cstr)
ctypes_to_cstr, OrderedSet)
from devito.types.basic import (AbstractFunction, AbstractSymbol, Basic, Indexed,
Symbol)
from devito.types.object import AbstractObject, LocalObject
Expand Down Expand Up @@ -1459,7 +1459,7 @@ def __init__(self, body, halo_scheme):
def __repr__(self):
fstrings = []
for f in self.functions:
loc_indices = set().union(*[self.halo_scheme.fmapper[f].loc_indices.values()])
loc_indices = OrderedSet(*(self.halo_scheme.fmapper[f].loc_indices.values()))
loc_indices_str = str(list(loc_indices)) if loc_indices else ""

fstrings.append(f"{f.name}{loc_indices_str}")
Expand Down
21 changes: 8 additions & 13 deletions devito/mpi/halo_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from devito.ir.support import Forward, Scope
from devito.symbolics.manipulation import _uxreplace_registry
from devito.tools import (Reconstructable, Tag, as_tuple, filter_ordered, flatten,
frozendict, is_integer, filter_sorted)
frozendict, is_integer, filter_sorted, OrderedSet)
from devito.types import Grid

__all__ = ['HaloScheme', 'HaloSchemeEntry', 'HaloSchemeException', 'HaloTouch']
Expand All @@ -28,18 +28,16 @@ class HaloLabel(Tag):
STENCIL = HaloLabel('stencil')


class HaloSchemeEntry:
class HaloSchemeEntry(Reconstructable):

__rkwargs__ = ('loc_indices', 'loc_dirs', 'halos', 'dims')

def __init__(self, loc_indices, loc_dirs, halos, dims):
self.loc_indices = loc_indices
self.loc_dirs = loc_dirs
self.halos = halos
self.dims = dims

def __repr__(self):
return (f"HaloSchemeEntry(loc_indices={self.loc_indices}, "
f"loc_dirs={self.loc_dirs}, halos={self.halos}, dims={self.dims})")

def __eq__(self, other):
if not isinstance(other, HaloSchemeEntry):
return False
Expand All @@ -54,12 +52,9 @@ def __hash__(self):
frozenset(self.halos),
frozenset(self.dims)))

def rebuild(self, **kwargs):
loc_indices = kwargs.get('loc_indices', self.loc_indices)
loc_dirs = kwargs.get('loc_dirs', self.loc_dirs)
halos = kwargs.get('halos', self.halos)
dims = kwargs.get('dims', self.dims)
return HaloSchemeEntry(loc_indices, loc_dirs, halos, dims)
def __repr__(self):
return (f"HaloSchemeEntry(loc_indices={self.loc_indices}, "
f"loc_dirs={self.loc_dirs}, halos={self.halos}, dims={self.dims})")


Halo = namedtuple('Halo', 'dim side')
Expand Down Expand Up @@ -129,7 +124,7 @@ def __init__(self, exprs, ispace):
def __repr__(self):
fstrings = []
for f in self.fmapper:
loc_indices = set().union(*[self._mapper[f].loc_indices.values()])
loc_indices = OrderedSet(*(self._mapper[f].loc_indices.values()))
loc_indices_str = str(list(loc_indices)) if loc_indices else ""

fstrings.append(f"{f.name}{loc_indices_str}")
Expand Down
18 changes: 9 additions & 9 deletions devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def _hoist_halospots(iet):
continue

for it in iters:
# If also merge-able we can start hoisting the latter
for dep in scopes[it].d_flow.project(f):
if not any(r(dep, hs1, v.loc_indices) for r in merge_rules()):
break
Expand All @@ -120,17 +119,18 @@ def _hoist_halospots(iet):
# the loc_indices with known values
# TODO: Can I get this in a more elegant way?
for d in hse.loc_indices:
if hse.loc_indices[d].is_Symbol:
assert d in hse.loc_indices[d]._defines
root_min = hse.loc_indices[d].symbolic_min
new_min = root_min.subs(hse.loc_indices[d].root,
hse.loc_indices[d].root.symbolic_min)
md = hse.loc_indices[d]
if md.is_Symbol:
root = md.root
hse_min = md.symbolic_min
new_min = hse_min.subs(root, root.symbolic_min)
raw_loc_indices[d] = new_min
else:
assert d.symbolic_min in hse.loc_indices[d].free_symbols
raw_loc_indices[d] = hse.loc_indices[d]
# md is in form of an expression
assert d.symbolic_min in md.free_symbols
raw_loc_indices[d] = md

hse = hse.rebuild(loc_indices=frozendict(raw_loc_indices))
hse = hse._rebuild(loc_indices=frozendict(raw_loc_indices))
hs1.halo_scheme.fmapper[f] = hse

hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f)
Expand Down

0 comments on commit c620588

Please sign in to comment.