diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 9bcc3460f6..57becab5ff 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -617,6 +617,14 @@ def bounds(self, _min=None, _max=None): return (_min, _max) + @property + def start(self): + """The start value.""" + if self.direction is Forward: + return self.dim.symbolic_min + else: + return self.dim.symbolic_max + @property def step(self): """The step value.""" diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 505fe2e001..9b068a7d25 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -873,7 +873,6 @@ def default_retval(cls): the nodes of type ``child_types`` retrieved by the search. This behaviour can be changed through this parameter. Accepted values are: - 'immediate': only the closest matching ancestor is mapped. - - 'groupby': the matching ancestors are grouped together as a single key. """ def __init__(self, parent_type=None, child_types=None, mode=None): @@ -886,7 +885,7 @@ def __init__(self, parent_type=None, child_types=None, mode=None): assert issubclass(parent_type, Node) self.parent_type = parent_type self.child_types = as_tuple(child_types) or (Call, Expression) - assert mode in (None, 'immediate', 'groupby') + assert mode in (None, 'immediate') self.mode = mode def visit_object(self, o, ret=None, **kwargs): @@ -903,9 +902,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False): if parents is None: parents = [] if isinstance(o, self.child_types): - if self.mode == 'groupby': - ret.setdefault(as_tuple(parents), []).append(o) - elif self.mode == 'immediate': + if self.mode == 'immediate': if in_parent: ret.setdefault(parents[-1], []).append(o) else: diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index bb558c8d58..6f59b442e8 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -25,7 +25,7 @@ class IndexMode(Tag): REGULAR = IndexMode('regular') IRREGULAR = IndexMode('irregular') -# Symbols to create mock data depdendencies +# Symbols to create mock data dependencies mocksym0 = Symbol(name='__⋈_0__') mocksym1 = Symbol(name='__⋈_1__') diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index 7074b4c5d5..09cf98ac6a 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -11,7 +11,7 @@ from devito.ir.support import Forward, Scope from devito.symbolics.manipulation import _uxreplace_registry from devito.tools import (Reconstructable, Tag, as_tuple, filter_ordered, flatten, - frozendict, is_integer, filter_sorted) + frozendict, is_integer, filter_sorted, EnrichedTuple) from devito.types import Grid __all__ = ['HaloScheme', 'HaloSchemeEntry', 'HaloSchemeException', 'HaloTouch'] @@ -28,7 +28,22 @@ class HaloLabel(Tag): STENCIL = HaloLabel('stencil') -HaloSchemeEntry = namedtuple('HaloSchemeEntry', 'loc_indices loc_dirs halos dims') +class HaloSchemeEntry(EnrichedTuple): + + __rargs__ = ('loc_indices', 'loc_dirs', 'halos', 'dims') + + def __init__(self, loc_indices, loc_dirs, halos, dims, getters=None): + self.loc_indices = frozendict(loc_indices) + self.loc_dirs = frozendict(loc_dirs) + self.halos = frozenset(halos) + self.dims = frozenset(dims) + + def __hash__(self): + return hash((self.loc_indices, + self.loc_dirs, + self.halos, + self.dims)) + Halo = namedtuple('Halo', 'dim side') @@ -121,7 +136,10 @@ def union(self, halo_schemes): Create a new HaloScheme from the union of a set of HaloSchemes. """ halo_schemes = [hs for hs in halo_schemes if hs is not None] - if not halo_schemes: + + if len(halo_schemes) == 1: + return halo_schemes[0] + elif not halo_schemes: return None fmapper = {} @@ -365,6 +383,10 @@ def distributed_aindices(self): def loc_indices(self): return set().union(*[i.loc_indices.keys() for i in self.fmapper.values()]) + @cached_property + def loc_values(self): + return set().union(*[i.loc_indices.values() for i in self.fmapper.values()]) + @cached_property def arguments(self): return self.dimensions | set(flatten(self.honored.values())) @@ -503,8 +525,6 @@ def classify(exprs, ispace): loc_indices, loc_dirs = process_loc_indices(raw_loc_indices, ispace.directions) - halos = frozenset(halos) - dims = frozenset(dims) mapper[f] = HaloSchemeEntry(loc_indices, loc_dirs, halos, dims) @@ -556,7 +576,7 @@ def process_loc_indices(raw_loc_indices, directions): known = set().union(*[i._defines for i in loc_indices]) loc_dirs = {d: v for d, v in directions.items() if d in known} - return frozendict(loc_indices), frozendict(loc_dirs) + return loc_indices, loc_dirs class HaloTouch(sympy.Function, Reconstructable): @@ -634,9 +654,7 @@ def _uxreplace_dispatch_haloscheme(hs0, rule): # Nope, let's try with the next Indexed, if any continue - hse = HaloSchemeEntry(frozendict(loc_indices), - frozendict(loc_dirs), - hse0.halos, hse0.dims) + hse = hse0._rebuild(loc_indices=loc_indices, loc_dirs=loc_dirs) else: continue diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 9f49440709..930cc108b2 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -1,6 +1,7 @@ from collections import defaultdict from sympy import S +from itertools import combinations from devito.ir.iet import (Call, Expression, HaloSpot, Iteration, FindNodes, MapNodes, MapHaloSpots, Transformer, @@ -19,11 +20,11 @@ @iet_pass def optimize_halospots(iet, **kwargs): """ - Optimize the HaloSpots in ``iet``. HaloSpots may be dropped, merged and moved - around in order to improve the halo exchange performance. + Optimize the HaloSpots in ``iet``. HaloSpots may be dropped, hoisted, + merged and moved around in order to improve the halo exchange performance. """ - iet = _drop_halospots(iet) - iet = _hoist_halospots(iet) + iet = _drop_reduction_halospots(iet) + iet = _hoist_invariant(iet) iet = _merge_halospots(iet) iet = _drop_if_unwritten(iet, **kwargs) iet = _mark_overlappable(iet) @@ -31,12 +32,10 @@ def optimize_halospots(iet, **kwargs): return iet, {} -def _drop_halospots(iet): +def _drop_reduction_halospots(iet): """ - Remove HaloSpots that: - - * Would be used to compute Increments (in which case, a halo exchange - is actually unnecessary) + Remove HaloSpots that are used to compute Increments + (in which case, a halo exchange is actually unnecessary) """ mapper = defaultdict(set) @@ -48,35 +47,29 @@ def _drop_halospots(iet): mapper[hs].add(f) # Transform the IET introducing the "reduced" HaloSpots - subs = {hs: hs._rebuild(halo_scheme=hs.halo_scheme.drop(mapper[hs])) - for hs in FindNodes(HaloSpot).visit(iet)} - iet = Transformer(subs, nested=True).visit(iet) + mapper = {hs: hs._rebuild(halo_scheme=hs.halo_scheme.drop(mapper[hs])) + for hs in FindNodes(HaloSpot).visit(iet)} + iet = Transformer(mapper, nested=True).visit(iet) return iet -def _hoist_halospots(iet): +def _hoist_invariant(iet): """ Hoist HaloSpots from inner to outer Iterations where all data dependencies - would be honored. - """ - - # Hoisting rules -- if the retval is True, then it means the input `dep` is not - # a stopper to halo hoisting + would be honored. This pass avoids redundant halo exchanges when the same + data is redundantly exchanged within the same Iteration tree level. + + Example: + haloupd v[t0] + for time for time + W v[t1]- R v[t0] W v[t1]- R v[t0] + haloupd v[t1] haloupd v[t1] + R v[t1] R v[t1] + haloupd v[t0] R v[t0] + R v[t0] - def rule0(dep, candidates, loc_dims): - # E.g., `dep=W -> R` and `candidates=({time}, {x})` => False - # E.g., `dep=W -> R`, `dep.cause={t,time}` and - # `candidates=({x},)` => True - return (all(i & set(dep.distance_mapper) for i in candidates) and - not any(i & dep.cause for i in candidates) and - not any(i & loc_dims for i in candidates)) - - def rule1(dep, candidates, loc_dims): - # A reduction isn't a stopper to hoisting - return dep.write is not None and dep.write.is_reduction - - rules = [rule0, rule1] + """ # Precompute scopes to save time scopes = {i: Scope([e.expr for e in v]) for i, v in MapNodes().visit(iet).items()} @@ -84,106 +77,86 @@ def rule1(dep, candidates, loc_dims): # Analysis hsmapper = {} imapper = defaultdict(list) - for iters, halo_spots in MapNodes(Iteration, HaloSpot, 'groupby').visit(iet).items(): - for hs in halo_spots: - hsmapper[hs] = hs.halo_scheme - - for f, v in hs.fmapper.items(): - loc_dims = frozenset().union([q for d in v.loc_indices - for q in d._defines]) - - for n, i in enumerate(iters): - if i not in scopes: - continue - - candidates = [i.dim._defines for i in iters[n:]] - - all_candidates = set().union(*candidates) - reads = scopes[i].getreads(f) - if any(set(a.ispace.dimensions) & all_candidates - for a in reads): - continue - - for dep in scopes[i].d_flow.project(f): - if not any(r(dep, candidates, loc_dims) for r in rules): - break - else: - hsmapper[hs] = hsmapper[hs].drop(f) - imapper[i].append(hs.halo_scheme.project(f)) + + cond_mapper = _make_cond_mapper(iet) + iter_mapper = _filter_iter_mapper(iet) + + for it, halo_spots in iter_mapper.items(): + for hs0, hs1 in combinations(halo_spots, r=2): + + if _check_control_flow(hs0, hs1, cond_mapper): + continue + + # If there are overlapping loc_indices, skip + hs0_mdims = hs0.halo_scheme.loc_values + hs1_mdims = hs1.halo_scheme.loc_values + if hs0_mdims.intersection(hs1_mdims): + continue + + for f, v in hs1.fmapper.items(): + if f not in hs0.functions: + continue + + for dep in scopes[it].d_flow.project(f): + if not any(r(dep, hs1, v.loc_indices) for r in rules): break + else: + # `hs1`` can be hoisted out of `it`, but we need to infer valid + # loc_indices + hse = hs1.halo_scheme.fmapper[f] + loc_indices = {} + + for d, v in hse.loc_indices.items(): + if v in it.uindices: + loc_indices[d] = v.symbolic_min.subs(it.dim, it.start) + else: + loc_indices[d] = v + + hse = hse._rebuild(loc_indices=loc_indices) + hs1.halo_scheme.fmapper[f] = hse + + hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f) + imapper[it].append(hs1.halo_scheme.project(f)) - # Post-process analysis mapper = {i: HaloSpot(i._rebuild(), HaloScheme.union(hss)) for i, hss in imapper.items()} mapper.update({i: i.body if hs.is_void else i._rebuild(halo_scheme=hs) for i, hs in hsmapper.items()}) - # Transform the IET hoisting/dropping HaloSpots as according to the analysis iet = Transformer(mapper, nested=True).visit(iet) - - # Clean up: de-nest HaloSpots if necessary - mapper = {} - for hs in FindNodes(HaloSpot).visit(iet): - if hs.body.is_HaloSpot: - halo_scheme = HaloScheme.union([hs.halo_scheme, hs.body.halo_scheme]) - mapper[hs] = hs._rebuild(halo_scheme=halo_scheme, body=hs.body.body) - iet = Transformer(mapper, nested=True).visit(iet) - return iet def _merge_halospots(iet): """ Merge HaloSpots on the same Iteration tree level where all data dependencies - would be honored. - """ - - # Merge rules -- if the retval is True, then it means the input `dep` is not - # a stopper to halo merging + would be honored. Avoids redundant halo exchanges when the same data is + redundantly exchanged within the same Iteration tree level as well as to initiate + multiple halo exchanges at once. - def rule0(dep, hs, loc_indices): - # E.g., `dep=W -> R` => True - return not any(d in hs.dimensions or dep.distance_mapper[d] is S.Infinity - for d in dep.cause) + Example: - def rule1(dep, hs, loc_indices): - # TODO This is apparently never hit, but feeling uncomfortable to remove it - return (dep.is_regular and - dep.read is not None and - all(not any(dep.read.touched_halo(d.root)) for d in dep.cause)) + for time for time + haloupd v[t0] haloupd v[t0], h + W v[t1]- R v[t0] W v[t1]- R v[t0] + haloupd v[t0], h + W g[t1]- R v[t0], h W g[t1]- R v[t0], h - def rule2(dep, hs, loc_indices): - # E.g., `dep=W -> R` and `loc_indices={t: t0}` => True - return any(dep.distance_mapper[d] == 0 and dep.source[d] is not v - for d, v in loc_indices.items()) - - rules = [rule0, rule1, rule2] + """ # Analysis - cond_mapper = MapHaloSpots().visit(iet) - cond_mapper = {hs: {i for i in v if i.is_Conditional and - not isinstance(i.condition, GuardFactorEq)} - for hs, v in cond_mapper.items()} - - iter_mapper = MapNodes(Iteration, HaloSpot, 'immediate').visit(iet) - mapper = {} - for i, halo_spots in iter_mapper.items(): - if i is None or len(halo_spots) <= 1: - continue + cond_mapper = _make_cond_mapper(iet) + iter_mapper = _filter_iter_mapper(iet) - scope = Scope([e.expr for e in FindNodes(Expression).visit(i)]) + for it, halo_spots in iter_mapper.items(): + scope = Scope([e.expr for e in FindNodes(Expression).visit(it)]) hs0 = halo_spots[0] - mapper[hs0] = hs0.halo_scheme for hs1 in halo_spots[1:]: - mapper[hs1] = hs1.halo_scheme - # If there are Conditionals involved, both `hs0` and `hs1` must be - # within the same Conditional, otherwise we would break the control - # flow semantics - if cond_mapper.get(hs0) != cond_mapper.get(hs1): + if _check_control_flow(hs0, hs1, cond_mapper): continue for f, v in hs1.fmapper.items(): @@ -191,14 +164,10 @@ def rule2(dep, hs, loc_indices): if not any(r(dep, hs1, v.loc_indices) for r in rules): break else: - try: - hs = hs1.halo_scheme.project(f) - mapper[hs0] = HaloScheme.union([mapper[hs0], hs]) - mapper[hs1] = mapper[hs1].drop(f) - except ValueError: - # `hs1.loc_indices= 1: + iter_mapper[k] = filtered_hs + + return iter_mapper + + +def _make_cond_mapper(iet): + """ + Return a mapper from HaloSpots to the Conditionals that contain them. + """ + cond_mapper = {} + for hs, v in MapHaloSpots().visit(iet).items(): + conditionals = {i for i in v if i.is_Conditional and + not isinstance(i.condition, GuardFactorEq)} + cond_mapper[hs] = conditionals + + return cond_mapper + + +def _check_control_flow(hs0, hs1, cond_mapper): + """ + If there are Conditionals involved, both `hs0` and `hs1` must be + within the same Conditional, otherwise we would break control flow + """ + cond0 = cond_mapper.get(hs0) + cond1 = cond_mapper.get(hs1) + + return cond0 != cond1 + + +# Code motion rules -- if the retval is True, then it means the input `dep` is not +# a stopper to moving the HaloSpot `hs` around + +def _rule0(dep, hs, loc_indices): + # E.g., `dep=W -> R` => True + return not any(d in hs.dimensions or dep.distance_mapper[d] is S.Infinity + for d in dep.cause) + + +def _rule1(dep, hs, loc_indices): + # E.g., `dep=W -> R` and `loc_indices={t: t0}` => True + return any(dep.distance_mapper[d] == 0 and dep.source[d] is not v + for d, v in loc_indices.items()) + + +rules = (_rule0, _rule1) diff --git a/examples/mpi/overview.ipynb b/examples/mpi/overview.ipynb index df57022cbe..7e8926aa1f 100644 --- a/examples/mpi/overview.ipynb +++ b/examples/mpi/overview.ipynb @@ -107,7 +107,10 @@ "source": [ "%%px\n", "from devito import configuration\n", - "configuration['mpi'] = True" + "configuration['mpi'] = True\n", + "\n", + "# Feel free to change the log level, and see more detailed logging\n", + "configuration['log-level'] = 'INFO'" ] }, { diff --git a/examples/seismic/abc_methods/03_pml.ipynb b/examples/seismic/abc_methods/03_pml.ipynb index d3ca9929b1..b0ab62839c 100644 --- a/examples/seismic/abc_methods/03_pml.ipynb +++ b/examples/seismic/abc_methods/03_pml.ipynb @@ -179,10 +179,8 @@ "# NBVAL_IGNORE_OUTPUT\n", "\n", "%matplotlib inline\n", - "from examples.seismic import TimeAxis\n", - "from examples.seismic import RickerSource\n", - "from examples.seismic import Receiver\n", - "from devito import SubDomain, Grid, NODE, TimeFunction, Function, Eq, solve, Operator" + "from devito import SubDomain, Grid, NODE, TimeFunction, Function, Eq, solve, Operator\n", + "from examples.seismic import TimeAxis, RickerSource, Receiver" ] }, { @@ -205,7 +203,7 @@ "compx = x1-x0\n", "z0 = 0.\n", "z1 = 1000.\n", - "compz = z1-z0;\n", + "compz = z1-z0\n", "hxv = (x1-x0)/(nptx-1)\n", "hzv = (z1-z0)/(nptz-1)" ] diff --git a/examples/seismic/tutorials/09_viscoelastic.ipynb b/examples/seismic/tutorials/09_viscoelastic.ipynb index 5d7cabdd75..0d1d524613 100644 --- a/examples/seismic/tutorials/09_viscoelastic.ipynb +++ b/examples/seismic/tutorials/09_viscoelastic.ipynb @@ -457,7 +457,7 @@ "source": [ "# References\n", "\n", - "[1] Johan O. A. Roberston, *et.al.* (1994). \"Viscoelatic finite-difference modeling\" GEOPHYSICS, 59(9), 1444-1456.\n", + "[1] Johan O. A. Roberston, *et.al.* (1994). \"Viscoelastic finite-difference modeling\" GEOPHYSICS, 59(9), 1444-1456.\n", "\n", "\n", "[2] https://janth.home.xs4all.nl/Software/fdelmodcManual.pdf" diff --git a/examples/seismic/viscoacoustic/operators.py b/examples/seismic/viscoacoustic/operators.py index 1253b61686..d237d43ea6 100755 --- a/examples/seismic/viscoacoustic/operators.py +++ b/examples/seismic/viscoacoustic/operators.py @@ -524,7 +524,7 @@ def ForwardOperator(model, geometry, space_order=4, kernel='sls', time_order=2, # Substitute spacing terms to reduce flops return Operator(eqn + src_term + rec_term, subs=model.spacing_map, - name='Forward', **kwargs) + name='ViscoIsoAcousticForward', **kwargs) def AdjointOperator(model, geometry, space_order=4, kernel='SLS', time_order=2, **kwargs): diff --git a/examples/seismic/viscoelastic/operators.py b/examples/seismic/viscoelastic/operators.py index bc2f5642a1..9e0269665a 100644 --- a/examples/seismic/viscoelastic/operators.py +++ b/examples/seismic/viscoelastic/operators.py @@ -64,4 +64,4 @@ def ForwardOperator(model, geometry, space_order=4, save=False, **kwargs): # Substitute spacing terms to reduce flops return Operator([u_v, u_r, u_t] + src_rec_expr, subs=model.spacing_map, - name='Forward', **kwargs) + name='ViscoIsoElasticForward', **kwargs) diff --git a/tests/test_dse.py b/tests/test_dse.py index 33f93bf5e5..989e22929a 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -11,7 +11,7 @@ SparseTimeFunction, Dimension, SubDimension, ConditionalDimension, DefaultDimension, Grid, Operator, norm, grad, div, dimensions, switchconfig, configuration, - centered, first_derivative, solve, transpose, Abs, cos, + first_derivative, solve, transpose, Abs, cos, sin, sqrt, floor, Ge, Lt, Derivative) from devito.exceptions import InvalidArgument, InvalidOperator from devito.ir import (Conditional, DummyEq, Expression, Iteration, FindNodes, diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 74e09575c1..9d2032c472 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -4,7 +4,8 @@ from conftest import _R, assert_blocking, assert_structure from devito import (Grid, Constant, Function, TimeFunction, SparseFunction, - SparseTimeFunction, Dimension, ConditionalDimension, div, + SparseTimeFunction, VectorTimeFunction, TensorTimeFunction, + Dimension, ConditionalDimension, div, solve, diag, grad, SubDimension, SubDomain, Eq, Ne, Inc, NODE, Operator, norm, inner, configuration, switchconfig, generic_derivative, PrecomputedSparseFunction, DefaultDimension, Buffer) @@ -19,6 +20,7 @@ from devito.tools import Bunch from examples.seismic.acoustic import acoustic_setup +from tests.test_dse import TestTTI class TestDistributor: @@ -980,6 +982,7 @@ def test_avoid_redundant_haloupdate_cond(self, mode): calls = FindNodes(Call).visit(op) assert len(calls) == 1 + assert calls[0].functions[0] is f @pytest.mark.parallel(mode=1) def test_avoid_haloupdate_if_distr_but_sequential(self, mode): @@ -1220,7 +1223,7 @@ def test_avoid_fullmode_if_crossloop_dep(self, mode): assert np.all(f.data[:] == 2.) @pytest.mark.parallel(mode=2) - def test_avoid_haloudate_if_flowdep_along_other_dim(self, mode): + def test_avoid_haloupdate_if_flowdep_along_other_dim(self, mode): grid = Grid(shape=(10,)) x = grid.dimensions[0] t = grid.stepping_dim @@ -1296,7 +1299,7 @@ def test_unmerge_haloupdate_if_no_locindices(self, mode): assert np.allclose(g.data_ro_domain[0, 5:], [16., 16., 14., 13., 6.], rtol=R) @pytest.mark.parallel(mode=1) - def test_merge_haloupdate_if_diff_locindices_v0(self, mode): + def test_merge_haloupdate_if_diff_locindices(self, mode): grid = Grid(shape=(101, 101)) x, y = grid.dimensions t = grid.stepping_dim @@ -1317,11 +1320,12 @@ def test_merge_haloupdate_if_diff_locindices_v0(self, mode): op.cfunction @pytest.mark.parallel(mode=2) - def test_merge_haloupdate_if_diff_locindices_v1(self, mode): + def test_merge_and_hoist_haloupdate_if_diff_locindices(self, mode): """ This test is a revisited, more complex version of - `test_merge_haloupdate_if_diff_locindices_v0`. And in addition to - checking the generated code, it also checks the numerical output. + `test_merge_haloupdate_if_diff_locindices`, also checking hoisting. + And in addition to checking the generated code, + it also checks the numerical output. In the Operator there are three Eqs: @@ -1333,8 +1337,10 @@ def test_merge_haloupdate_if_diff_locindices_v1(self, mode): * the second and third Eqs cannot be fused in the same loop - In the IET we end up with *one* HaloSpots, placed right before the - second Eq. The third Eq will seamlessy find its halo up-to-date. + In the IET we end up with *two* HaloSpots, one placed before the + time loop, and one placed before the second Eq. The third Eq, + reading from f[t0], will seamlessy find its halo up-to-date, + due to the f[t1] being updated in the previous time iteration. """ grid = Grid(shape=(10,)) x = grid.dimensions[0] @@ -1357,9 +1363,15 @@ def test_merge_haloupdate_if_diff_locindices_v1(self, mode): op = Operator(eqns) calls = FindNodes(Call).visit(op) - assert len(calls) == 1 + assert len(calls) == 2 + assert calls[0].arguments[3].args[0] is t.symbolic_min + + assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[0])) == 1 + assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[1])) == 1 + assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[2])) == 0 op.apply(time_M=1) + glb_pos_map = f.grid.distributor.glb_pos_map R = 1e-07 # Can't use np.all due to rounding error at the tails if LEFT in glb_pos_map[x]: @@ -2390,7 +2402,6 @@ def test_staggering(self, mode): op = Operator(eqns) op(time_M=2) - # Expected norms computed "manually" from sequential runs assert np.isclose(norm(ux), 7003.098, rtol=1.e-4) assert np.isclose(norm(uxx), 78902.21, rtol=1.e-4) @@ -2727,6 +2738,245 @@ def test_adjoint_F_no_omp(self, mode): self.run_adjoint_F(3) +class TestElasticLike: + + @pytest.mark.parallel(mode=[(1, 'diag')]) + def test_elastic_structure(self, mode): + + so = 4 + grid = Grid(shape=(3, 3)) + + v = VectorTimeFunction(name='v', grid=grid, space_order=so) + tau = TensorTimeFunction(name='t', grid=grid, space_order=so) + + damp = Function(name='damp', grid=grid) + l = Function(name='lam', grid=grid) + mu = Function(name='mu', grid=grid) + ro = Function(name='b', grid=grid) + + rec = SparseTimeFunction(name="rec", grid=grid, npoint=1, nt=10) + rec_term = rec.interpolate(expr=v[0] + v[1]) + + # First order elastic wave equation + pde_v = v.dt - ro * div(tau) + pde_tau = (tau.dt - l * diag(div(v.forward)) - + mu * (grad(v.forward) + grad(v.forward).transpose(inner=False))) + + # Time update + u_v = Eq(v.forward, damp * solve(pde_v, v.forward)) + u_t = Eq(tau.forward, damp * solve(pde_tau, tau.forward)) + + op = Operator([u_v] + [u_t] + rec_term) + assert len(op._func_table) == 11 + + calls = [i for i in FindNodes(Call).visit(op) if isinstance(i, HaloUpdateCall)] + + assert len(calls) == 5 + + assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[1].body[0])) == 1 + assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[1].body[1])) == 4 + assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[1].body[2])) == 0 + + assert calls[0].arguments[0] is v[0] + assert calls[0].arguments[1] is v[1] + assert calls[1].arguments[0] is tau[0, 0] + assert calls[2].arguments[0] is tau[0, 1] + assert calls[3].arguments[0] is tau[1, 1] + assert calls[4].arguments[0] is v[0] + assert calls[4].arguments[1] is v[1] + + @pytest.fixture + def setup(self): + """ + This fixture sets up the grid, fields, elastic-like + equations and receivers for test_issue_2448_*. + """ + shape = (2,) + so = 2 + tn = 30 + + grid = Grid(shape=shape) + + # Velocity and pressure fields + v = TimeFunction(name='v', grid=grid, space_order=so) + tau = TimeFunction(name='tau', grid=grid, space_order=so) + + # First order elastic-like dependencies equations + pde_v = v.dt - (tau.dx) + pde_tau = tau.dt - ((v.forward).dx) + u_v = Eq(v.forward, solve(pde_v, v.forward)) + u_tau = Eq(tau.forward, solve(pde_tau, tau.forward)) + + rec = SparseTimeFunction(name="rec", grid=grid, npoint=1, nt=tn) + rec.coordinates.data[:, 0] = np.linspace(0., shape[0], num=1) + + return grid, v, tau, u_v, u_tau, rec + + @pytest.mark.parallel(mode=1) + def test_issue_2448_v0(self, mode, setup): + _, v, tau, u_v, u_tau, rec = setup + + rec_term0 = rec.interpolate(expr=v) + + op0 = Operator([u_v, u_tau, rec_term0]) + + calls = [i for i in FindNodes(Call).visit(op0) if isinstance(i, HaloUpdateCall)] + + assert len(calls) == 3 + assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[0])) == 1 + assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[1])) == 2 + assert calls[0].arguments[0] is v + assert calls[1].arguments[0] is tau + assert calls[2].arguments[0] is v + + @pytest.mark.parallel(mode=1) + def test_issue_2448_v1(self, mode, setup): + _, v, tau, u_v, u_tau, rec = setup + + rec_term1 = rec.interpolate(expr=v.forward) + + op1 = Operator([u_v, u_tau, rec_term1]) + + calls = [i for i in FindNodes(Call).visit(op1) if isinstance(i, HaloUpdateCall)] + + assert len(calls) == 3 + assert len(FindNodes(HaloUpdateCall).visit(op1.body.body[1].body[0])) == 0 + assert len(FindNodes(HaloUpdateCall).visit(op1.body.body[1].body[1])) == 3 + assert calls[0].arguments[0] is tau + assert calls[1].arguments[0] is v + assert calls[2].arguments[0] is v + + @pytest.mark.parallel(mode=1) + def test_issue_2448_v2(self, mode, setup): + grid, v, tau, u_v, u_tau, rec = setup + + # Additional velocity and pressure fields + v2 = TimeFunction(name='v2', grid=grid, space_order=2) + tau2 = TimeFunction(name='tau2', grid=grid, space_order=2) + + # First order elastic-like dependencies equations + pde_v2 = v2.dt - (tau2.dx) + pde_tau2 = tau2.dt - ((v2.forward).dx) + u_v2 = Eq(v2.forward, solve(pde_v2, v2.forward)) + u_tau2 = Eq(tau2.forward, solve(pde_tau2, tau2.forward)) + + rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=1, nt=30) + rec2.coordinates.data[:, 0] = np.linspace(0., grid.shape[0], num=1) + + rec_term0 = rec.interpolate(expr=v) + rec_term2 = rec2.interpolate(expr=v2) + + op2 = Operator([u_v, u_v2, u_tau, u_tau2, rec_term0, rec_term2]) + + calls = [i for i in FindNodes(Call).visit(op2) if isinstance(i, HaloUpdateCall)] + + assert len(calls) == 5 + assert len(FindNodes(HaloUpdateCall).visit(op2.body.body[1].body[1].body[0])) == 2 + assert len(FindNodes(HaloUpdateCall).visit(op2.body.body[1].body[1].body[1])) == 3 + assert calls[0].arguments[0] is v + assert calls[1].arguments[0] is v2 + assert calls[2].arguments[0] is tau + assert calls[2].arguments[1] is tau2 + assert calls[3].arguments[0] is v + assert calls[4].arguments[0] is v2 + + @pytest.mark.parallel(mode=1) + def test_issue_2448_v3(self, mode, setup): + grid, v, tau, u_v, u_tau, rec = setup + + # Additional velocity and pressure fields + v2 = TimeFunction(name='v2', grid=grid, space_order=2) + tau2 = TimeFunction(name='tau2', grid=grid, space_order=2) + + # First order elastic-like dependencies equations + pde_v2 = v2.dt - (tau2.dx) + pde_tau2 = tau2.dt - ((v2.forward).dx) + u_v2 = Eq(v2.forward, solve(pde_v2, v2.forward)) + u_tau2 = Eq(tau2.forward, solve(pde_tau2, tau2.forward)) + + rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=1, nt=30) + rec2.coordinates.data[:, 0] = np.linspace(0., grid.shape[0], num=1) + + rec_term0 = rec.interpolate(expr=v) + rec_term3 = rec2.interpolate(expr=v2.forward) + + op3 = Operator([u_v, u_v2, u_tau, u_tau2, rec_term0, rec_term3]) + + calls = [i for i in FindNodes(Call).visit(op3) if isinstance(i, HaloUpdateCall)] + + assert len(calls) == 5 + assert len(FindNodes(HaloUpdateCall).visit(op3.body.body[1].body[1].body[0])) == 1 + assert len(FindNodes(HaloUpdateCall).visit(op3.body.body[1].body[1].body[1])) == 4 + assert calls[0].arguments[0] is v + assert calls[1].arguments[0] is tau + assert calls[1].arguments[1] is tau2 + assert calls[2].arguments[0] is v + assert calls[3].arguments[0] is v2 + assert calls[4].arguments[0] is v2 + + @pytest.mark.parallel(mode=1) + def test_issue_2448_backward(self, mode): + ''' + Similar to test_issue_2448, but with backward instead of forward + so that the hoisted halo has different starting point + ''' + shape = (2,) + so = 2 + + grid = Grid(shape=shape) + t = grid.stepping_dim + + tn = 7 + + # Velocity and pressure fields + v = TimeFunction(name='v', grid=grid, space_order=so) + v.data_with_halo[0, :] = 1. + v.data_with_halo[1, :] = 3. + + tau = TimeFunction(name='tau', grid=grid, space_order=so) + tau.data_with_halo[:] = 1. + + # First order elastic-like dependencies equations + pde_v = v.dt - (tau.dx) + pde_tau = tau.dt - ((v.backward).dx) + + u_v = Eq(v.backward, solve(pde_v, v)) + u_tau = Eq(tau.backward, solve(pde_tau, tau)) + + # Test two variants of receiver interpolation + nrec = 1 + rec = SparseTimeFunction(name="rec", grid=grid, npoint=nrec, nt=tn) + rec.coordinates.data[:, 0] = np.linspace(0., shape[0], num=nrec) + + # Test receiver interpolation 0, here we have a halo exchange hoisted + op0 = Operator([u_v] + [u_tau] + rec.interpolate(expr=v)) + + calls = [i for i in FindNodes(Call).visit(op0) + if isinstance(i, HaloUpdateCall)] + + assert len(calls) == 3 + assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[0])) == 1 + assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[1])) == 2 + assert calls[0].arguments[0] is v + assert calls[0].arguments[3].args[0] is t.symbolic_max + assert calls[1].arguments[0] is tau + assert calls[2].arguments[0] is v + + +class TestTTIOp: + + @pytest.mark.parallel(mode=1) + def test_halo_structure(self, mode): + solver = TestTTI().tti_operator(opt='advanced', space_order=8) + op = solver.op_fwd(save=False) + + calls = [i for i in FindNodes(Call).visit(op) if isinstance(i, HaloUpdateCall)] + + assert len(calls) == 1 + assert calls[0].functions[0].name == 'u' + assert calls[0].functions[1].name == 'v' + + if __name__ == "__main__": # configuration['mpi'] = 'overlap' # TestDecomposition().test_reshape_left_right()