Skip to content

Commit

Permalink
mpi: fix halo exchange for non-mpi devito within mpi code
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Nov 6, 2023
1 parent 0f8f88b commit f2dfd6b
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 6 deletions.
2 changes: 1 addition & 1 deletion devito/arch/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def __init_finalize__(self, **kwargs):
self.ldflags = filter_ordered(self.ldflags + extrald)

def __lookup_cmds__(self):
self._base.__lookup_cmds__()
self._base.__lookup_cmds__(self)
self.CC = environ.get('CC', self.CC)
self.CXX = environ.get('CXX', self.CXX)
self.MPICC = environ.get('MPICC', self.MPICC)
Expand Down
8 changes: 4 additions & 4 deletions devito/mpi/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class Distributor(AbstractDistributor):
"""

def __init__(self, shape, dimensions, input_comm=None, topology=None):
super(Distributor, self).__init__(shape, dimensions)
super().__init__(shape, dimensions)

if configuration['mpi']:
# First time we enter here, we make sure MPI is initialized
Expand Down Expand Up @@ -426,7 +426,7 @@ class SparseDistributor(AbstractDistributor):
"""

def __init__(self, npoint, dimension, distributor):
super(SparseDistributor, self).__init__(npoint, dimension)
super().__init__(npoint, dimension)
self._distributor = distributor

# The dimension decomposition
Expand Down Expand Up @@ -523,7 +523,7 @@ def __init__(self, neighborhood):
self._entries = [i for i in neighborhood if isinstance(i, tuple)]

fields = [(''.join(j.name[0] for j in i), c_int) for i in self.entries]
super(MPINeighborhood, self).__init__('nb', 'neighborhood', fields)
super().__init__('nb', 'neighborhood', fields)

@property
def entries(self):
Expand Down Expand Up @@ -552,7 +552,7 @@ def _C_typedecl(self):
for i, j in groups])

def _arg_defaults(self):
values = super(MPINeighborhood, self)._arg_defaults()
values = super()._arg_defaults()
for name, i in zip(self.fields, self.entries):
setattr(values[self.name]._obj, name, self.neighborhood[i])
return values
Expand Down
3 changes: 2 additions & 1 deletion devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,8 @@ def _C_get_field(self, region, dim, side=None):

def _halo_exchange(self):
"""Perform the halo exchange with the neighboring processes."""
if not MPI.Is_initialized() or MPI.COMM_WORLD.size == 1:
if not MPI.Is_initialized() or MPI.COMM_WORLD.size == 1 or \
not configuration['mpi']:
# Nothing to do
return
if MPI.COMM_WORLD.size > 1 and self._distributor is None:
Expand Down
15 changes: 15 additions & 0 deletions tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,21 @@ def test_local_indices(self, shape, expected):
assert all(i == slice(*j)
for i, j in zip(f.local_indices, expected[grid.distributor.myrank]))

@pytest.mark.parallel(mode=4)
@pytest.mark.parametrize('shape', [(1,), (2, 3), (4, 5, 6)])
def test_mpi4py_nodevmpi(self, shape):

with switchconfig(mpi=False):
# Mimic external mpi init
MPI.Init()
# Check that internal Function work correctly
grid = Grid(shape=shape)
f = Function(name="f", grid=grid, space_order=1)
assert f.data.shape == shape
assert f.data_with_halo.shape == tuple(s+2 for s in shape)
assert f.data._local.shape == shape
MPI.Finalize()


class TestSparseFunction(object):

Expand Down

0 comments on commit f2dfd6b

Please sign in to comment.