From fee66c89b97ec318eca0c8ff007937b2bb6d7cdb Mon Sep 17 00:00:00 2001 From: George Bisbas Date: Tue, 5 Nov 2024 18:57:52 +0200 Subject: [PATCH] compiler: Inline halo_to_halo --- devito/mpi/halo_scheme.py | 5 ++- devito/passes/iet/mpi.py | 92 ++++++++++++++++++--------------------- tests/test_mpi.py | 35 ++++++--------- 3 files changed, 59 insertions(+), 73 deletions(-) diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index f6afa3a07e..cc9edbe416 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -162,7 +162,10 @@ def union(self, halo_schemes): Create a new HaloScheme from the union of a set of HaloSchemes. """ halo_schemes = [hs for hs in halo_schemes if hs is not None] - if not halo_schemes: + + if len(halo_schemes) == 1: + return halo_schemes[0] + elif not halo_schemes: return None fmapper = {} diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 2aff3fd5c9..323b390c7f 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -63,7 +63,7 @@ def _hoist_halospots(iet): Example: haloupd v[t0] for time for time - W v[t1]- R v[t0] W v[t1]- R v[t0] + W v[t1]- R v[t0] W v[t1]- R v[t0] haloupd v[t1] haloupd v[t1] R v[t1] R v[t1] haloupd v[t0] R v[t0] @@ -74,7 +74,7 @@ def _hoist_halospots(iet): # Precompute scopes to save time scopes = {i: Scope([e.expr for e in v]) for i, v in MapNodes().visit(iet).items()} - cond_mapper = _create_cond_mapper(iet) + cond_mapper = _make_cond_mapper(iet) # Analysis hsmapper = {} @@ -95,13 +95,46 @@ def _hoist_halospots(iet): continue # If there are overlapping time accesses, skip - if any(i in hs0.halo_scheme.loc_values - for i in hs1.halo_scheme.loc_values): + if hs0.halo_scheme.loc_values.intersection(hs1.halo_scheme.loc_values): continue - # Compare hs0 to subsequent halo_spots, looking for optimization - # possibilities - _process_halo_to_halo(hs0, hs1, iters, scopes, hsmapper, imapper) + # Loop over the functions in the HaloSpots + for f, v in hs1.fmapper.items(): + # If no time accesses, skip + if not hs1.halo_scheme.fmapper[f].loc_indices: + continue + + # If the function is not in both HaloSpots, skip + if f not in hs0.functions: + continue + + for it in iters: + # If also merge-able we can start hoisting the latter + for dep in scopes[it].d_flow.project(f): + if not any(r(dep, hs1, v.loc_indices) for r in merge_rules()): + break + else: + hse = hs1.halo_scheme.fmapper[f] + raw_loc_indices = {} + # Entering here means we can lift, and we need to update + # the loc_indices with known values + # TODO: Can I get this in a more elegant way? + for d in hse.loc_indices: + if hse.loc_indices[d].is_Symbol: + assert d in hse.loc_indices[d]._defines + root_min = hse.loc_indices[d].symbolic_min + new_min = root_min.subs(hse.loc_indices[d].root, + hse.loc_indices[d].root.symbolic_min) + raw_loc_indices[d] = new_min + else: + assert d.symbolic_min in hse.loc_indices[d].free_symbols + raw_loc_indices[d] = hse.loc_indices[d] + + hse = hse.rebuild(loc_indices=frozendict(raw_loc_indices)) + hs1.halo_scheme.fmapper[f] = hse + + hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f) + imapper[it].append(hs1.halo_scheme.project(f)) mapper = {i: HaloSpot(i._rebuild(), HaloScheme.union(hss)) for i, hss in imapper.items()} @@ -113,47 +146,6 @@ def _hoist_halospots(iet): return iet -def _process_halo_to_halo(hs0, hs1, iters, scopes, hsmapper, imapper): - - # Loop over the functions in the HaloSpots - for f, v in hs1.fmapper.items(): - # If no time accesses, skip - if not hs1.halo_scheme.fmapper[f].loc_indices: - continue - - # If the function is not in both HaloSpots, skip - if (*hs0.functions, *hs1.functions).count(f) < 2: - continue - - for iter in iters: - # If also merge-able we can start hoisting the latter - for dep in scopes[iter].d_flow.project(f): - if not any(r(dep, hs1, v.loc_indices) for r in merge_rules()): - break - else: - hse = hs1.halo_scheme.fmapper[f] - raw_loc_indices = {} - # Entering here means we can lift, and we need to update - # the loc_indices with known values - # TODO: Can I get this in a more elegant way? - for d in hse.loc_indices: - if hse.loc_indices[d].is_Symbol: - assert d in hse.loc_indices[d]._defines - root_min = hse.loc_indices[d].symbolic_min - new_min = root_min.subs(hse.loc_indices[d].root, - hse.loc_indices[d].root.symbolic_min) - raw_loc_indices[d] = new_min - else: - assert d.symbolic_min in hse.loc_indices[d].free_symbols - raw_loc_indices[d] = hse.loc_indices[d] - - hse = hse.rebuild(loc_indices=frozendict(raw_loc_indices)) - hs1.halo_scheme.fmapper[f] = hse - - hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f) - imapper[iter].append(hs1.halo_scheme.project(f)) - - def _merge_halospots(iet): """ Merge HaloSpots on the same Iteration tree level where all data dependencies @@ -161,7 +153,7 @@ def _merge_halospots(iet): """ # Analysis - cond_mapper = _create_cond_mapper(iet) + cond_mapper = _make_cond_mapper(iet) mapper = {} for iter, halo_spots in MapNodes(Iteration, HaloSpot, 'immediate').visit(iet).items(): @@ -363,7 +355,7 @@ def mpiize(graph, **kwargs): # Utility functions to avoid code duplication -def _create_cond_mapper(iet): +def _make_cond_mapper(iet): cond_mapper = MapHaloSpots().visit(iet) return {hs: {i for i in v if i.is_Conditional and not isinstance(i.condition, GuardFactorEq)} diff --git a/tests/test_mpi.py b/tests/test_mpi.py index e51a8f06ee..c02d122731 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -22,7 +22,7 @@ from devito.types.dimension import SpaceDimension from examples.seismic.acoustic import acoustic_setup -from examples.seismic import Receiver, TimeAxis, demo_model +from examples.seismic import demo_model from tests.test_dse import TestTTI @@ -1013,23 +1013,20 @@ def test_avoid_haloupdate_if_distr_but_sequential(self, mode): @pytest.mark.parallel(mode=1) def test_issue_2448(self, mode): - extent = (10., ) - shape = (2, ) + extent = (10.,) + shape = (2,) so = 2 - to = 1 x = SpaceDimension(name='x', spacing=Constant(name='h_x', value=extent[0]/(shape[0]-1))) - grid = Grid(extent=extent, shape=shape, dimensions=(x, )) + grid = Grid(extent=extent, shape=shape, dimensions=(x,)) # Time related - t0, tn = 0., 30. - dt = (10. / np.sqrt(2.)) / 6. - time_range = TimeAxis(start=t0, stop=tn, step=dt) + tn = 30 # Velocity and pressure fields - v = TimeFunction(name='v', grid=grid, space_order=so, time_order=to) - tau = TimeFunction(name='tau', grid=grid, space_order=so, time_order=to) + v = TimeFunction(name='v', grid=grid, space_order=so) + tau = TimeFunction(name='tau', grid=grid, space_order=so) # First order elastic-like dependencies equations pde_v = v.dt - (tau.dx) @@ -1040,7 +1037,7 @@ def test_issue_2448(self, mode): # Test two variants of receiver interpolation nrec = 1 - rec = Receiver(name="rec", grid=grid, npoint=nrec, time_range=time_range) + rec = SparseTimeFunction(name="rec", grid=grid, npoint=nrec, nt=tn) rec.coordinates.data[:, 0] = np.linspace(0., extent[0], num=nrec) # The receiver 0 @@ -2811,18 +2808,12 @@ def test_elastic_structure(self, mode): shape=(301, 301), spacing=(10., 10.), space_order=so) - t0, tn = 0., 2000. - dt = model.critical_dt - time_range = TimeAxis(start=t0, stop=tn, step=dt) - - x, z = model.grid.dimensions - - v = VectorTimeFunction(name='v', grid=model.grid, space_order=so, time_order=1) - tau = TensorTimeFunction(name='t', grid=model.grid, space_order=so, time_order=1) + v = VectorTimeFunction(name='v', grid=model.grid, space_order=so) + tau = TensorTimeFunction(name='t', grid=model.grid, space_order=so) # The receiver nrec = 301 - rec = Receiver(name="rec", grid=model.grid, npoint=nrec, time_range=time_range) + rec = SparseTimeFunction(name="rec", grid=model.grid, npoint=nrec, nt=10) rec.coordinates.data[:, 0] = np.linspace(0., model.domain_size[0], num=nrec) rec.coordinates.data[:, -1] = 5. @@ -2864,9 +2855,9 @@ def test_elastic_structure(self, mode): assert calls[4].arguments[1] is v[1] -class TestTTI_w_MPI: +class TestTTIwMPI: - @pytest.mark.parallel(mode=[(1)]) + @pytest.mark.parallel(mode=1) def test_halo_structure(self, mode): mytest = TestTTI()