From 387da72bfab40cc18a98a2cb524464e2d89ad012 Mon Sep 17 00:00:00 2001 From: mloubout Date: Mon, 9 Sep 2024 17:54:08 -0400 Subject: [PATCH 1/6] compiler:nfix handling of modulo 1 --- devito/ir/clusters/algorithms.py | 33 +++++++++++++++++++++++++------- devito/mpi/halo_scheme.py | 2 +- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index b463593438..2be2fed855 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -16,7 +16,7 @@ from devito.mpi.halo_scheme import HaloScheme, HaloTouch from devito.mpi.reduction_scheme import DistReduce from devito.symbolics import (limits_mapper, retrieve_indexed, uxreplace, - xreplace_indices) + xreplace_indices, retrieve_dimensions) from devito.tools import (DefaultOrderedDict, Stamp, as_mapper, flatten, is_integer, split, timed_pass, toposort) from devito.types import Array, Eq, Symbol @@ -51,6 +51,9 @@ def clusterize(exprs, **kwargs): # Derive the necessary communications for distributed-memory parallelism clusters = communications(clusters) + # Substitute potential stepping simplifications + clusters = simplify_modulo(clusters) + return ClusterGroup(clusters) @@ -358,12 +361,8 @@ def rule(size, e): groups = as_mapper(mds, lambda d: d.modulo) for size, v in groups.items(): key = partial(rule, size) - if size == 1: - # Optimization -- avoid useless "% 1" ModuloDimensions - subs = {md.origin: 0 for md in v} - else: - subs = {md.origin: md for md in v} - sub_iterators[d].extend(v) + subs = {md.origin: md for md in v} + sub_iterators[d].extend(v) func = partial(xreplace_indices, mapper=subs, key=key) exprs = [e.apply(func) for e in exprs] @@ -376,6 +375,26 @@ def rule(size, e): return processed +@timed_pass() +def simplify_modulo(clusters): + """ + Simplify trivial modulo expressions such as %1 + """ + processed = [] + for c in clusters: + mds = {d for d in retrieve_dimensions(c.exprs, deep=True) if d.is_Modulo} + + if mds: + subs = {d: 0 for d in mds if d._modulo == 1} + func = partial(xreplace_indices, mapper=subs) + exprs = [e.apply(func) for e in c.exprs] + processed.append(c.rebuild(exprs=exprs)) + else: + processed.append(c) + + return processed + + @timed_pass(name='communications') def communications(clusters): """ diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index 0062b5b32f..740ad87ed5 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -628,7 +628,7 @@ def _uxreplace_dispatch_haloscheme(hs0, rule): except KeyError: # E.g., `usave(cd, x, y)` and `usave.dx` in an # adjoint Operator - assert d0.is_Conditional + assert d0.is_Conditional or d1.is_Stepping loc_dirs[d1] = hse0.loc_dirs[d0.root] if len(loc_indices) != len(hse0.loc_indices): From f0619f51492e549c8b487a3da70775048b290b46 Mon Sep 17 00:00:00 2001 From: mloubout Date: Mon, 9 Sep 2024 18:12:07 -0400 Subject: [PATCH 2/6] test: add buffer1 halo test --- devito/ir/clusters/algorithms.py | 12 ++++++++---- devito/mpi/halo_scheme.py | 2 +- tests/test_mpi.py | 22 ++++++++++++++++++++-- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index 2be2fed855..1491e846f1 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -362,7 +362,11 @@ def rule(size, e): for size, v in groups.items(): key = partial(rule, size) subs = {md.origin: md for md in v} - sub_iterators[d].extend(v) + # Modulo 1 will be removed later but need to be removed + # from the sub_iterator here to prevent bad dimension analysis + # E.g. it would prevent some outer loop parallelism + if size != 1: + sub_iterators[d].extend(v) func = partial(xreplace_indices, mapper=subs, key=key) exprs = [e.apply(func) for e in exprs] @@ -378,14 +382,14 @@ def rule(size, e): @timed_pass() def simplify_modulo(clusters): """ - Simplify trivial modulo expressions such as %1 + Simplify trivial modulo expressions such as %1 """ processed = [] for c in clusters: mds = {d for d in retrieve_dimensions(c.exprs, deep=True) if d.is_Modulo} + subs = {d: 0 for d in mds if d._modulo == 1} - if mds: - subs = {d: 0 for d in mds if d._modulo == 1} + if subs: func = partial(xreplace_indices, mapper=subs) exprs = [e.apply(func) for e in c.exprs] processed.append(c.rebuild(exprs=exprs)) diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index 740ad87ed5..3fe8a31e21 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -628,7 +628,7 @@ def _uxreplace_dispatch_haloscheme(hs0, rule): except KeyError: # E.g., `usave(cd, x, y)` and `usave.dx` in an # adjoint Operator - assert d0.is_Conditional or d1.is_Stepping + assert d0.is_Conditional or d0.is_Stepping loc_dirs[d1] = hse0.loc_dirs[d0.root] if len(loc_indices) != len(hse0.loc_indices): diff --git a/tests/test_mpi.py b/tests/test_mpi.py index c8bad20aaa..6594418369 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -4,10 +4,10 @@ from conftest import _R, assert_blocking, assert_structure from devito import (Grid, Constant, Function, TimeFunction, SparseFunction, - SparseTimeFunction, Dimension, ConditionalDimension, + SparseTimeFunction, Dimension, ConditionalDimension, div, SubDimension, SubDomain, Eq, Ne, Inc, NODE, Operator, norm, inner, configuration, switchconfig, generic_derivative, - PrecomputedSparseFunction, DefaultDimension) + PrecomputedSparseFunction, DefaultDimension, Buffer) from devito.arch.compiler import OneapiCompiler from devito.data import LEFT, RIGHT from devito.ir.iet import (Call, Conditional, Iteration, FindNodes, FindSymbols, @@ -1639,6 +1639,24 @@ def test_enforce_haloupdate_if_unwritten_function(self, mode): calls = FindNodes(Call).visit(op) assert len(calls) == 2 # One for `v` and one for `usave` + @pytest.mark.parallel(mode=1) + def test_haloupdate_buffer1(self, mode): + grid = Grid(shape=(4, 4)) + x, y = grid.dimensions + + u = TimeFunction(name='u', grid=grid, time_order=1, save=Buffer(1)) + v = TimeFunction(name='v', grid=grid, time_order=1, save=Buffer(1)) + + eqns = [Eq(u.forward, div(v) + 1.), + Eq(v.forward, div(u.forward) + 1.)] + + op = Operator(eqns) + + calls = FindNodes(Call).visit(op) + # There should be two separate calls + # halo(v), eq_u, halo_u, eq(v) + assert len(calls) == 2 + class TestOperatorAdvanced: From f0365640520a18c6bc88c8597fa9031c479b0b95 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 10 Sep 2024 07:49:36 +0000 Subject: [PATCH 3/6] compiler: Simplify modulo-1 ModuloDimension removal --- devito/ir/clusters/algorithms.py | 25 +------------------------ devito/mpi/halo_scheme.py | 2 +- devito/passes/iet/misc.py | 8 ++++++-- tests/test_mpi.py | 4 ++++ 4 files changed, 12 insertions(+), 27 deletions(-) diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index 1491e846f1..e17f844cf4 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -16,7 +16,7 @@ from devito.mpi.halo_scheme import HaloScheme, HaloTouch from devito.mpi.reduction_scheme import DistReduce from devito.symbolics import (limits_mapper, retrieve_indexed, uxreplace, - xreplace_indices, retrieve_dimensions) + xreplace_indices) from devito.tools import (DefaultOrderedDict, Stamp, as_mapper, flatten, is_integer, split, timed_pass, toposort) from devito.types import Array, Eq, Symbol @@ -51,9 +51,6 @@ def clusterize(exprs, **kwargs): # Derive the necessary communications for distributed-memory parallelism clusters = communications(clusters) - # Substitute potential stepping simplifications - clusters = simplify_modulo(clusters) - return ClusterGroup(clusters) @@ -379,26 +376,6 @@ def rule(size, e): return processed -@timed_pass() -def simplify_modulo(clusters): - """ - Simplify trivial modulo expressions such as %1 - """ - processed = [] - for c in clusters: - mds = {d for d in retrieve_dimensions(c.exprs, deep=True) if d.is_Modulo} - subs = {d: 0 for d in mds if d._modulo == 1} - - if subs: - func = partial(xreplace_indices, mapper=subs) - exprs = [e.apply(func) for e in c.exprs] - processed.append(c.rebuild(exprs=exprs)) - else: - processed.append(c) - - return processed - - @timed_pass(name='communications') def communications(clusters): """ diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index 3fe8a31e21..0062b5b32f 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -628,7 +628,7 @@ def _uxreplace_dispatch_haloscheme(hs0, rule): except KeyError: # E.g., `usave(cd, x, y)` and `usave.dx` in an # adjoint Operator - assert d0.is_Conditional or d0.is_Stepping + assert d0.is_Conditional loc_dirs[d1] = hse0.loc_dirs[d0.root] if len(loc_indices) != len(hse0.loc_indices): diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index f0b2b7f4f5..869d8d58d4 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -246,9 +246,13 @@ def remove_redundant_moddims(iet): if not mds: return iet - mapper = as_mapper(mds, key=lambda md: md.offset % md.modulo) + # Modulo-1 dimensions are always redundant as they can be replaced by 0 + degenerate, mds = split(mds, lambda md: md.modulo == 1) + subs = {md: sympy.S.Zero for md in degenerate} - subs = {} + # Group ModuloDimensions so that we can pick one and remove the others that + # would map to the same modulo value + mapper = as_mapper(mds, key=lambda md: md.offset % md.modulo) for k, v in mapper.items(): chosen = v.pop(0) subs.update({d: chosen for d in v}) diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 6594418369..b6332a476c 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -1657,6 +1657,10 @@ def test_haloupdate_buffer1(self, mode): # halo(v), eq_u, halo_u, eq(v) assert len(calls) == 2 + # Also ensure the compiler is doing its job removing unnecessary + # ModuloDimensions + assert len([i for i in FindSymbols('dimensions').visit(op) if i.is_Modulo]) == 0 + class TestOperatorAdvanced: From 9fe733b1d3b6e6ba05a8cd34d9240034b26af293 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 10 Sep 2024 08:37:29 +0000 Subject: [PATCH 4/6] compiler: Clean up removal of useless ModuloDimensions --- devito/ir/clusters/algorithms.py | 6 +---- devito/passes/iet/misc.py | 42 +++++++++++++++++--------------- 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index e17f844cf4..1a05f0842a 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -359,11 +359,7 @@ def rule(size, e): for size, v in groups.items(): key = partial(rule, size) subs = {md.origin: md for md in v} - # Modulo 1 will be removed later but need to be removed - # from the sub_iterator here to prevent bad dimension analysis - # E.g. it would prevent some outer loop parallelism - if size != 1: - sub_iterators[d].extend(v) + sub_iterators[d].extend(v) func = partial(xreplace_indices, mapper=subs, key=key) exprs = [e.apply(func) for e in exprs] diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 869d8d58d4..0b4ed39149 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -246,29 +246,31 @@ def remove_redundant_moddims(iet): if not mds: return iet - # Modulo-1 dimensions are always redundant as they can be replaced by 0 - degenerate, mds = split(mds, lambda md: md.modulo == 1) - subs = {md: sympy.S.Zero for md in degenerate} - - # Group ModuloDimensions so that we can pick one and remove the others that - # would map to the same modulo value - mapper = as_mapper(mds, key=lambda md: md.offset % md.modulo) - for k, v in mapper.items(): - chosen = v.pop(0) - subs.update({d: chosen for d in v}) - - body = Uxreplace(subs).visit(iet.body) - iet = iet._rebuild(body=body) - # ModuloDimensions are defined in Iteration headers, hence they must be - # removed from there too - subs = {} + # removed from there first of all + mapper = {} for n in FindNodes(Iteration).visit(iet): - if not set(n.uindices) & set(mds): - continue - subs[n] = n._rebuild(uindices=filter_ordered(n.uindices)) + candidates = [d for d in n.uindices if d in mds] + + degenerates, others = split(candidates, lambda d: d.modulo == 1) + subs = {d: sympy.S.Zero for d in degenerates} - iet = Transformer(subs, nested=True).visit(iet) + redundants = as_mapper(others, key=lambda d: d.offset % d.modulo) + for k, v in redundants.items(): + chosen = v.pop(0) + subs.update({d: chosen for d in v}) + + if subs: + # Expunge the ModuloDimensions from the Iteration header + uindices = [d for d in n.uindices if d not in subs] + iteration = n._rebuild(uindices=uindices) + + # Replace the ModuloDimensions in the Iteration body + iteration = Uxreplace(subs).visit(iteration) + + mapper[n] = iteration + + iet = Transformer(mapper, nested=True).visit(iet) return iet From f3149914b52f0b656cd9de382792c27891e91ed3 Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 10 Sep 2024 07:55:05 -0400 Subject: [PATCH 5/6] tests: fix peculiar reduction test --- tests/test_gpu_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 3980b60cd2..70879b4ce7 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -161,7 +161,7 @@ def test_reduction_many_dims(self): op1 = Operator(eqns, opt=('advanced', {'mapify-reduce': True})) tree, = retrieve_iteration_tree(op0) - assert 'collapse(4) reduction(+:s)' in str(tree.root.pragmas[0]) + assert 'collapse(3) reduction(+:s)' in str(tree[1].pragmas[0]) tree, = retrieve_iteration_tree(op1) assert 'collapse(3) reduction(+:s)' in str(tree[1].pragmas[0]) From eab35794deab794a2a34ae5de072b388ac429c52 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 10 Sep 2024 13:14:39 +0000 Subject: [PATCH 6/6] mpi: Cast MPI={1,True,basic} to MPI=basic for homogeneity --- devito/__init__.py | 2 +- devito/mpi/routines.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/devito/__init__.py b/devito/__init__.py index ef08d206f2..c03324a39c 100644 --- a/devito/__init__.py +++ b/devito/__init__.py @@ -75,7 +75,7 @@ def reinit_compiler(val): deprecate='openmp') # MPI mode (0 => disabled, 1 == basic) -preprocessor = lambda i: bool(i) if isinstance(i, int) else i +preprocessor = lambda i: {0: False, 1: 'basic'}.get(i, i) configuration.add('mpi', 0, [0, 1] + list(mpi_registry), preprocessor=preprocessor, callback=reinit_compiler) diff --git a/devito/mpi/routines.py b/devito/mpi/routines.py index dc89e85164..8b4987c8bb 100644 --- a/devito/mpi/routines.py +++ b/devito/mpi/routines.py @@ -1002,7 +1002,6 @@ def _call_poke(self, poke): mpi_registry = { - True: BasicHaloExchangeBuilder, 'basic': BasicHaloExchangeBuilder, 'diag': DiagHaloExchangeBuilder, 'diag2': Diag2HaloExchangeBuilder,