Skip to content

Commit

Permalink
Merge pull request Cascella-Group-UiO#77 from hmcezar/multi-replica
Browse files Browse the repository at this point in the history
Multi replica simulations
  • Loading branch information
hmcezar authored Oct 5, 2023
2 parents 848a343 + 0e2145c commit 0be6e2e
Show file tree
Hide file tree
Showing 9 changed files with 302 additions and 139 deletions.
52 changes: 46 additions & 6 deletions hymd/configure_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def configure_runtime(args_in, comm):
Namespace containing command line arguments.
config : hymd.input_parser.Config
Parsed configuration object.
prng : np.random.Generator
Random number generator for this rank.
topol : dict
Topology dictionary.
intracomm : mpi4py.Comm
MPI communicator to use for rank communication within a replica.
intercomm : mpi4py.Comm
MPI communicator to use for rank communication between replicas.
"""
ap = argparse.ArgumentParser()

Expand Down Expand Up @@ -115,6 +123,9 @@ def configure_runtime(args_in, comm):
ap.add_argument(
"--destdir", default=".", help="Write output to specified directory"
)
ap.add_argument(
"--replica-dirs", type=str, nargs="+", default=[], help="Directories to store results for each replica"
)
ap.add_argument(
"--seed",
default=None,
Expand Down Expand Up @@ -143,13 +154,13 @@ def configure_runtime(args_in, comm):
type=extant_file,
help="Gmx-like topology file in toml format"
)
ap.add_argument("config", help="Config .py or .toml input configuration script")
ap.add_argument("config", type=extant_file, help="Config .py or .toml input configuration script")
ap.add_argument("input", help="input.hdf5")
args = ap.parse_args(args_in)

if comm.Get_rank() == 0:
os.makedirs(args.destdir, exist_ok=True)
comm.barrier()
# check if we have at least one rank per replica
if comm.Get_size() < len(args.replica_dirs) and comm.Get_rank() == 0:
raise ValueError("You should have at least one MPI rank per replica.")

# Safely define seeds
seeds = None
Expand All @@ -165,11 +176,40 @@ def configure_runtime(args_in, comm):
# Setup a PRNG for each rank
prng = np.random.default_rng(seeds[comm.Get_rank()])

# multi replica setup
n_replicas = len(args.replica_dirs)
if n_replicas > 1:
# split comm=COMM_WORLD in intracomm and intercomm
size = comm.Get_size()
rank = comm.Get_rank()

n_intra = int(size / n_replicas)
if (n_replicas * n_intra != size):
err_str = "Inconsistent number of ranks per replica"

if rank == 0:
raise AssertionError(err_str)
intracomm = comm.Split(int(rank / n_intra), rank)
intercomm = comm.Split(rank % n_intra, rank)

# assign directory to each rank
os.chdir(args.replica_dirs[int(rank / n_intra)])

else:
intracomm = comm
intercomm = None

if intracomm.Get_rank() == 0:
os.makedirs(args.destdir, exist_ok=True)

intracomm.barrier()

# Setup logger
Logger.setup(
default_level=logging.INFO,
log_file=f"{args.destdir}/{args.logfile}",
verbose=args.verbose,
comm=intracomm,
)

# print header info
Expand Down Expand Up @@ -229,7 +269,7 @@ def profile_atexit():
f"Unable to parse configuration file {args.config}"
f"\n\ntoml parse traceback:" + repr(ve)
)
return args, config, prng, topol
return args, config, prng, topol, intracomm, intercomm


def extant_file(x):
Expand All @@ -241,5 +281,5 @@ def extant_file(x):
# Argparse uses the ArgumentTypeError to give a rejection message like:
# error: argument input: x does not exist
raise argparse.ArgumentTypeError("{0} does not exist".format(x))
return x
return os.path.abspath(x)

14 changes: 1 addition & 13 deletions hymd/file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,20 +197,8 @@ def store_static(
creator_group = h5md_group.create_group("creator")
creator_group.attrs["name"] = np.string_("Hylleraas MD")

# Get HyMD version. Also grab the user email from git config if we
# can find it.
# Get HyMD version
creator_group.attrs["version"] = np.string_(get_version())
try:
import git

try:
reader = repo.config_reader()
user_email = reader.get_value("user", "email")
author_group.attrs["email"] = np.string_(user_email)
except:
pass
except:
pass

h5md.particles_group = h5md.file.create_group("/particles")
h5md.all_particles = h5md.particles_group.create_group("all")
Expand Down
51 changes: 43 additions & 8 deletions hymd/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@

class MPIFilterRoot(logging.Filter):
"""Log output Filter wrapper class for the root MPI rank log"""

comm = None

def __init__(self, comm=MPI.COMM_WORLD):
self.comm = comm

def filter(self, record):
"""Log event message filter
Expand All @@ -20,9 +25,9 @@ def filter(self, record):
"""
if record.funcName == "<module>":
record.funcName = "main"
if MPI.COMM_WORLD.Get_rank() == 0:
record.rank = MPI.COMM_WORLD.Get_rank()
record.size = MPI.COMM_WORLD.Get_size()
if self.comm.Get_rank() == 0:
record.rank = self.comm.Get_rank()
record.size = self.comm.Get_size()
return True
else:
return False
Expand All @@ -31,6 +36,11 @@ def filter(self, record):
class MPIFilterAll(logging.Filter):
"""Log output Filter wrapper class for the all-MPI-ranks log"""

comm = None

def __init__(self, comm=MPI.COMM_WORLD):
self.comm = comm

def filter(self, record):
"""Log event message filter
Expand All @@ -41,11 +51,33 @@ def filter(self, record):
"""
if record.funcName == "<module>":
record.funcName = "main"
record.rank = MPI.COMM_WORLD.Get_rank()
record.size = MPI.COMM_WORLD.Get_size()
record.rank = self.comm.Get_rank()
record.size = self.comm.Get_size()
return True


class MPIFilterStdout(logging.Filter):
"""Log output Filter wrapper class for filtering STDOUT log"""

def filter(self, record):
"""Log event message filter
Parameters
----------
record : logging.LogRecord
LogRecord object corresponding to the log event.
"""
if record.funcName == "<module>":
record.funcName = "main"
record.funcName = "replica 0 " + record.funcName
if MPI.COMM_WORLD.Get_rank() == 0:
record.rank = MPI.COMM_WORLD.Get_rank()
record.size = MPI.COMM_WORLD.Get_size()
return True
else:
return False


class Logger:
"""Log output handler class
Expand Down Expand Up @@ -96,7 +128,7 @@ class Logger:
all_ranks = logging.getLogger("HyMD.all_ranks")

