Skip to content

Commit

Permalink
Merge pull request #2344 from devitocodes/allreduce
Browse files Browse the repository at this point in the history
compiler: Support for C-level MPI_Allreduce
  • Loading branch information
FabioLuporini authored May 31, 2024
2 parents daa1d85 + 35aea7c commit 10963f3
Show file tree
Hide file tree
Showing 18 changed files with 362 additions and 135 deletions.
60 changes: 33 additions & 27 deletions devito/builtins/arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

import devito as dv
from devito.builtins.utils import MPIReduction
from devito.builtins.utils import make_retval


__all__ = ['norm', 'sumall', 'sum', 'inner', 'mmin', 'mmax']
Expand Down Expand Up @@ -44,15 +44,15 @@ def norm(f, order=2):
p, eqns = f.guard() if f.is_SparseFunction else (f, [])

dtype = accumulator_mapper[f.dtype]
n = make_retval(f.grid, dtype)
s = dv.types.Symbol(name='sum', dtype=dtype)

with MPIReduction(f, dtype=dtype) as mr:
op = dv.Operator([dv.Eq(s, 0.0)] + eqns +
[dv.Inc(s, dv.Abs(Pow(p, order))), dv.Eq(mr.n[0], s)],
name='norm%d' % order)
op.apply(**kwargs)
op = dv.Operator([dv.Eq(s, 0.0)] + eqns +
[dv.Inc(s, dv.Abs(Pow(p, order))), dv.Eq(n[0], s)],
name='norm%d' % order)
op.apply(**kwargs)

v = np.power(mr.v, 1/order)
v = np.power(n.data[0], 1/order)

return f.dtype(v)

Expand Down Expand Up @@ -129,15 +129,15 @@ def sumall(f):
p, eqns = f.guard() if f.is_SparseFunction else (f, [])

dtype = accumulator_mapper[f.dtype]
n = make_retval(f.grid, dtype)
s = dv.types.Symbol(name='sum', dtype=dtype)

with MPIReduction(f, dtype=dtype) as mr:
op = dv.Operator([dv.Eq(s, 0.0)] + eqns +
[dv.Inc(s, p), dv.Eq(mr.n[0], s)],
name='sum')
op.apply(**kwargs)
op = dv.Operator([dv.Eq(s, 0.0)] + eqns +
[dv.Inc(s, p), dv.Eq(n[0], s)],
name='sum')
op.apply(**kwargs)

return f.dtype(mr.v)
return f.dtype(n.data[0])


@dv.switchconfig(log_level='ERROR')
Expand Down Expand Up @@ -184,15 +184,15 @@ def inner(f, g):
rhs, eqns = f.guard(f*g) if f.is_SparseFunction else (f*g, [])

dtype = accumulator_mapper[f.dtype]
n = make_retval(f.grid or g.grid, dtype)
s = dv.types.Symbol(name='sum', dtype=dtype)

with MPIReduction(f, g, dtype=dtype) as mr:
op = dv.Operator([dv.Eq(s, 0.0)] + eqns +
[dv.Inc(s, rhs), dv.Eq(mr.n[0], s)],
name='inner')
op.apply(**kwargs)
op = dv.Operator([dv.Eq(s, 0.0)] + eqns +
[dv.Inc(s, rhs), dv.Eq(n[0], s)],
name='inner')
op.apply(**kwargs)

return f.dtype(mr.v)
return f.dtype(n.data[0])


@dv.switchconfig(log_level='ERROR')
Expand All @@ -208,11 +208,14 @@ def mmin(f):
if isinstance(f, dv.Constant):
return f.data
elif isinstance(f, dv.types.dense.DiscreteFunction):
with MPIReduction(f, op=dv.mpi.MPI.MIN) as mr:
mr.n.data[0] = np.min(f.data_ro_domain).item()
return mr.v.item()
v = np.min(f.data_ro_domain)
if f.grid is None or not dv.configuration['mpi']:
return v.item()
else:
comm = f.grid.distributor.comm
return comm.allreduce(v, dv.mpi.MPI.MIN).item()
else:
raise ValueError("Expected Function, not `%s`" % type(f))
raise ValueError("Expected Function, got `%s`" % type(f))


