diff --git a/hymd/configure_runtime.py b/hymd/configure_runtime.py index 0221603f..c374c87e 100644 --- a/hymd/configure_runtime.py +++ b/hymd/configure_runtime.py @@ -29,6 +29,14 @@ def configure_runtime(args_in, comm): Namespace containing command line arguments. config : hymd.input_parser.Config Parsed configuration object. + prng : np.random.Generator + Random number generator for this rank. + topol : dict + Topology dictionary. + intracomm : mpi4py.Comm + MPI communicator to use for rank communication within a replica. + intercomm : mpi4py.Comm + MPI communicator to use for rank communication between replicas. """ ap = argparse.ArgumentParser() @@ -151,11 +159,11 @@ def configure_runtime(args_in, comm): args = ap.parse_args(args_in) # check if we have at least one rank per replica - if comm.Get_size() < len(args.replica_dirs): + if comm.Get_size() < len(args.replica_dirs) and comm.Get_rank() == 0: raise ValueError("You should have at least one MPI rank per replica.") # block destdir with replicas - if (len(args.replica_dirs) > 0) and args.destdir != ".": + if (len(args.replica_dirs) > 0) and args.destdir != "." and comm.Get_rank() == 0: raise ValueError("You should not specify a destination directory when using replicas.") if comm.Get_rank() == 0: @@ -177,11 +185,35 @@ def configure_runtime(args_in, comm): # Setup a PRNG for each rank prng = np.random.default_rng(seeds[comm.Get_rank()]) + # multi replica setup + n_replicas = len(args.replica_dirs) + if n_replicas > 1: + # split comm=COMM_WORLD in intracomm and intercomm + size = comm.Get_size() + rank = comm.Get_rank() + + n_intra = int(size / n_replicas) + if (n_replicas * n_intra != size): + err_str = "Inconsistent number of ranks per replica" + + if rank == 0: + raise AssertionError(err_str) + intracomm = comm.Split(int(rank / n_intra), rank) + intercomm = comm.Split(rank % n_intra, rank) + + # assign directory to each rank + os.chdir(args.replica_dirs[int(rank / n_intra)]) + + else: + intracomm = comm + intercomm = None + # Setup logger Logger.setup( default_level=logging.INFO, log_file=f"{args.destdir}/{args.logfile}", verbose=args.verbose, + comm=intracomm, ) # print header info @@ -241,7 +273,7 @@ def profile_atexit(): f"Unable to parse configuration file {args.config}" f"\n\ntoml parse traceback:" + repr(ve) ) - return args, config, prng, topol + return args, config, prng, topol, intracomm, intercomm def extant_file(x): diff --git a/hymd/logger.py b/hymd/logger.py index 6f2c1d78..822b757e 100644 --- a/hymd/logger.py +++ b/hymd/logger.py @@ -9,6 +9,11 @@ class MPIFilterRoot(logging.Filter): """Log output Filter wrapper class for the root MPI rank log""" + + comm = None + + def __init__(self, comm=MPI.COMM_WORLD): + self.comm = comm def filter(self, record): """Log event message filter @@ -20,9 +25,9 @@ def filter(self, record): """ if record.funcName == "": record.funcName = "main" - if MPI.COMM_WORLD.Get_rank() == 0: - record.rank = MPI.COMM_WORLD.Get_rank() - record.size = MPI.COMM_WORLD.Get_size() + if self.comm.Get_rank() == 0: + record.rank = self.comm.Get_rank() + record.size = self.comm.Get_size() return True else: return False @@ -31,6 +36,11 @@ def filter(self, record): class MPIFilterAll(logging.Filter): """Log output Filter wrapper class for the all-MPI-ranks log""" + comm = None + + def __init__(self, comm=MPI.COMM_WORLD): + self.comm = comm + def filter(self, record): """Log event message filter @@ -41,8 +51,8 @@ def filter(self, record): """ if record.funcName == "": record.funcName = "main" - record.rank = MPI.COMM_WORLD.Get_rank() - record.size = MPI.COMM_WORLD.Get_size() + record.rank = self.comm.Get_rank() + record.size = self.comm.Get_size() return True @@ -96,7 +106,7 @@ class Logger: all_ranks = logging.getLogger("HyMD.all_ranks") @classmethod - def setup(cls, default_level=logging.INFO, log_file=None, verbose=False): + def setup(cls, default_level=logging.INFO, log_file=None, verbose=False, comm=MPI.COMM_WORLD): """Sets up the logger object. If a :code:`log_file` path is provided, log event messages are output @@ -127,8 +137,10 @@ def setup(cls, default_level=logging.INFO, log_file=None, verbose=False): cls.rank0.setLevel(level) cls.all_ranks.setLevel(level) - cls.rank0.addFilter(MPIFilterRoot()) - cls.all_ranks.addFilter(MPIFilterAll()) + root_filter = MPIFilterRoot(comm) + all_filter = MPIFilterAll(comm) + cls.rank0.addFilter(root_filter) + cls.all_ranks.addFilter(all_filter) if (not log_file) and (not log_to_stdout): return diff --git a/hymd/main.py b/hymd/main.py index 6efab412..b0fbcb06 100644 --- a/hymd/main.py +++ b/hymd/main.py @@ -42,13 +42,15 @@ def main(): the molecular dynamics loop. """ comm = MPI.COMM_WORLD - rank = comm.Get_rank() - size = comm.Get_size() - if rank == 0: + if comm.Get_rank() == 0: start_time = datetime.datetime.now() - args, config, prng, topol = configure_runtime(sys.argv[1:], comm) + args, config, prng, topol, intracomm, intercomm = configure_runtime(sys.argv[1:], comm) + + # Get rank and size + rank = intracomm.Get_rank() + size = intracomm.Get_size() if args.double_precision: dtype = np.float64 @@ -65,30 +67,6 @@ def main(): from .force import compute_angle_forces__fortran as compute_angle_forces from .force import compute_dihedral_forces__fortran as compute_dihedral_forces - # multi replica setup - n_replicas = len(args.replica_dirs) - if n_replicas > 1: - # split COMM_WORLD in intracomm and intercomm - n_intra = int(size / n_replicas) - if (n_replicas * n_intra != size): - err_str = "Inconsistent number of ranks per replica" - Logger.rank0.log(logging.ERROR, err_str) - - if rank == 0: - raise AssertionError(err_str) - intracomm = comm.Split(int(rank / n_intra), rank) - intercomm = comm.Split(rank % n_intra, rank) - - # assign directory to each rank - os.chdir(args.replica_dirs[int(rank / n_intra)]) - - # update rank and size - rank = intracomm.Get_rank() - size = intracomm.Get_size() - else: - intracomm = comm - intercomm = None - # read input .hdf5 driver = "mpio" if not args.disable_mpio else None _kwargs = {"driver": driver, "comm": intracomm} if not args.disable_mpio else {} diff --git a/test/test_configure_runtime.py b/test/test_configure_runtime.py index 65cd3e87..963f2598 100644 --- a/test/test_configure_runtime.py +++ b/test/test_configure_runtime.py @@ -23,11 +23,10 @@ def test_extant_file(tmp_path, caplog): f.write("test") ret_path = extant_file(file_path) - assert ret_path == file_path + assert ret_path == os.path.abspath(file_path) caplog.clear() - def test_configure_runtime(h5py_molecules_file, config_toml, change_tmp_dir, caplog): caplog.set_level(logging.ERROR) @@ -36,63 +35,71 @@ def test_configure_runtime(h5py_molecules_file, config_toml, basearg = [config_toml_file, h5py_molecules_file] - parsed, config, prng, topol = configure_runtime(basearg, comm) + parsed, config, prng, topol, intracomm, intercomm = configure_runtime(basearg, comm) assert isinstance(parsed, Namespace) assert isinstance(config, Config) assert isinstance(prng, np.random.Generator) + assert isinstance(intracomm, MPI.Comm) + assert isinstance(intercomm, (MPI.Comm, type(None))) assert parsed.destdir == "." + assert len(parsed.replica_dirs) == 0 assert parsed.logfile == "sim.log" assert parsed.plumed_outfile == "plumed.out" - parsed, _, _, _ = configure_runtime(["-v", "2"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["-v", "2"]+basearg, comm) assert parsed.verbose == 2 - parsed, _, _, _ = configure_runtime(["--disable-field"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--disable-field"]+basearg, comm) assert parsed.disable_field - parsed, _, _, _ = configure_runtime(["--disable-bonds"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--disable-bonds"]+basearg, comm) assert parsed.disable_bonds - parsed, _, _, _ = configure_runtime(["--disable-angle-bonds"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--disable-angle-bonds"]+basearg, comm) assert parsed.disable_angle_bonds - parsed, _, _, _ = configure_runtime(["--disable-dihedrals"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--disable-dihedrals"]+basearg, comm) assert parsed.disable_dihedrals - parsed, _, _, _ = configure_runtime(["--disable-dipole"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--disable-dipole"]+basearg, comm) assert parsed.disable_dipole - parsed, _, _, _ = configure_runtime(["--double-precision"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--double-precision"]+basearg, comm) assert parsed.double_precision - parsed, _, _, _ = configure_runtime(["--double-output"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--double-output"]+basearg, comm) assert parsed.double_output - parsed, _, _, _ = configure_runtime(["--dump-per-particle"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--dump-per-particle"]+basearg, comm) assert parsed.dump_per_particle - parsed, _, _, _ = configure_runtime(["--force-output"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--force-output"]+basearg, comm) assert parsed.force_output - parsed, _, _, _ = configure_runtime(["--velocity-output"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--velocity-output"]+basearg, comm) assert parsed.velocity_output - parsed, _, _, _ = configure_runtime(["--disable-mpio"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--disable-mpio"]+basearg, comm) assert parsed.disable_mpio - parsed, _, _, _ = configure_runtime(["--destdir", "testdir"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--destdir", "testdir"]+basearg, comm) assert parsed.destdir == "testdir" - parsed, _, _, _ = configure_runtime(["--seed", "54321"]+basearg, comm) + # parsed, _, _, _, _, _ = configure_runtime(basearg+["--replica-dirs", "a", "b"], comm) + # assert len(parsed.replica_dirs) == 2 + # assert parsed.replica_dirs[0] == "a" + # assert parsed.replica_dirs[1] == "b" + + parsed, _, _, _, _, _ = configure_runtime(["--seed", "54321"]+basearg, comm) assert parsed.seed == 54321 - parsed, _, _, _ = configure_runtime(["--logfile", "test.log"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--logfile", "test.log"]+basearg, comm) assert parsed.logfile == "test.log" - parsed, _, _, _ = configure_runtime(["--plumed", "test.log"]+basearg, comm) - assert parsed.plumed == "test.log" + parsed, _, _, _, _, _ = configure_runtime(["--plumed", "test.log"]+basearg, comm) + assert parsed.plumed == os.path.abspath("test.log") - parsed, _, _, _ = configure_runtime( + parsed, _, _, _, _, _ = configure_runtime( ["--plumed-outfile", "test.plumed.out"]+basearg, comm ) diff --git a/test/test_plumed.py b/test/test_plumed.py index be7cfd35..0de47312 100644 --- a/test/test_plumed.py +++ b/test/test_plumed.py @@ -122,7 +122,7 @@ def test_plumed_bias_obj(molecules_with_solvent, change_test_dir, tmp_path, @pytest.mark.mpi() @pytest.mark.skip(reason="Currently fails in CI due to environment variable") -def test_fail_plumed_bias_obj(monkeypatch): +def test_fail_plumed_bias_obj(monkeypatch, caplog): pytest.importorskip("plumed") comm = MPI.COMM_WORLD @@ -168,7 +168,7 @@ def test_fail_plumed_bias_obj(monkeypatch): assert all([(s in message) for s in cmp_strings]) -def test_unavailable_plumed(hide_available_plumed): +def test_unavailable_plumed(hide_available_plumed, caplog): import hymd.plumed comm = MPI.COMM_WORLD