@classmethod
def setup(cls, default_level=logging.INFO, log_file=None, verbose=False):
def setup(cls, default_level=logging.INFO, log_file=None, verbose=False, comm=MPI.COMM_WORLD):
"""Sets up the logger object.
If a :code:`log_file` path is provided, log event messages are output
Expand Down Expand Up @@ -127,8 +159,10 @@ def setup(cls, default_level=logging.INFO, log_file=None, verbose=False):
cls.rank0.setLevel(level)
cls.all_ranks.setLevel(level)

cls.rank0.addFilter(MPIFilterRoot())
cls.all_ranks.addFilter(MPIFilterAll())
root_filter = MPIFilterRoot(comm)
all_filter = MPIFilterAll(comm)
cls.rank0.addFilter(root_filter)
cls.all_ranks.addFilter(all_filter)

if (not log_file) and (not log_to_stdout):
return
Expand All @@ -145,6 +179,7 @@ def setup(cls, default_level=logging.INFO, log_file=None, verbose=False):
cls.stdout_handler.setLevel(level)
cls.stdout_handler.setStream(sys.stdout)
cls.stdout_handler.setFormatter(cls.formatter)
cls.stdout_handler.addFilter(MPIFilterStdout())

cls.rank0.addHandler(cls.stdout_handler)
cls.all_ranks.addHandler(cls.stdout_handler)
Expand Down
Loading

0 comments on commit 0be6e2e

Please sign in to comment.