@dv.switchconfig(log_level='ERROR')
Expand All @@ -228,8 +231,11 @@ def mmax(f):
if isinstance(f, dv.Constant):
return f.data
elif isinstance(f, dv.types.dense.DiscreteFunction):
with MPIReduction(f, op=dv.mpi.MPI.MAX) as mr:
mr.n.data[0] = np.max(f.data_ro_domain).item()
return mr.v.item()
v = np.max(f.data_ro_domain)
if f.grid is None or not dv.configuration['mpi']:
return v.item()
else:
comm = f.grid.distributor.comm
return comm.allreduce(v, dv.mpi.MPI.MAX).item()
else:
raise ValueError("Expected Function, not `%s`" % type(f))
raise ValueError("Expected Function, got `%s`" % type(f))
52 changes: 13 additions & 39 deletions devito/builtins/utils.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,26 @@
from functools import wraps

import numpy as np

import devito as dv
from devito.symbolics import uxreplace
from devito.tools import as_tuple

__all__ = ['MPIReduction', 'nbl_to_padsize', 'pad_outhalo', 'abstract_args']
__all__ = ['make_retval', 'nbl_to_padsize', 'pad_outhalo', 'abstract_args']


class MPIReduction:
def make_retval(grid, dtype):
"""
A context manager to build MPI-aware reduction Operators.
Devito does not support passing values by reference. This function
creates a dummy Function of size 1 to store the return value of a builtin
applied to `f`.
"""

def __init__(self, *functions, op=dv.mpi.MPI.SUM, dtype=None):
grids = {f.grid for f in functions}
if len(grids) == 0:
self.grid = None
elif len(grids) == 1:
self.grid = grids.pop()
else:
raise ValueError("Multiple Grids found")
if dtype is not None:
self.dtype = dtype
else:
dtype = {f.dtype for f in functions}
if len(dtype) == 1:
self.dtype = np.result_type(dtype.pop(), np.float32).type
else:
raise ValueError("Illegal mixed data types")
self.v = None
self.op = op

def __enter__(self):
i = dv.Dimension(name='mri',)
self.n = dv.Function(name='n', shape=(1,), dimensions=(i,),
grid=self.grid, dtype=self.dtype, space='host')
self.n.data[:] = 0
return self

def __exit__(self, exc_type, exc_value, traceback):
if self.grid is None or not dv.configuration['mpi']:
assert self.n.data.size == 1
self.v = self.n.data[0]
else:
comm = self.grid.distributor.comm
self.v = comm.allreduce(np.asarray(self.n.data), self.op)[0]
if grid is None:
raise ValueError("Expected Grid, got None")

i = dv.Dimension(name='mri',)
n = dv.Function(name='n', shape=(1,), dimensions=(i,), grid=grid,
dtype=dtype, space='host')
n.data[:] = 0
return n


def nbl_to_padsize(nbl, ndim):
Expand Down
19 changes: 10 additions & 9 deletions devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,20 @@ def _normalize_gpu_fit(cls, oo, **kwargs):
return as_tuple(cls.GPU_FIT)

@classmethod
def _rcompile_wrapper(cls, **kwargs0):
options = kwargs0['options']
def _rcompile_wrapper(cls, **kwargs):
def wrapper(expressions, mode='default', **options):

def wrapper(expressions, mode='default', **kwargs1):
if mode == 'host':
kwargs = {**{
par_disabled = kwargs['options']['par-disabled']
target = {
'platform': 'cpu64',
'language': 'C' if options['par-disabled'] else 'openmp',
'compiler': 'custom',
}, **kwargs1}
'language': 'C' if par_disabled else 'openmp',
'compiler': 'custom'
}
else:
kwargs = {**kwargs0, **kwargs1}
return rcompile(expressions, kwargs)
target = None

return rcompile(expressions, kwargs, options, target=target)

return wrapper

