From 76f6412334a026af56b2584cd83868ca19e26ab7 Mon Sep 17 00:00:00 2001 From: George Bisbas Date: Mon, 18 Nov 2024 14:49:15 +0200 Subject: [PATCH] tests: Add backward test and drop redundancies --- devito/ir/iet/nodes.py | 8 ++++ devito/ir/iet/visitors.py | 7 +-- devito/passes/iet/mpi.py | 37 ++++++++------- tests/test_dse.py | 4 +- tests/test_mpi.py | 98 +++++++++++++++++++++++++++------------ 5 files changed, 100 insertions(+), 54 deletions(-) diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index f03112464e..76876d687a 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -617,6 +617,14 @@ def bounds(self, _min=None, _max=None): return (_min, _max) + @property + def start(self): + """The start value.""" + if self.direction is Forward: + return self.dim.symbolic_min + else: + return self.dim.symbolic_max + @property def step(self): """The step value.""" diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 505fe2e001..9b068a7d25 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -873,7 +873,6 @@ def default_retval(cls): the nodes of type ``child_types`` retrieved by the search. This behaviour can be changed through this parameter. Accepted values are: - 'immediate': only the closest matching ancestor is mapped. - - 'groupby': the matching ancestors are grouped together as a single key. """ def __init__(self, parent_type=None, child_types=None, mode=None): @@ -886,7 +885,7 @@ def __init__(self, parent_type=None, child_types=None, mode=None): assert issubclass(parent_type, Node) self.parent_type = parent_type self.child_types = as_tuple(child_types) or (Call, Expression) - assert mode in (None, 'immediate', 'groupby') + assert mode in (None, 'immediate') self.mode = mode def visit_object(self, o, ret=None, **kwargs): @@ -903,9 +902,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False): if parents is None: parents = [] if isinstance(o, self.child_types): - if self.mode == 'groupby': - ret.setdefault(as_tuple(parents), []).append(o) - elif self.mode == 'immediate': + if self.mode == 'immediate': if in_parent: ret.setdefault(parents[-1], []).append(o) else: diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index c352a86c3a..a3cef2ff8b 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -82,15 +82,7 @@ def _hoist_invariant(iet): hsmapper = {} imapper = defaultdict(list) - iter_mapper = MapNodes(Iteration, HaloSpot, 'immediate').visit(iet) - - # Drop void `halo_scheme`s from the analysis - iter_mapper = {k: [hs for hs in v if not hs.halo_scheme.is_void] - for k, v in iter_mapper.items()} - - iter_mapper = {k: v for k, v in iter_mapper.items() if k is not None} - - iter_mapper = {k: v for k, v in iter_mapper.items() if len(v) > 1} + iter_mapper = _filter_iter_mapper(iet) for it, halo_spots in iter_mapper.items(): for hs0, hs1 in combinations(halo_spots, r=2): @@ -121,9 +113,8 @@ def _hoist_invariant(iet): for d in hse.loc_indices: md = hse.loc_indices[d] if md in it.uindices: - new_min = md.symbolic_min.subs(it.dim, - it.dim.symbolic_min) - raw_loc_indices[d] = new_min + md_sub = it.start + raw_loc_indices[d] = md.symbolic_min.subs(it.dim, md_sub) else: raw_loc_indices[d] = md @@ -164,11 +155,7 @@ def _merge_halospots(iet): mapper = {} - iter_mapper = MapNodes(Iteration, HaloSpot, 'immediate').visit(iet) - - iter_mapper = {k: v for k, v in iter_mapper.items() if k is not None} - - iter_mapper = {k: v for k, v in iter_mapper.items() if len(v) > 1} + iter_mapper = _filter_iter_mapper(iet) for it, halo_spots in iter_mapper.items(): scope = Scope([e.expr for e in FindNodes(Expression).visit(it)]) @@ -362,6 +349,20 @@ def mpiize(graph, **kwargs): # *** Utilities +def _filter_iter_mapper(iet): + """ + Given an IET, return a mapper from Iterations to the HaloSpots. + Additionally, filter out Iterations that are not of interest. + """ + iter_mapper = MapNodes(Iteration, HaloSpot, 'immediate').visit(iet) + iter_mapper = {k: [hs for hs in v if not hs.halo_scheme.is_void] + for k, v in iter_mapper.items()} + iter_mapper = {k: v for k, v in iter_mapper.items() if k is not None} + iter_mapper = {k: v for k, v in iter_mapper.items() if len(v) > 1} + + return iter_mapper + + def _make_cond_mapper(iet): cond_mapper = MapHaloSpots().visit(iet) return {hs: {i for i in v if i.is_Conditional and @@ -395,4 +396,4 @@ def _rule1(dep, hs, loc_indices): for d, v in loc_indices.items()) -rules = [_rule0, _rule1] +rules = (_rule0, _rule1) diff --git a/tests/test_dse.py b/tests/test_dse.py index c82848a28a..989e22929a 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -11,8 +11,8 @@ SparseTimeFunction, Dimension, SubDimension, ConditionalDimension, DefaultDimension, Grid, Operator, norm, grad, div, dimensions, switchconfig, configuration, - centered, first_derivative, solve, transpose, Abs, cos, - sin, sqrt, floor, Ge, Lt, Derivative, solve) + first_derivative, solve, transpose, Abs, cos, + sin, sqrt, floor, Ge, Lt, Derivative) from devito.exceptions import InvalidArgument, InvalidOperator from devito.ir import (Conditional, DummyEq, Expression, Iteration, FindNodes, FindSymbols, ParallelIteration, retrieve_iteration_tree) diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 105c2b4b08..53b34b883f 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -20,7 +20,6 @@ from devito.tools import Bunch from examples.seismic.acoustic import acoustic_setup -from examples.seismic import demo_model from tests.test_dse import TestTTI @@ -1031,9 +1030,8 @@ def test_issue_2448(self, mode): u_tau = Eq(tau.forward, solve(pde_tau, tau.forward)) # Test two variants of receiver interpolation - nrec = 1 - rec = SparseTimeFunction(name="rec", grid=grid, npoint=nrec, nt=tn) - rec.coordinates.data[:, 0] = np.linspace(0., shape[0], num=nrec) + rec = SparseTimeFunction(name="rec", grid=grid, npoint=1, nt=tn) + rec.coordinates.data[:, 0] = np.linspace(0., shape[0], num=1) # The receiver 0 rec_term0 = rec.interpolate(expr=v) @@ -1087,9 +1085,8 @@ def test_issue_2448(self, mode): u_tau2 = Eq(tau2.forward, solve(pde_tau2, tau2.forward)) # Test two variants of receiver interpolation - nrec = 1 - rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=nrec, nt=tn) - rec2.coordinates.data[:, 0] = np.linspace(0., shape[0], num=nrec) + rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=1, nt=tn) + rec2.coordinates.data[:, 0] = np.linspace(0., shape[0], num=1) # The receiver 2 rec_term2 = rec2.interpolate(expr=v2) @@ -1135,6 +1132,56 @@ def test_issue_2448(self, mode): assert calls[3].arguments[0] is v2 assert calls[4].arguments[0] is v2 + @pytest.mark.parallel(mode=1) + def test_issue_2448_backward(self, mode): + ''' + Similar to test_issue_2448, but with backward instead of forward + so that the hoisted halo + + ''' + shape = (2,) + so = 2 + + grid = Grid(shape=shape) + t = grid.stepping_dim + + tn = 7 + + # Velocity and pressure fields + v = TimeFunction(name='v', grid=grid, space_order=so) + v.data_with_halo[0, :] = 1. + v.data_with_halo[1, :] = 3. + + tau = TimeFunction(name='tau', grid=grid, space_order=so) + tau.data_with_halo[:] = 1. + + # First order elastic-like dependencies equations + pde_v = v.dt - (tau.dx) + pde_tau = tau.dt - ((v.backward).dx) + + u_v = Eq(v.backward, solve(pde_v, v)) + u_tau = Eq(tau.backward, solve(pde_tau, tau)) + + # Test two variants of receiver interpolation + nrec = 1 + rec = SparseTimeFunction(name="rec", grid=grid, npoint=nrec, nt=tn) + rec.coordinates.data[:, 0] = np.linspace(0., shape[0], num=nrec) + + # Test receiver interpolation 0, here we have a halo exchange hoisted + op0 = Operator([u_v] + [u_tau] + rec.interpolate(expr=v)) + + calls = [i for i in FindNodes(Call).visit(op0) + if isinstance(i, HaloUpdateCall)] + + # The correct we want + assert len(calls) == 3 + assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[0])) == 1 + assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[1])) == 2 + assert calls[0].arguments[0] is v + assert calls[0].arguments[3].args[0] is t.symbolic_max + assert calls[1].arguments[0] is tau + assert calls[2].arguments[0] is v + @pytest.mark.parallel(mode=1) def test_avoid_haloupdate_with_subdims(self, mode): grid = Grid(shape=(4,)) @@ -2862,35 +2909,30 @@ class TestElastic: def test_elastic_structure(self, mode): so = 4 - model = demo_model(preset='layers-elastic', nlayers=1, - shape=(301, 301), spacing=(10., 10.), - space_order=so) + grid = Grid(shape=(3, 3)) - v = VectorTimeFunction(name='v', grid=model.grid, space_order=so) - tau = TensorTimeFunction(name='t', grid=model.grid, space_order=so) + v = VectorTimeFunction(name='v', grid=grid, space_order=so) + tau = TensorTimeFunction(name='t', grid=grid, space_order=so) - # The receiver - nrec = 301 - rec = SparseTimeFunction(name="rec", grid=model.grid, npoint=nrec, nt=10) - rec.coordinates.data[:, 0] = np.linspace(0., model.domain_size[0], num=nrec) - rec.coordinates.data[:, -1] = 5. + damp = Function(name='damp', grid=grid) + l = Function(name='lam', grid=grid) + mu = Function(name='mu', grid=grid) + ro = Function(name='b', grid=grid) + # The receiver + rec = SparseTimeFunction(name="rec", grid=grid, npoint=1, nt=10) rec_term = rec.interpolate(expr=v[0] + v[1]) - # Now let's try and create the staggered updates - # Lame parameters - l, mu, ro = model.lam, model.mu, model.b - # First order elastic wave equation pde_v = v.dt - ro * div(tau) pde_tau = (tau.dt - l * diag(div(v.forward)) - mu * (grad(v.forward) + grad(v.forward).transpose(inner=False))) + # Time update - u_v = Eq(v.forward, model.damp * solve(pde_v, v.forward)) - u_t = Eq(tau.forward, model.damp * solve(pde_tau, tau.forward)) + u_v = Eq(v.forward, damp * solve(pde_v, v.forward)) + u_t = Eq(tau.forward, damp * solve(pde_tau, tau.forward)) op = Operator([u_v] + [u_t] + rec_term) - assert len(op._func_table) == 11 calls = [i for i in FindNodes(Call).visit(op) if isinstance(i, HaloUpdateCall)] @@ -2898,17 +2940,15 @@ def test_elastic_structure(self, mode): # The correct we want assert len(calls) == 5 - assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[3].body[0])) == 1 - assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[3].body[1])) == 4 - assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[3].body[2])) == 0 + assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[1].body[0])) == 1 + assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[1].body[1])) == 4 + assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[1].body[2])) == 0 assert calls[0].arguments[0] is v[0] assert calls[0].arguments[1] is v[1] - assert calls[1].arguments[0] is tau[0, 0] assert calls[2].arguments[0] is tau[0, 1] assert calls[3].arguments[0] is tau[1, 1] - assert calls[4].arguments[0] is v[0] assert calls[4].arguments[1] is v[1]