Skip to content

Commit

Permalink
compiler: Further split tests
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Dec 16, 2024
1 parent f8182b7 commit 5a014ed
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 60 deletions.
2 changes: 1 addition & 1 deletion devito/mpi/halo_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from devito.ir.iet.nodes import HaloMixin
from devito.symbolics.manipulation import _uxreplace_registry
from devito.tools import (Reconstructable, Tag, as_tuple, filter_ordered, flatten,
frozendict, is_integer, filter_sorted, OrderedSet)
frozendict, is_integer, filter_sorted)
from devito.types import Grid

__all__ = ['HaloScheme', 'HaloSchemeEntry', 'HaloSchemeException', 'HaloTouch']
Expand Down
6 changes: 2 additions & 4 deletions devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@ def optimize_halospots(iet, **kwargs):

def _drop_reduction_halospots(iet):
"""
Remove HaloSpots that:
* Would be used to compute Increments (in which case, a halo exchange
is actually unnecessary)
Remove HaloSpots that are used to compute Increments
(in which case, a halo exchange is actually unnecessary)
"""
mapper = defaultdict(set)

Expand Down
118 changes: 63 additions & 55 deletions tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,16 +1008,14 @@ def test_avoid_haloupdate_if_distr_but_sequential(self, mode):
calls = FindNodes(Call).visit(op)
assert len(calls) == 0

@pytest.mark.parallel(mode=1)
def test_issue_2448(self, mode):
@pytest.fixture
def setup(self):
shape = (2,)
so = 2
tn = 30

grid = Grid(shape=shape)

# Time related
tn = 30

# Velocity and pressure fields
v = TimeFunction(name='v', grid=grid, space_order=so)
tau = TimeFunction(name='tau', grid=grid, space_order=so)
Expand All @@ -1026,105 +1024,111 @@ def test_issue_2448(self, mode):
pde_v = v.dt - (tau.dx)
pde_tau = tau.dt - ((v.forward).dx)
u_v = Eq(v.forward, solve(pde_v, v.forward))

u_tau = Eq(tau.forward, solve(pde_tau, tau.forward))

# Test two variants of receiver interpolation
# Receiver
rec = SparseTimeFunction(name="rec", grid=grid, npoint=1, nt=tn)
rec.coordinates.data[:, 0] = np.linspace(0., shape[0], num=1)

# The receiver 0
rec_term0 = rec.interpolate(expr=v)
return grid, v, tau, u_v, u_tau, rec

# The receiver 1
rec_term1 = rec.interpolate(expr=v.forward)
@pytest.mark.parallel(mode=1)
def test_issue_2448_I(self, mode, setup):
_, v, tau, u_v, u_tau, rec = setup

# Test receiver interpolation 0, here we have a halo exchange hoisted
op0 = Operator([u_v] + [u_tau] + rec_term0)
rec_term0 = rec.interpolate(expr=v)

calls = [i for i in FindNodes(Call).visit(op0)
if isinstance(i, HaloUpdateCall)]
op0 = Operator([u_v, u_tau, rec_term0])

# The correct we want
assert len(calls) == 3
calls = [i for i in FindNodes(Call).visit(op0) if isinstance(i, HaloUpdateCall)]

assert len(calls) == 3
assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[0])) == 1
assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[1])) == 2
assert calls[0].arguments[0] is v
assert calls[1].arguments[0] is tau
assert calls[2].arguments[0] is v

# Test receiver interpolation 1, here we should not have any halo exchange
# hoisted
op1 = Operator([u_v] + [u_tau] + rec_term1)
@pytest.mark.parallel(mode=1)
def test_issue_2448_II(self, mode, setup):
_, v, tau, u_v, u_tau, rec = setup

calls = [i for i in FindNodes(Call).visit(op1)
if isinstance(i, HaloUpdateCall)]
rec_term1 = rec.interpolate(expr=v.forward)

# The correct we want
assert len(calls) == 3
op1 = Operator([u_v, u_tau, rec_term1])

calls = [i for i in FindNodes(Call).visit(op1) if isinstance(i, HaloUpdateCall)]

assert len(calls) == 3
assert len(FindNodes(HaloUpdateCall).visit(op1.body.body[1].body[0])) == 0
assert len(FindNodes(HaloUpdateCall).visit(op1.body.body[1].body[1])) == 3
assert calls[0].arguments[0] is tau
assert calls[1].arguments[0] is v
assert calls[2].arguments[0] is v

# Further complicate/stree-test adding an artifical example
# with two hoisting opportunities
@pytest.mark.parallel(mode=1)
def test_issue_2448_III(self, mode, setup):
grid, v, tau, u_v, u_tau, rec = setup

# Velocity and pressure fields
v2 = TimeFunction(name='v2', grid=grid, space_order=so)
tau2 = TimeFunction(name='tau2', grid=grid, space_order=so)
# Additional velocity and pressure fields
v2 = TimeFunction(name='v2', grid=grid, space_order=2)
tau2 = TimeFunction(name='tau2', grid=grid, space_order=2)

# First order elastic-like dependencies equations
pde_v2 = v2.dt - (tau2.dx)
pde_tau2 = tau2.dt - ((v2.forward).dx)
u_v2 = Eq(v2.forward, solve(pde_v2, v2.forward))

u_tau2 = Eq(tau2.forward, solve(pde_tau2, tau2.forward))

# Test two variants of receiver interpolation
rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=1, nt=tn)
rec2.coordinates.data[:, 0] = np.linspace(0., shape[0], num=1)
# Receiver
rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=1, nt=30)
rec2.coordinates.data[:, 0] = np.linspace(0., grid.shape[0], num=1)

# The receiver 2
rec_term0 = rec.interpolate(expr=v)
rec_term2 = rec2.interpolate(expr=v2)

# The receiver 3
rec_term3 = rec2.interpolate(expr=v2.forward)

# Test receiver interpolation 0, here we have a halo exchange hoisted
op2 = Operator([u_v] + [u_v2] + [u_tau] + [u_tau2] + rec_term0 + rec_term2)
op2 = Operator([u_v, u_v2, u_tau, u_tau2, rec_term0, rec_term2])

calls = [i for i in FindNodes(Call).visit(op2)
if isinstance(i, HaloUpdateCall)]
calls = [i for i in FindNodes(Call).visit(op2) if isinstance(i, HaloUpdateCall)]

# The correct we want
assert len(calls) == 5

assert len(FindNodes(HaloUpdateCall).visit(op2.body.body[1].body[1].body[0])) == 2
assert len(FindNodes(HaloUpdateCall).visit(op2.body.body[1].body[1].body[1])) == 3

assert calls[0].arguments[0] is v
assert calls[1].arguments[0] is v2
assert calls[2].arguments[0] is tau
assert calls[2].arguments[1] is tau2
assert calls[3].arguments[0] is v
assert calls[4].arguments[0] is v2

# Test receiver interpolation 0, here we have a halo exchange hoisted
op3 = Operator([u_v] + [u_v2] + [u_tau] + [u_tau2] + rec_term0 + rec_term3)
@pytest.mark.parallel(mode=1)
def test_issue_2448_IV(self, mode, setup):
grid, v, tau, u_v, u_tau, rec = setup

calls = [i for i in FindNodes(Call).visit(op3)
if isinstance(i, HaloUpdateCall)]
# Additional velocity and pressure fields
v2 = TimeFunction(name='v2', grid=grid, space_order=2)
tau2 = TimeFunction(name='tau2', grid=grid, space_order=2)

# The correct we want
assert len(calls) == 5
# First order elastic-like dependencies equations
pde_v2 = v2.dt - (tau2.dx)
pde_tau2 = tau2.dt - ((v2.forward).dx)
u_v2 = Eq(v2.forward, solve(pde_v2, v2.forward))
u_tau2 = Eq(tau2.forward, solve(pde_tau2, tau2.forward))

# Receiver
rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=1, nt=30)
rec2.coordinates.data[:, 0] = np.linspace(0., grid.shape[0], num=1)

rec_term0 = rec.interpolate(expr=v)
rec_term3 = rec2.interpolate(expr=v2.forward)

op3 = Operator([u_v, u_v2, u_tau, u_tau2, rec_term0, rec_term3])

calls = [i for i in FindNodes(Call).visit(op3) if isinstance(i, HaloUpdateCall)]

assert len(calls) == 5
assert len(FindNodes(HaloUpdateCall).visit(op3.body.body[1].body[1].body[0])) == 1
assert len(FindNodes(HaloUpdateCall).visit(op3.body.body[1].body[1].body[1])) == 4

assert calls[0].arguments[0] is v
assert calls[1].arguments[0] is tau
assert calls[1].arguments[1] is tau2
Expand All @@ -1136,8 +1140,7 @@ def test_issue_2448(self, mode):
def test_issue_2448_backward(self, mode):
'''
Similar to test_issue_2448, but with backward instead of forward
so that the hoisted halo
so that the hoisted halo has different starting point
'''
shape = (2,)
so = 2
Expand Down Expand Up @@ -1397,7 +1400,7 @@ def test_avoid_fullmode_if_crossloop_dep(self, mode):
assert np.all(f.data[:] == 2.)

@pytest.mark.parallel(mode=2)
def test_avoid_haloudate_if_flowdep_along_other_dim(self, mode):
def test_avoid_halopudate_if_flowdep_along_other_dim(self, mode):
grid = Grid(shape=(10,))
x = grid.dimensions[0]
t = grid.stepping_dim
Expand Down Expand Up @@ -1535,6 +1538,11 @@ def test_merge_haloupdate_if_diff_locindices_v1(self, mode):

calls = FindNodes(Call).visit(op)
assert len(calls) == 2
assert calls[0].arguments[3].args[0] is t.symbolic_min

assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[0])) == 1
assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[1])) == 1
assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[2])) == 0

op.apply(time_M=1)
glb_pos_map = f.grid.distributor.glb_pos_map
Expand Down Expand Up @@ -2953,7 +2961,7 @@ def test_elastic_structure(self, mode):
assert calls[4].arguments[1] is v[1]


class TestTTIwMPI:
class TestTTIOp:

@pytest.mark.parallel(mode=1)
def test_halo_structure(self, mode):
Expand Down

0 comments on commit 5a014ed

Please sign in to comment.