Expand Down
83 changes: 75 additions & 8 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
from devito.exceptions import InvalidOperator
from devito.finite_differences.elementary import Max, Min
from devito.ir.support import (Any, Backward, Forward, IterationSpace, erange,
pull_dims)
pull_dims, null_ispace)
from devito.ir.equations import OpMin, OpMax
from devito.ir.clusters.analysis import analyze
from devito.ir.clusters.cluster import Cluster, ClusterGroup
from devito.ir.clusters.visitors import Queue, QueueStateful, cluster_pass
from devito.mpi.halo_scheme import HaloScheme, HaloTouch
from devito.mpi.reduction_scheme import DistReduce
from devito.symbolics import (limits_mapper, retrieve_indexed, uxreplace,
xreplace_indices)
from devito.tools import (DefaultOrderedDict, Stamp, as_mapper, flatten,
is_integer, timed_pass, toposort)
is_integer, split, timed_pass, toposort)
from devito.types import Array, Eq, Symbol
from devito.types.dimension import BOTTOM, ModuloDimension

Expand Down Expand Up @@ -48,7 +49,7 @@ def clusterize(exprs, **kwargs):
clusters = normalize(clusters, **kwargs)

# Derive the necessary communications for distributed-memory parallelism
clusters = Communications().process(clusters)
clusters = communications(clusters)

return ClusterGroup(clusters)

Expand Down Expand Up @@ -365,19 +366,29 @@ def rule(size, e):
return processed


class Communications(Queue):

@timed_pass(name='communications')
def communications(clusters):
"""
Enrich a sequence of Clusters by adding special Clusters representing data
communications, or "halo exchanges", for distributed parallelism.
communications for distributed parallelism.
"""
clusters = HaloComms().process(clusters)
clusters = reduction_comms(clusters)

return clusters


class HaloComms(Queue):

"""
Inject Clusters representing halo exchanges for distributed-memory parallelism.
"""

_q_guards_in_key = True
_q_properties_in_key = True

B = Symbol(name='⊥')

@timed_pass(name='communications')
def process(self, clusters):
return self._process_fatd(clusters, 1, seen=set())

Expand Down Expand Up @@ -432,6 +443,57 @@ def callback(self, clusters, prefix, seen=None):
return processed


def reduction_comms(clusters):
processed = []
fifo = []
for c in clusters:
# Schedule the global distributed reductions encountered before `c`,
# if `c`'s IterationSpace is such that the reduction can be carried out
found, fifo = split(fifo, lambda dr: dr.ispace.is_subset(c.ispace))
if found:
exprs = [Eq(dr.var, dr) for dr in found]
processed.append(c.rebuild(exprs=exprs))

# Detect the global distributed reductions in `c`
for e in c.exprs:
op = e.operation
if op is None or c.is_sparse:
continue

var = e.lhs
grid = c.grid
if grid is None:
continue

# Is Inc/Max/Min/... actually used for a reduction?
ispace = c.ispace.project(lambda d: d in var.free_symbols)
if ispace.itdims == c.ispace.itdims:
continue

# The reduced Dimensions
rdims = set(c.ispace.itdims) - set(ispace.itdims)

# The reduced Dimensions inducing a global distributed reduction
grdims = {d for d in rdims if d._defines & c.dist_dimensions}
if not grdims:
continue

# The IterationSpace within which the global distributed reduction
# must be carried out
ispace = c.ispace.prefix(lambda d: d in var.free_symbols)

fifo.append(DistReduce(var, op=op, grid=grid, ispace=ispace))

processed.append(c)

# Leftover reductions are placed at the very end
if fifo:
exprs = [Eq(dr.var, dr) for dr in fifo]
processed.append(Cluster(exprs=exprs, ispace=null_ispace))

return processed


def normalize(clusters, **kwargs):
options = kwargs['options']
sregistry = kwargs['sregistry']
Expand Down Expand Up @@ -562,7 +624,12 @@ def _normalize_reductions_dense(cluster, sregistry, mapper):
# because the Function might be padded, and reduction operations
# require, in general, the data values to be contiguous
name = sregistry.make_name()
a = mapper[rhs] = Array(name=name, dtype=e.dtype, dimensions=dims)
try:
grid = cluster.grid
except ValueError:
grid = None
a = mapper[rhs] = Array(name=name, dtype=e.dtype, dimensions=dims,
grid=grid)

processed.extend([Eq(a.indexify(), rhs),
e.func(lhs, a.indexify())])
Expand Down
Loading

0 comments on commit 10963f3

Please sign in to comment.