Skip to content

Commit

Permalink
Merge pull request #24 from tskisner/mmap
Browse files Browse the repository at this point in the history
Use Python SharedMemory as the backend
  • Loading branch information
tskisner authored Mar 17, 2024
2 parents 2ca8e38 + 95e19fa commit 02cd641
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 30 deletions.
43 changes: 19 additions & 24 deletions pshmem/shmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,19 @@
##

import sys
from multiprocessing import shared_memory

import numpy as np
import sysv_ipc

from .utils import mpi_data_type, random_shm_key
from .utils import (
mpi_data_type,
random_shm_key,
remove_shm_from_resource_tracker,
)

# Monkey patch resource_tracker. Remove once upstream CPython
# changes are merged.
remove_shm_from_resource_tracker()


class MPIShared(object):
Expand Down Expand Up @@ -149,7 +157,7 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
if self._rank == 0:
# Get a random 64bit integer between the supported range of keys
self._shm_index = random_shm_key()
# Name, just used for printing
# Name, used as global tag.
self._name = f"MPIShared_{self._shm_index}"
if self._comm is not None:
self._shm_index = self._comm.bcast(self._shm_index, root=0)
Expand Down Expand Up @@ -177,10 +185,8 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
# First rank on each node creates the buffer
if self._noderank == 0:
try:
self._shmem = sysv_ipc.SharedMemory(
self._shm_index,
flags=sysv_ipc.IPC_CREX,
size=int(nbytes),
self._shmem = shared_memory.SharedMemory(
name=self._name, create=True, size=int(nbytes),
)
except Exception as e:
msg = "Process {}: {}".format(self._rank, self._name)
Expand All @@ -199,8 +205,8 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
# Other ranks on the node attach
if self._noderank != 0:
try:
self._shmem = sysv_ipc.SharedMemory(
self._shm_index, flags=0, size=0
self._shmem = shared_memory.SharedMemory(
name=self._name, create=False, size=int(nbytes)
)
except Exception as e:
msg = "Process {}: {}".format(self._rank, self._name)
Expand All @@ -216,7 +222,7 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
self._flat = np.ndarray(
self._n,
dtype=self._dtype,
buffer=self._shmem,
buffer=self._shmem.buf,
)

# Initialize to zero.
Expand All @@ -230,19 +236,6 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
if self._nodecomm is not None:
self._nodecomm.barrier()

# Now the rank zero process will call remove() to mark the shared
# memory segment for removal. However, this will not actually
# be removed until all processes detach.
if self._noderank == 0:
try:
self._shmem.remove()
except sysv_ipc.ExistentialError:
msg = "Process {}: {}".format(self._rank, self._name)
msg += " failed to remove shared memory"
msg += ": {}".format(e)
print(msg, flush=True)
raise

def __del__(self):
self.close()

Expand Down Expand Up @@ -370,7 +363,9 @@ def close(self):
del self._flat
if hasattr(self, "_shmem"):
if self._shmem is not None:
self._shmem.detach()
self._shmem.close()
if self._noderank == 0:
self._shmem.unlink()
del self._shmem
self._shmem = None
self._flat = None
Expand Down
7 changes: 6 additions & 1 deletion pshmem/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,9 +432,14 @@ def test_zero(self):
# dims = (200, 1000000)
# dt = np.float64
# shm = MPIShared(dims, dt, self.comm)
# if self.comm is None or self.comm.rank == 0:
# temp = np.ones(dims, dtype=dt)
# else:
# temp = None
# shm.set(temp, fromrank=0)
# del temp
# import time
# time.sleep(60)
# shm.close()
# del shm
# return

Expand Down
32 changes: 28 additions & 4 deletions pshmem/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
##

import random
import sys
# Import for monkey patching resource tracker
from multiprocessing import resource_tracker

import numpy as np
import sysv_ipc


def mpi_data_type(comm, dt):
Expand Down Expand Up @@ -48,7 +50,7 @@ def mpi_data_type(comm, dt):


def random_shm_key():
"""Get a random 64bit integer in the range supported by shmget()
"""Get a random positive integer for using in shared memory naming.
The python random library is used, and seeded with the default source
(either system time or os.urandom).
Expand All @@ -57,8 +59,30 @@ def random_shm_key():
(int): The random integer.
"""
min_val = sysv_ipc.KEY_MIN
max_val = sysv_ipc.KEY_MAX
min_val = 0
max_val = sys.maxsize
# Seed with default source of randomness
random.seed(a=None)
return random.randint(min_val, max_val)


def remove_shm_from_resource_tracker():
"""Monkey-patch multiprocessing.resource_tracker so SharedMemory won't be tracked
More details at: https://bugs.python.org/issue38119
"""

def fix_register(name, rtype):
if rtype == "shared_memory":
return
return resource_tracker._resource_tracker.register(self, name, rtype)
resource_tracker.register = fix_register

def fix_unregister(name, rtype):
if rtype == "shared_memory":
return
return resource_tracker._resource_tracker.unregister(self, name, rtype)
resource_tracker.unregister = fix_unregister

if "shared_memory" in resource_tracker._CLEANUP_FUNCS:
del resource_tracker._CLEANUP_FUNCS["shared_memory"]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def readme():
scripts=None,
license="BSD",
python_requires=">=3.8.0",
install_requires=["numpy", "sysv_ipc"],
install_requires=["numpy"],
extras_require={"mpi": ["mpi4py>=3.0"]},
cmdclass=versioneer.get_cmdclass(),
classifiers=[
Expand Down

0 comments on commit 02cd641

Please sign in to comment.