From 1ebd479f3c3248ad6431667924f5dafa60497006 Mon Sep 17 00:00:00 2001 From: Henrique Musseli Cezar Date: Thu, 28 Sep 2023 16:20:54 +0200 Subject: [PATCH] Initial attempt of multi replica --- hymd/configure_runtime.py | 16 ++++- hymd/file_io.py | 14 +---- hymd/main.py | 121 +++++++++++++++++++++++--------------- hymd/plumed.py | 33 +++++++---- test/test_plumed.py | 8 +-- 5 files changed, 116 insertions(+), 76 deletions(-) diff --git a/hymd/configure_runtime.py b/hymd/configure_runtime.py index e1a5a67f..0221603f 100644 --- a/hymd/configure_runtime.py +++ b/hymd/configure_runtime.py @@ -115,6 +115,9 @@ def configure_runtime(args_in, comm): ap.add_argument( "--destdir", default=".", help="Write output to specified directory" ) + ap.add_argument( + "--replica-dirs", type=str, nargs="+", default=[], help="Directories to store results for each replica" + ) ap.add_argument( "--seed", default=None, @@ -143,12 +146,21 @@ def configure_runtime(args_in, comm): type=extant_file, help="Gmx-like topology file in toml format" ) - ap.add_argument("config", help="Config .py or .toml input configuration script") + ap.add_argument("config", type=extant_file, help="Config .py or .toml input configuration script") ap.add_argument("input", help="input.hdf5") args = ap.parse_args(args_in) + # check if we have at least one rank per replica + if comm.Get_size() < len(args.replica_dirs): + raise ValueError("You should have at least one MPI rank per replica.") + + # block destdir with replicas + if (len(args.replica_dirs) > 0) and args.destdir != ".": + raise ValueError("You should not specify a destination directory when using replicas.") + if comm.Get_rank() == 0: os.makedirs(args.destdir, exist_ok=True) + comm.barrier() # Safely define seeds @@ -241,5 +253,5 @@ def extant_file(x): # Argparse uses the ArgumentTypeError to give a rejection message like: # error: argument input: x does not exist raise argparse.ArgumentTypeError("{0} does not exist".format(x)) - return x + return os.path.abspath(x) diff --git a/hymd/file_io.py b/hymd/file_io.py index 0a918617..f8c81d04 100644 --- a/hymd/file_io.py +++ b/hymd/file_io.py @@ -197,20 +197,8 @@ def store_static( creator_group = h5md_group.create_group("creator") creator_group.attrs["name"] = np.string_("Hylleraas MD") - # Get HyMD version. Also grab the user email from git config if we - # can find it. + # Get HyMD version creator_group.attrs["version"] = np.string_(get_version()) - try: - import git - - try: - reader = repo.config_reader() - user_email = reader.get_value("user", "email") - author_group.attrs["email"] = np.string_(user_email) - except: - pass - except: - pass h5md.particles_group = h5md.file.create_group("/particles") h5md.all_particles = h5md.particles_group.create_group("all") diff --git a/hymd/main.py b/hymd/main.py index c05952a2..be3ecc67 100644 --- a/hymd/main.py +++ b/hymd/main.py @@ -4,6 +4,7 @@ import h5py import logging import sys +import os from mpi4py import MPI import numpy as np import pmesh.pm as pmesh @@ -64,9 +65,33 @@ def main(): from .force import compute_angle_forces__fortran as compute_angle_forces from .force import compute_dihedral_forces__fortran as compute_dihedral_forces + # multi replica setup + n_replicas = len(args.replica_dirs) + if n_replicas > 1: + # split COMM_WORLD in intracomm and intercomm + n_intra = int(size / n_replicas) + if (n_replicas * n_intra != size): + err_str = "Inconsistent number of ranks per replica" + Logger.rank0.log(logging.ERROR, err_str) + + if rank == 0: + raise AssertionError(err_str) + intracomm = comm.Split(int(rank / n_intra), rank) + intercomm = comm.Split(rank % n_intra, rank) + + rank = intracomm.Get_rank() + size = intracomm.Get_size() + + # assign directory to each rank + os.chdir(args.replica_dirs[int(rank / n_intra)]) + else: + intracomm = comm + intercomm = None + + # read input .hdf5 driver = "mpio" if not args.disable_mpio else None - _kwargs = {"driver": driver, "comm": comm} if not args.disable_mpio else {} + _kwargs = {"driver": driver, "comm": intracomm} if not args.disable_mpio else {} with h5py.File(args.input, "r", **_kwargs) as in_file: rank_range, molecules_flag = distribute_input( in_file, @@ -74,7 +99,7 @@ def main(): size, config.n_particles, config.max_molecule_size if config.max_molecule_size else 201, - comm=comm, + comm=intracomm, ) indices = in_file["indices"][rank_range] positions = in_file["coordinates"][-1, rank_range, :] @@ -121,7 +146,7 @@ def main(): input_box = np.array([None, None, None]) # finishes config setup and checks - config = check_config(config, indices, names, types, charges, input_box, comm=comm) + config = check_config(config, indices, names, types, charges, input_box, comm=intracomm) # import barostat if necessary if config.barostat_type == "berendsen": @@ -164,13 +189,13 @@ def conv_fun(comm, diffmesh): dielectric_flag = False if config.start_temperature: - velocities = generate_initial_velocities(velocities, config, prng, comm=comm) + velocities = generate_initial_velocities(velocities, config, prng, comm=intracomm) elif config.cancel_com_momentum: - velocities = cancel_com_momentum(velocities, config, comm=comm) + velocities = cancel_com_momentum(velocities, config, comm=intracomm) # set all PRNG to the root's PRNG # this is done to avoid communication between ranks in the thermostat - prng = comm.bcast(prng, root=0) + prng = intracomm.bcast(prng, root=0) bond_forces = np.zeros_like(positions) angle_forces = np.zeros_like(positions) @@ -208,7 +233,7 @@ def conv_fun(comm, diffmesh): hamiltonian = get_hamiltonian(config) - pm_objs = initialize_pm(pmesh, config, comm) + pm_objs = initialize_pm(pmesh, config, intracomm) pm, field_list, elec_common_list, coulomb_list = pm_objs ( phi, @@ -288,7 +313,7 @@ def conv_fun(comm, diffmesh): bonds=bonds if molecules_flag else None, topol=False if topol is None else True, verbose=args.verbose, - comm=comm, + comm=intracomm, ) if charges_flag: @@ -360,13 +385,13 @@ def conv_fun(comm, diffmesh): v_ext, config, layouts, - comm=comm, + comm=intracomm, ) compute_field_force( layouts, positions, force_on_grid, field_forces, types, config.n_types ) else: - kinetic_energy = comm.allreduce(0.5 * config.mass * np.sum(velocities**2)) + kinetic_energy = intracomm.allreduce(0.5 * config.mass * np.sum(velocities**2)) if charges_flag: layout_q = pm.decompose(positions) @@ -398,11 +423,11 @@ def conv_fun(comm, diffmesh): pm, positions, config, - comm=comm, + comm=intracomm, ) field_q_energy = compute_field_energy_q_GPE( - config, phi_eps, field_q_energy, elec_dot, comm=comm + config, phi_eps, field_q_energy, elec_dot, comm=intracomm ) if config.coulombtype == "PIC_Spectral": @@ -460,7 +485,7 @@ def conv_fun(comm, diffmesh): raise NotImplementedError(err_str) # Check if we have a protein - protein_flag = comm.allreduce(bonds_4_type.any() == 1) + protein_flag = intracomm.allreduce(bonds_4_type.any() == 1) if protein_flag and not args.disable_dipole: # Each rank will have different n_tors, we do not need to # domain decompose dipoles @@ -511,7 +536,7 @@ def conv_fun(comm, diffmesh): bonds_2_equilibrium, bonds_2_stength, ) - bond_energy = comm.allreduce(bond_energy_, MPI.SUM) + bond_energy = intracomm.allreduce(bond_energy_, MPI.SUM) else: bonds_2_atom1, bonds_2_atom2 = [], [] if not args.disable_angle_bonds: @@ -525,7 +550,7 @@ def conv_fun(comm, diffmesh): bonds_3_equilibrium, bonds_3_stength, ) - angle_energy = comm.allreduce(angle_energy_, MPI.SUM) + angle_energy = intracomm.allreduce(angle_energy_, MPI.SUM) # angle_pr = comm.allreduce(angle_pr_, MPI.SUM) if not args.disable_dihedrals: @@ -544,7 +569,7 @@ def conv_fun(comm, diffmesh): bonds_4_last, dipole_flag, ) - dihedral_energy = comm.allreduce(dihedral_energy_, MPI.SUM) + dihedral_energy = intracomm.allreduce(dihedral_energy_, MPI.SUM) if protein_flag and not args.disable_dipole: dipole_positions = np.reshape(dipole_positions, (4 * n_tors, 3)) @@ -600,6 +625,7 @@ def conv_fun(comm, diffmesh): config, double_out=args.double_output, disable_mpio=args.disable_mpio, + comm=intracomm, ) store_static( @@ -617,7 +643,7 @@ def conv_fun(comm, diffmesh): charges=charges if charges_flag else False, dielectrics=dielectric_sorted if dielectric_flag else False, plumed_out=True if args.plumed else False, - comm=comm, + comm=intracomm, ) if config.n_print > 0: @@ -639,10 +665,10 @@ def conv_fun(comm, diffmesh): v_ext, config, layouts, - comm=comm, + comm=intracomm, ) else: - kinetic_energy = comm.allreduce(0.5 * config.mass * np.sum(velocities**2)) + kinetic_energy = intracomm.allreduce(0.5 * config.mass * np.sum(velocities**2)) temperature = ( (2 / 3) @@ -664,7 +690,7 @@ def conv_fun(comm, diffmesh): positions, bond_pr_, angle_pr_, - comm=comm, + comm=intracomm, ) # if rank ==0 : print(pressure[9:12]) @@ -688,7 +714,8 @@ def conv_fun(comm, diffmesh): config, args.plumed, args.plumed_outfile, - comm=comm, + intracomm=intracomm, + intercomm=intercomm, verbose=args.verbose ) @@ -747,7 +774,7 @@ def conv_fun(comm, diffmesh): charge_out=charges_flag, plumed_out=True if args.plumed else False, dump_per_particle=args.dump_per_particle, - comm=comm, + comm=intracomm, ) if rank == 0: @@ -900,7 +927,7 @@ def conv_fun(comm, diffmesh): angle_pr_, step, prng, - comm=comm, + comm=intracomm, ) elif config.barostat.lower() == "semiisotropic": @@ -921,7 +948,7 @@ def conv_fun(comm, diffmesh): angle_pr_, step, prng, - comm=comm, + comm=intracomm, ) pm, field_list, elec_common_list, coulomb_list = pm_objs ( @@ -960,11 +987,11 @@ def conv_fun(comm, diffmesh): # inner step if molecules_flag: if not args.disable_bonds: - bond_energy = comm.allreduce(bond_energy_, MPI.SUM) + bond_energy = intracomm.allreduce(bond_energy_, MPI.SUM) if not args.disable_angle_bonds: - angle_energy = comm.allreduce(angle_energy_, MPI.SUM) + angle_energy = intracomm.allreduce(angle_energy_, MPI.SUM) if not args.disable_dihedrals: - dihedral_energy = comm.allreduce(dihedral_energy_, MPI.SUM) + dihedral_energy = intracomm.allreduce(dihedral_energy_, MPI.SUM) # Update slow forces if not args.disable_field: @@ -1027,11 +1054,11 @@ def conv_fun(comm, diffmesh): pm, positions, config, - comm=comm, + comm=intracomm, ) field_q_energy = compute_field_energy_q_GPE( - config, phi_eps, field_q_energy, elec_dot, comm=comm + config, phi_eps, field_q_energy, elec_dot, comm=intracomm ) if config.coulombtype == "PIC_Spectral": @@ -1113,13 +1140,13 @@ def conv_fun(comm, diffmesh): field_energy, kinetic_energy, ) = compute_field_and_kinetic_energy( phi, velocities, hamiltonian, positions, types, v_ext, - config, layouts, comm=comm, + config, layouts, comm=intracomm, ) if charges_flag and config.coulombtype == "PIC_Spectral": field_q_energy = compute_field_energy_q( config, phi_q_fourier, elec_energy_field, - field_q_energy, comm=comm, + field_q_energy, comm=intracomm, ) else: field_energy = 0.0 @@ -1127,7 +1154,7 @@ def conv_fun(comm, diffmesh): # PLUMED likes the energy per rank poteng = (field_energy + bond_energy + angle_energy - + dihedral_energy + field_q_energy) / size + + dihedral_energy + field_q_energy) / size else: poteng = 0.0 @@ -1191,7 +1218,7 @@ def conv_fun(comm, diffmesh): bonds=bonds if molecules_flag else None, topol=False if topol is None else True, verbose=args.verbose, - comm=comm, + comm=intracomm, ) if charges_flag: @@ -1307,12 +1334,12 @@ def conv_fun(comm, diffmesh): # Thermostat if config.target_temperature and np.mod(step, config.n_b) == 0: - csvr_thermostat(velocities, names, config, prng, comm=comm) + csvr_thermostat(velocities, names, config, prng, comm=intracomm) # Remove total linear momentum if config.cancel_com_momentum: if np.mod(step, config.cancel_com_momentum) == 0: - velocities = cancel_com_momentum(velocities, config, comm=comm) + velocities = cancel_com_momentum(velocities, config, comm=intracomm) # Print trajectory if config.n_print > 0: @@ -1355,16 +1382,16 @@ def conv_fun(comm, diffmesh): v_ext, config, layouts, - comm=comm, + comm=intracomm, ) if charges_flag: if config.coulombtype == "PIC_Spectral_GPE": field_q_energy = compute_field_energy_q_GPE( - config, phi_eps, field_q_energy, elec_dot, comm=comm + config, phi_eps, field_q_energy, elec_dot, comm=intracomm ) else: - kinetic_energy = comm.allreduce( + kinetic_energy = intracomm.allreduce( 0.5 * config.mass * np.sum(velocities**2) ) temperature = ( @@ -1389,7 +1416,7 @@ def conv_fun(comm, diffmesh): positions, bond_pr_, angle_pr_, - comm=comm, + comm=intracomm, ) else: pressure = 0.0 # 0.0 indicates not calculated. To be changed. @@ -1430,7 +1457,7 @@ def conv_fun(comm, diffmesh): charge_out=charges_flag, plumed_out=True if args.plumed else False, dump_per_particle=args.dump_per_particle, - comm=comm, + comm=intracomm, ) if np.mod(step, config.n_print * config.n_flush) == 0: out_dataset.flush() @@ -1485,7 +1512,7 @@ def conv_fun(comm, diffmesh): v_ext, config, layouts, - comm=comm, + comm=intracomm, ) if charges_flag: @@ -1518,11 +1545,11 @@ def conv_fun(comm, diffmesh): pm, positions, config, - comm=comm, + comm=intracomm, ) field_q_energy = compute_field_energy_q_GPE( - config, phi_eps, field_q_energy, elec_dot, comm=comm + config, phi_eps, field_q_energy, elec_dot, comm=intracomm ) if config.coulombtype == "PIC_Spectral": @@ -1543,7 +1570,7 @@ def conv_fun(comm, diffmesh): ) else: - kinetic_energy = comm.allreduce(0.5 * config.mass * np.sum(velocities**2)) + kinetic_energy = intracomm.allreduce(0.5 * config.mass * np.sum(velocities**2)) frame = (step + 1) // config.n_print temperature = ( (2 / 3) * kinetic_energy / (config.gas_constant * config.n_particles) @@ -1566,7 +1593,7 @@ def conv_fun(comm, diffmesh): bond_pr_, angle_pr_, Vbar_elec, - comm=comm, + comm=intracomm, ) else: pressure = 0.0 # 0.0 indicates not calculated. To be changed. @@ -1607,9 +1634,9 @@ def conv_fun(comm, diffmesh): charge_out=charges_flag, plumed_out=True if args.plumed else False, dump_per_particle=args.dump_per_particle, - comm=comm, + comm=intracomm, ) - out_dataset.close_file() + out_dataset.close_file(comm=intracomm) if args.plumed: plumed.finalize() diff --git a/hymd/plumed.py b/hymd/plumed.py index d50f91a3..87c0acb7 100644 --- a/hymd/plumed.py +++ b/hymd/plumed.py @@ -36,8 +36,10 @@ class PlumedBias: Used as a pointer to an :code:`double` to store the bias energy. plumed_version : (1,) numpy.ndarray Used as a pointer to an :code:`int` to store PLUMED API version. - comm : mpi4py.Comm - MPI communicator to use for rank communication. + intracomm : mpi4py.Comm + MPI communicator to use for rank communication within a replica. + intercomm : mpi4py.Comm + MPI communicator to use for rank communication between replicas. ready : bool Stores wether the :code:`calc()` method can be called or not. """ @@ -48,11 +50,12 @@ class PlumedBias: charges = None plumed_bias = np.zeros(1, np.double) plumed_version = np.zeros(1, dtype=np.intc) - comm = None + intracomm = None + intercomm = None ready = False verbose = None - def __init__(self, config, plumeddat, logfile, comm=MPI.COMM_WORLD, verbose=1): + def __init__(self, config, plumeddat, logfile, intracomm=MPI.COMM_WORLD, intercomm=None, verbose=1): """Constructor Parameters @@ -63,8 +66,10 @@ def __init__(self, config, plumeddat, logfile, comm=MPI.COMM_WORLD, verbose=1): Path file containing PLUMED input. logfile : str Path to PLUMED's output file - comm : mpi4py.Comm, optional - MPI communicator to use for rank communication. + intracomm : mpi4py.Comm, optional + MPI communicator to use for rank communication within a replica. + intercomm : mpi4py.Comm, optional + MPI communicator to use for rank communication between replicas. verbose : int, optional Specify the logging event verbosity of this object. @@ -76,12 +81,13 @@ def __init__(self, config, plumeddat, logfile, comm=MPI.COMM_WORLD, verbose=1): import plumed except ImportError: err_str = ( - "You are trying to use PLUMED " "but HyMD could not import py-plumed." + "You are trying to use PLUMED but HyMD could not import py-plumed." ) Logger.rank0.log(logging.ERROR, err_str) raise ImportError(err_str) - self.comm = comm + self.intracomm = intracomm + self.intercomm = intercomm self.verbose = verbose try: @@ -103,12 +109,19 @@ def __init__(self, config, plumeddat, logfile, comm=MPI.COMM_WORLD, verbose=1): self.plumed_obj.cmd("getApiVersion", self.plumed_version) if self.plumed_version[0] <= 3: - err_str = "HyMD requires a PLUMED API > 3. " "Use a newer PLUMED kernel." + err_str = "HyMD requires a PLUMED API > 3. Use a newer PLUMED kernel." Logger.rank0.log(logging.ERROR, err_str) raise AssertionError(err_str) self.plumed_obj.cmd("setMDEngine", "HyMD") - self.plumed_obj.cmd("setMPIComm", comm) + + if intercomm is not None: + if intracomm.Get_rank() == 0: + self.plumed_obj.cmd("GREX setMPIIntercomm", intercomm) + self.plumed_obj.cmd("GREX setMPIIntracomm", intracomm) + self.plumed_obj.cmd("GREX init") + + self.plumed_obj.cmd("setMPIComm", intracomm) Logger.rank0.log( logging.INFO, f"Attempting to read PLUMED input from {plumeddat}" diff --git a/test/test_plumed.py b/test/test_plumed.py index 1a2c8d74..be7cfd35 100644 --- a/test/test_plumed.py +++ b/test/test_plumed.py @@ -52,7 +52,7 @@ def test_plumed_bias_obj(molecules_with_solvent, change_test_dir, tmp_path, config, os.path.join(tmp_path,"plumed.dat"), os.path.join(tmp_path,"plumed.out"), - comm=comm, + intracomm=comm, verbose=2 ) @@ -143,13 +143,13 @@ def test_fail_plumed_bias_obj(monkeypatch): break import hymd.plumed - + with pytest.raises(RuntimeError) as recorded_error: with hymd.plumed.PlumedBias( config, "test.in", "test.out", - comm=comm, + intracomm=comm, verbose=2 ) as _: if rank == 0: @@ -185,7 +185,7 @@ def test_unavailable_plumed(hide_available_plumed): config, "test.in", "test.out", - comm=comm, + intracomm=comm, verbose=2 ) as _: if rank == 0: