Skip to content

Commit

Permalink
compiler: Fix misc review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Dec 9, 2024
1 parent 4567885 commit 2fbed4d
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 222 deletions.
20 changes: 5 additions & 15 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Forward, WithLock, PrefetchUpdate, detect_io)
from devito.symbolics import ListInitializer, CallFromPointer, ccode
from devito.tools import (Signer, as_tuple, filter_ordered, filter_sorted, flatten,
ctypes_to_cstr, OrderedSet)
ctypes_to_cstr)
from devito.types.basic import (AbstractFunction, AbstractSymbol, Basic, Indexed,
Symbol)
from devito.types.object import AbstractObject, LocalObject
Expand Down Expand Up @@ -1438,20 +1438,7 @@ def DummyExpr(*args, init=False):

# Nodes required for distributed-memory halo exchange

class HaloMixin:

def __repr__(self):
fstrings = []
for f in self.fmapper.keys():
loc_indices = OrderedSet(*(self.fmapper[f].loc_indices.values()))
loc_indices_str = str(list(loc_indices)) if loc_indices else ""
fstrings.append("%s%s" % (f.name, loc_indices_str))

functions = ",".join(fstrings)
return "<%s(%s)>" % (self.__class__.__name__, functions)


class HaloSpot(HaloMixin, Node):
class HaloSpot(Node):

"""
A halo exchange operation (e.g., send, recv, wait, ...) required to
Expand Down Expand Up @@ -1508,6 +1495,9 @@ def body(self):
def functions(self):
return tuple(self.fmapper)

def __repr__(self):
funcs = self.halo_scheme.__reprfuncs__()
return "<%s(%s)>" % (self.__class__.__name__, funcs)

# Utility classes

Expand Down
29 changes: 20 additions & 9 deletions devito/mpi/halo_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
from devito import configuration
from devito.data import CORE, OWNED, LEFT, CENTER, RIGHT
from devito.ir.support import Forward, Scope
from devito.ir.iet.nodes import HaloMixin
from devito.symbolics.manipulation import _uxreplace_registry
from devito.tools import (Reconstructable, Tag, as_tuple, filter_ordered, flatten,
frozendict, is_integer, filter_sorted)
frozendict, is_integer, filter_sorted, OrderedSet)
from devito.types import Grid

__all__ = ['HaloScheme', 'HaloSchemeEntry', 'HaloSchemeException', 'HaloTouch']
Expand All @@ -36,8 +35,8 @@ class HaloSchemeEntry(Reconstructable):
def __init__(self, loc_indices, loc_dirs, halos, dims):
self.loc_indices = frozendict(loc_indices)
self.loc_dirs = frozendict(loc_dirs)
self.halos = halos
self.dims = dims
self.halos = frozenset(halos)
self.dims = frozenset(dims)

def __eq__(self, other):
if not isinstance(other, HaloSchemeEntry):
Expand All @@ -48,10 +47,10 @@ def __eq__(self, other):
self.dims == other.dims)

def __hash__(self):
return hash((frozenset(self.loc_indices.items()),
frozenset(self.loc_dirs.items()),
frozenset(self.halos),
frozenset(self.dims)))
return hash((tuple(self.loc_indices.items()),
tuple(self.loc_dirs.items()),
self.halos,
self.dims))

def __repr__(self):
return (f"HaloSchemeEntry(loc_indices={self.loc_indices}, "
Expand All @@ -63,7 +62,7 @@ def __repr__(self):
OMapper = namedtuple('OMapper', 'core owned')


class HaloScheme(HaloMixin):
class HaloScheme():

"""
A HaloScheme describes a set of halo exchanges through a mapper:
Expand Down Expand Up @@ -121,6 +120,18 @@ def __init__(self, exprs, ispace):
self._honored[i.root] = frozenset([(ltk, rtk)])
self._honored = frozendict(self._honored)

