Skip to content

Commit

Permalink
tests: Add backward test and drop redundancies
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Dec 5, 2024
1 parent 01c07a8 commit 76f6412
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 54 deletions.
8 changes: 8 additions & 0 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,14 @@ def bounds(self, _min=None, _max=None):

return (_min, _max)

@property
def start(self):
"""The start value."""
if self.direction is Forward:
return self.dim.symbolic_min
else:
return self.dim.symbolic_max

@property
def step(self):
"""The step value."""
Expand Down
7 changes: 2 additions & 5 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,6 @@ def default_retval(cls):
the nodes of type ``child_types`` retrieved by the search. This behaviour
can be changed through this parameter. Accepted values are:
- 'immediate': only the closest matching ancestor is mapped.
- 'groupby': the matching ancestors are grouped together as a single key.
"""

def __init__(self, parent_type=None, child_types=None, mode=None):
Expand All @@ -886,7 +885,7 @@ def __init__(self, parent_type=None, child_types=None, mode=None):
assert issubclass(parent_type, Node)
self.parent_type = parent_type
self.child_types = as_tuple(child_types) or (Call, Expression)
assert mode in (None, 'immediate', 'groupby')
assert mode in (None, 'immediate')
self.mode = mode

def visit_object(self, o, ret=None, **kwargs):
Expand All @@ -903,9 +902,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
if parents is None:
parents = []
if isinstance(o, self.child_types):
if self.mode == 'groupby':
ret.setdefault(as_tuple(parents), []).append(o)
elif self.mode == 'immediate':
if self.mode == 'immediate':
if in_parent:
ret.setdefault(parents[-1], []).append(o)
else:
Expand Down
37 changes: 19 additions & 18 deletions devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,7 @@ def _hoist_invariant(iet):
hsmapper = {}
imapper = defaultdict(list)

iter_mapper = MapNodes(Iteration, HaloSpot, 'immediate').visit(iet)

# Drop void `halo_scheme`s from the analysis
iter_mapper = {k: [hs for hs in v if not hs.halo_scheme.is_void]
for k, v in iter_mapper.items()}

iter_mapper = {k: v for k, v in iter_mapper.items() if k is not None}

iter_mapper = {k: v for k, v in iter_mapper.items() if len(v) > 1}
iter_mapper = _filter_iter_mapper(iet)

for it, halo_spots in iter_mapper.items():
for hs0, hs1 in combinations(halo_spots, r=2):
Expand Down Expand Up @@ -121,9 +113,8 @@ def _hoist_invariant(iet):
for d in hse.loc_indices:
md = hse.loc_indices[d]
if md in it.uindices:
new_min = md.symbolic_min.subs(it.dim,
it.dim.symbolic_min)
raw_loc_indices[d] = new_min
md_sub = it.start
raw_loc_indices[d] = md.symbolic_min.subs(it.dim, md_sub)
else:
raw_loc_indices[d] = md

Expand Down Expand Up @@ -164,11 +155,7 @@ def _merge_halospots(iet):

mapper = {}

iter_mapper = MapNodes(Iteration, HaloSpot, 'immediate').visit(iet)

iter_mapper = {k: v for k, v in iter_mapper.items() if k is not None}

iter_mapper = {k: v for k, v in iter_mapper.items() if len(v) > 1}
iter_mapper = _filter_iter_mapper(iet)

for it, halo_spots in iter_mapper.items():
scope = Scope([e.expr for e in FindNodes(Expression).visit(it)])
Expand Down Expand Up @@ -362,6 +349,20 @@ def mpiize(graph, **kwargs):

# *** Utilities

def _filter_iter_mapper(iet):
"""
Given an IET, return a mapper from Iterations to the HaloSpots.
Additionally, filter out Iterations that are not of interest.
"""
iter_mapper = MapNodes(Iteration, HaloSpot, 'immediate').visit(iet)
iter_mapper = {k: [hs for hs in v if not hs.halo_scheme.is_void]
for k, v in iter_mapper.items()}
iter_mapper = {k: v for k, v in iter_mapper.items() if k is not None}
iter_mapper = {k: v for k, v in iter_mapper.items() if len(v) > 1}

return iter_mapper


def _make_cond_mapper(iet):
cond_mapper = MapHaloSpots().visit(iet)
return {hs: {i for i in v if i.is_Conditional and
Expand Down Expand Up @@ -395,4 +396,4 @@ def _rule1(dep, hs, loc_indices):
for d, v in loc_indices.items())


rules = [_rule0, _rule1]
rules = (_rule0, _rule1)
4 changes: 2 additions & 2 deletions tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
SparseTimeFunction, Dimension, SubDimension,
ConditionalDimension, DefaultDimension, Grid, Operator,
norm, grad, div, dimensions, switchconfig, configuration,
centered, first_derivative, solve, transpose, Abs, cos,
sin, sqrt, floor, Ge, Lt, Derivative, solve)
first_derivative, solve, transpose, Abs, cos,
sin, sqrt, floor, Ge, Lt, Derivative)
from devito.exceptions import InvalidArgument, InvalidOperator
from devito.ir import (Conditional, DummyEq, Expression, Iteration, FindNodes,
FindSymbols, ParallelIteration, retrieve_iteration_tree)
Expand Down
98 changes: 69 additions & 29 deletions tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from devito.tools import Bunch

from examples.seismic.acoustic import acoustic_setup
from examples.seismic import demo_model
from tests.test_dse import TestTTI


Expand Down Expand Up @@ -1031,9 +1030,8 @@ def test_issue_2448(self, mode):
u_tau = Eq(tau.forward, solve(pde_tau, tau.forward))

# Test two variants of receiver interpolation
nrec = 1
rec = SparseTimeFunction(name="rec", grid=grid, npoint=nrec, nt=tn)
rec.coordinates.data[:, 0] = np.linspace(0., shape[0], num=nrec)
rec = SparseTimeFunction(name="rec", grid=grid, npoint=1, nt=tn)
rec.coordinates.data[:, 0] = np.linspace(0., shape[0], num=1)

# The receiver 0
rec_term0 = rec.interpolate(expr=v)
Expand Down Expand Up @@ -1087,9 +1085,8 @@ def test_issue_2448(self, mode):
u_tau2 = Eq(tau2.forward, solve(pde_tau2, tau2.forward))

# Test two variants of receiver interpolation
nrec = 1
rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=nrec, nt=tn)
rec2.coordinates.data[:, 0] = np.linspace(0., shape[0], num=nrec)
rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=1, nt=tn)
rec2.coordinates.data[:, 0] = np.linspace(0., shape[0], num=1)

# The receiver 2
rec_term2 = rec2.interpolate(expr=v2)
Expand Down Expand Up @@ -1135,6 +1132,56 @@ def test_issue_2448(self, mode):
assert calls[3].arguments[0] is v2
assert calls[4].arguments[0] is v2

@pytest.mark.parallel(mode=1)
def test_issue_2448_backward(self, mode):
'''
Similar to test_issue_2448, but with backward instead of forward
so that the hoisted halo
'''
shape = (2,)
so = 2

grid = Grid(shape=shape)
t = grid.stepping_dim

tn = 7

# Velocity and pressure fields
v = TimeFunction(name='v', grid=grid, space_order=so)
v.data_with_halo[0, :] = 1.
v.data_with_halo[1, :] = 3.

tau = TimeFunction(name='tau', grid=grid, space_order=so)
tau.data_with_halo[:] = 1.

# First order elastic-like dependencies equations
pde_v = v.dt - (tau.dx)
pde_tau = tau.dt - ((v.backward).dx)

u_v = Eq(v.backward, solve(pde_v, v))
u_tau = Eq(tau.backward, solve(pde_tau, tau))

# Test two variants of receiver interpolation
nrec = 1
rec = SparseTimeFunction(name="rec", grid=grid, npoint=nrec, nt=tn)
rec.coordinates.data[:, 0] = np.linspace(0., shape[0], num=nrec)

# Test receiver interpolation 0, here we have a halo exchange hoisted
op0 = Operator([u_v] + [u_tau] + rec.interpolate(expr=v))

calls = [i for i in FindNodes(Call).visit(op0)
if isinstance(i, HaloUpdateCall)]

# The correct we want
assert len(calls) == 3
assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[0])) == 1
assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[1])) == 2
assert calls[0].arguments[0] is v
assert calls[0].arguments[3].args[0] is t.symbolic_max
assert calls[1].arguments[0] is tau
assert calls[2].arguments[0] is v

@pytest.mark.parallel(mode=1)
def test_avoid_haloupdate_with_subdims(self, mode):
grid = Grid(shape=(4,))
Expand Down Expand Up @@ -2862,53 +2909,46 @@ class TestElastic:
def test_elastic_structure(self, mode):

so = 4
model = demo_model(preset='layers-elastic', nlayers=1,
shape=(301, 301), spacing=(10., 10.),
space_order=so)
grid = Grid(shape=(3, 3))

v = VectorTimeFunction(name='v', grid=model.grid, space_order=so)
tau = TensorTimeFunction(name='t', grid=model.grid, space_order=so)
v = VectorTimeFunction(name='v', grid=grid, space_order=so)
tau = TensorTimeFunction(name='t', grid=grid, space_order=so)

# The receiver
nrec = 301
rec = SparseTimeFunction(name="rec", grid=model.grid, npoint=nrec, nt=10)
rec.coordinates.data[:, 0] = np.linspace(0., model.domain_size[0], num=nrec)
rec.coordinates.data[:, -1] = 5.
damp = Function(name='damp', grid=grid)
l = Function(name='lam', grid=grid)
mu = Function(name='mu', grid=grid)
ro = Function(name='b', grid=grid)

# The receiver
rec = SparseTimeFunction(name="rec", grid=grid, npoint=1, nt=10)
rec_term = rec.interpolate(expr=v[0] + v[1])

# Now let's try and create the staggered updates
# Lame parameters
l, mu, ro = model.lam, model.mu, model.b

# First order elastic wave equation
pde_v = v.dt - ro * div(tau)
pde_tau = (tau.dt - l * diag(div(v.forward)) -
mu * (grad(v.forward) + grad(v.forward).transpose(inner=False)))

# Time update
u_v = Eq(v.forward, model.damp * solve(pde_v, v.forward))
u_t = Eq(tau.forward, model.damp * solve(pde_tau, tau.forward))
u_v = Eq(v.forward, damp * solve(pde_v, v.forward))
u_t = Eq(tau.forward, damp * solve(pde_tau, tau.forward))

op = Operator([u_v] + [u_t] + rec_term)

assert len(op._func_table) == 11

calls = [i for i in FindNodes(Call).visit(op) if isinstance(i, HaloUpdateCall)]

# The correct we want
assert len(calls) == 5

assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[3].body[0])) == 1
assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[3].body[1])) == 4
assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[3].body[2])) == 0
assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[1].body[0])) == 1
assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[1].body[1])) == 4
assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[1].body[2])) == 0

assert calls[0].arguments[0] is v[0]
assert calls[0].arguments[1] is v[1]

assert calls[1].arguments[0] is tau[0, 0]
assert calls[2].arguments[0] is tau[0, 1]
assert calls[3].arguments[0] is tau[1, 1]

assert calls[4].arguments[0] is v[0]
assert calls[4].arguments[1] is v[1]

Expand Down

0 comments on commit 76f6412

Please sign in to comment.