def __reprfuncs__(self):
fstrings = []
for f in self.fmapper.keys():
loc_indices = OrderedSet(*(self.fmapper[f].loc_indices.values()))
loc_indices_str = str(list(loc_indices)) if loc_indices else ""
fstrings.append("%s%s" % (f.name, loc_indices_str))

return ",".join(fstrings)

def __repr__(self):
return "<%s(%s)>" % (self.__class__.__name__, self.__reprfuncs__())

def __eq__(self, other):
return (isinstance(other, HaloScheme) and
self._mapper == other._mapper and
Expand Down
30 changes: 18 additions & 12 deletions devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from devito.mpi.reduction_scheme import DistReduce
from devito.mpi.routines import HaloExchangeBuilder, ReductionBuilder
from devito.passes.iet.engine import iet_pass
from devito.tools import generator, frozendict
from devito.tools import generator

__all__ = ['mpiize']

Expand Down Expand Up @@ -94,7 +94,6 @@ def _hoist_invariant(iet):
continue

for f, v in hs1.fmapper.items():

if f not in hs0.functions:
continue

Expand All @@ -114,7 +113,7 @@ def _hoist_invariant(iet):
else:
raw_loc_indices[d] = v

hse = hse._rebuild(loc_indices=frozendict(raw_loc_indices))
hse = hse._rebuild(loc_indices=raw_loc_indices)
hs1.halo_scheme.fmapper[f] = hse

hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f)
Expand Down Expand Up @@ -348,20 +347,27 @@ def _filter_iter_mapper(iet):
Given an IET, return a mapper from Iterations to the HaloSpots.
Additionally, filter out Iterations that are not of interest.
"""
iter_mapper = MapNodes(Iteration, HaloSpot, 'immediate').visit(iet)
iter_mapper = {k: [hs for hs in v if not hs.halo_scheme.is_void]
for k, v in iter_mapper.items()}
iter_mapper = {k: v for k, v in iter_mapper.items() if k is not None}
iter_mapper = {k: v for k, v in iter_mapper.items() if len(v) > 1}
iter_mapper = {}
for k, v in MapNodes(Iteration, HaloSpot, 'immediate').visit(iet).items():
filtered_hs = [hs for hs in v if not hs.halo_scheme.is_void]
if k is not None and len(filtered_hs) > 1:
iter_mapper[k] = filtered_hs

return iter_mapper


def _make_cond_mapper(iet):
cond_mapper = MapHaloSpots().visit(iet)
return {hs: {i for i in v if i.is_Conditional and
not isinstance(i.condition, GuardFactorEq)}
for hs, v in cond_mapper.items()}

cond_mapper = {}
for hs, v in MapHaloSpots().visit(iet).items():
conditionals = set()
for i in v:
if i.is_Conditional and not isinstance(i.condition, GuardFactorEq):
conditionals.add(i)

cond_mapper[hs] = conditionals

return cond_mapper


def _check_control_flow(hs0, hs1, cond_mapper):
Expand Down
2 changes: 1 addition & 1 deletion examples/seismic/tutorials/09_viscoelastic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@
"source": [
"# References\n",
"\n",
"[1] Johan O. A. Roberston, *et.al.* (1994). \"Viscoelatic finite-difference modeling\" GEOPHYSICS, 59(9), 1444-1456.\n",
"[1] Johan O. A. Roberston, *et.al.* (1994). \"Viscoelastic finite-difference modeling\" GEOPHYSICS, 59(9), 1444-1456.\n",
"\n",
"\n",
"[2] https://janth.home.xs4all.nl/Software/fdelmodcManual.pdf"
Expand Down
2 changes: 1 addition & 1 deletion examples/seismic/viscoelastic/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,4 @@ def ForwardOperator(model, geometry, space_order=4, save=False, **kwargs):

# Substitute spacing terms to reduce flops
return Operator([u_v, u_r, u_t] + src_rec_expr, subs=model.spacing_map,
name='ViscoElForward', **kwargs)
name='ViscoIsoElasticForward', **kwargs)
Loading

0 comments on commit 2fbed4d

Please sign in to comment.