diff --git a/hymd/configure_runtime.py b/hymd/configure_runtime.py index e1a5a67f..37c0d46b 100644 --- a/hymd/configure_runtime.py +++ b/hymd/configure_runtime.py @@ -29,6 +29,14 @@ def configure_runtime(args_in, comm): Namespace containing command line arguments. config : hymd.input_parser.Config Parsed configuration object. + prng : np.random.Generator + Random number generator for this rank. + topol : dict + Topology dictionary. + intracomm : mpi4py.Comm + MPI communicator to use for rank communication within a replica. + intercomm : mpi4py.Comm + MPI communicator to use for rank communication between replicas. """ ap = argparse.ArgumentParser() @@ -115,6 +123,9 @@ def configure_runtime(args_in, comm): ap.add_argument( "--destdir", default=".", help="Write output to specified directory" ) + ap.add_argument( + "--replica-dirs", type=str, nargs="+", default=[], help="Directories to store results for each replica" + ) ap.add_argument( "--seed", default=None, @@ -143,13 +154,13 @@ def configure_runtime(args_in, comm): type=extant_file, help="Gmx-like topology file in toml format" ) - ap.add_argument("config", help="Config .py or .toml input configuration script") + ap.add_argument("config", type=extant_file, help="Config .py or .toml input configuration script") ap.add_argument("input", help="input.hdf5") args = ap.parse_args(args_in) - if comm.Get_rank() == 0: - os.makedirs(args.destdir, exist_ok=True) - comm.barrier() + # check if we have at least one rank per replica + if comm.Get_size() < len(args.replica_dirs) and comm.Get_rank() == 0: + raise ValueError("You should have at least one MPI rank per replica.") # Safely define seeds seeds = None @@ -165,11 +176,40 @@ def configure_runtime(args_in, comm): # Setup a PRNG for each rank prng = np.random.default_rng(seeds[comm.Get_rank()]) + # multi replica setup + n_replicas = len(args.replica_dirs) + if n_replicas > 1: + # split comm=COMM_WORLD in intracomm and intercomm + size = comm.Get_size() + rank = comm.Get_rank() + + n_intra = int(size / n_replicas) + if (n_replicas * n_intra != size): + err_str = "Inconsistent number of ranks per replica" + + if rank == 0: + raise AssertionError(err_str) + intracomm = comm.Split(int(rank / n_intra), rank) + intercomm = comm.Split(rank % n_intra, rank) + + # assign directory to each rank + os.chdir(args.replica_dirs[int(rank / n_intra)]) + + else: + intracomm = comm + intercomm = None + + if intracomm.Get_rank() == 0: + os.makedirs(args.destdir, exist_ok=True) + + intracomm.barrier() + # Setup logger Logger.setup( default_level=logging.INFO, log_file=f"{args.destdir}/{args.logfile}", verbose=args.verbose, + comm=intracomm, ) # print header info @@ -229,7 +269,7 @@ def profile_atexit(): f"Unable to parse configuration file {args.config}" f"\n\ntoml parse traceback:" + repr(ve) ) - return args, config, prng, topol + return args, config, prng, topol, intracomm, intercomm def extant_file(x): @@ -241,5 +281,5 @@ def extant_file(x): # Argparse uses the ArgumentTypeError to give a rejection message like: # error: argument input: x does not exist raise argparse.ArgumentTypeError("{0} does not exist".format(x)) - return x + return os.path.abspath(x) diff --git a/hymd/file_io.py b/hymd/file_io.py index 0a918617..f8c81d04 100644 --- a/hymd/file_io.py +++ b/hymd/file_io.py @@ -197,20 +197,8 @@ def store_static( creator_group = h5md_group.create_group("creator") creator_group.attrs["name"] = np.string_("Hylleraas MD") - # Get HyMD version. Also grab the user email from git config if we - # can find it. + # Get HyMD version creator_group.attrs["version"] = np.string_(get_version()) - try: - import git - - try: - reader = repo.config_reader() - user_email = reader.get_value("user", "email") - author_group.attrs["email"] = np.string_(user_email) - except: - pass - except: - pass h5md.particles_group = h5md.file.create_group("/particles") h5md.all_particles = h5md.particles_group.create_group("all") diff --git a/hymd/logger.py b/hymd/logger.py index 6f2c1d78..41deaade 100644 --- a/hymd/logger.py +++ b/hymd/logger.py @@ -9,6 +9,11 @@ class MPIFilterRoot(logging.Filter): """Log output Filter wrapper class for the root MPI rank log""" + + comm = None + + def __init__(self, comm=MPI.COMM_WORLD): + self.comm = comm def filter(self, record): """Log event message filter @@ -20,9 +25,9 @@ def filter(self, record): """ if record.funcName == "": record.funcName = "main" - if MPI.COMM_WORLD.Get_rank() == 0: - record.rank = MPI.COMM_WORLD.Get_rank() - record.size = MPI.COMM_WORLD.Get_size() + if self.comm.Get_rank() == 0: + record.rank = self.comm.Get_rank() + record.size = self.comm.Get_size() return True else: return False @@ -31,6 +36,11 @@ def filter(self, record): class MPIFilterAll(logging.Filter): """Log output Filter wrapper class for the all-MPI-ranks log""" + comm = None + + def __init__(self, comm=MPI.COMM_WORLD): + self.comm = comm + def filter(self, record): """Log event message filter @@ -41,11 +51,33 @@ def filter(self, record): """ if record.funcName == "": record.funcName = "main" - record.rank = MPI.COMM_WORLD.Get_rank() - record.size = MPI.COMM_WORLD.Get_size() + record.rank = self.comm.Get_rank() + record.size = self.comm.Get_size() return True +class MPIFilterStdout(logging.Filter): + """Log output Filter wrapper class for filtering STDOUT log""" + + def filter(self, record): + """Log event message filter + + Parameters + ---------- + record : logging.LogRecord + LogRecord object corresponding to the log event. + """ + if record.funcName == "": + record.funcName = "main" + record.funcName = "replica 0 " + record.funcName + if MPI.COMM_WORLD.Get_rank() == 0: + record.rank = MPI.COMM_WORLD.Get_rank() + record.size = MPI.COMM_WORLD.Get_size() + return True + else: + return False + + class Logger: """Log output handler class @@ -96,7 +128,7 @@ class Logger: all_ranks = logging.getLogger("HyMD.all_ranks") @classmethod - def setup(cls, default_level=logging.INFO, log_file=None, verbose=False): + def setup(cls, default_level=logging.INFO, log_file=None, verbose=False, comm=MPI.COMM_WORLD): """Sets up the logger object. If a :code:`log_file` path is provided, log event messages are output @@ -127,8 +159,10 @@ def setup(cls, default_level=logging.INFO, log_file=None, verbose=False): cls.rank0.setLevel(level) cls.all_ranks.setLevel(level) - cls.rank0.addFilter(MPIFilterRoot()) - cls.all_ranks.addFilter(MPIFilterAll()) + root_filter = MPIFilterRoot(comm) + all_filter = MPIFilterAll(comm) + cls.rank0.addFilter(root_filter) + cls.all_ranks.addFilter(all_filter) if (not log_file) and (not log_to_stdout): return @@ -145,6 +179,7 @@ def setup(cls, default_level=logging.INFO, log_file=None, verbose=False): cls.stdout_handler.setLevel(level) cls.stdout_handler.setStream(sys.stdout) cls.stdout_handler.setFormatter(cls.formatter) + cls.stdout_handler.addFilter(MPIFilterStdout()) cls.rank0.addHandler(cls.stdout_handler) cls.all_ranks.addHandler(cls.stdout_handler) diff --git a/hymd/main.py b/hymd/main.py index c05952a2..e3676623 100644 --- a/hymd/main.py +++ b/hymd/main.py @@ -4,6 +4,7 @@ import h5py import logging import sys +import os from mpi4py import MPI import numpy as np import pmesh.pm as pmesh @@ -41,14 +42,17 @@ def main(): the molecular dynamics loop. """ comm = MPI.COMM_WORLD - rank = comm.Get_rank() - size = comm.Get_size() + args, config, prng, topol, intracomm, intercomm = configure_runtime(sys.argv[1:], comm) + + # Get rank and size + rank = intracomm.Get_rank() + size = intracomm.Get_size() + + # Start timer after configure_runtime if rank == 0: start_time = datetime.datetime.now() - args, config, prng, topol = configure_runtime(sys.argv[1:], comm) - if args.double_precision: dtype = np.float64 config.dtype = dtype @@ -66,7 +70,7 @@ def main(): # read input .hdf5 driver = "mpio" if not args.disable_mpio else None - _kwargs = {"driver": driver, "comm": comm} if not args.disable_mpio else {} + _kwargs = {"driver": driver, "comm": intracomm} if not args.disable_mpio else {} with h5py.File(args.input, "r", **_kwargs) as in_file: rank_range, molecules_flag = distribute_input( in_file, @@ -74,7 +78,7 @@ def main(): size, config.n_particles, config.max_molecule_size if config.max_molecule_size else 201, - comm=comm, + comm=intracomm, ) indices = in_file["indices"][rank_range] positions = in_file["coordinates"][-1, rank_range, :] @@ -121,7 +125,7 @@ def main(): input_box = np.array([None, None, None]) # finishes config setup and checks - config = check_config(config, indices, names, types, charges, input_box, comm=comm) + config = check_config(config, indices, names, types, charges, input_box, comm=intracomm) # import barostat if necessary if config.barostat_type == "berendsen": @@ -164,13 +168,13 @@ def conv_fun(comm, diffmesh): dielectric_flag = False if config.start_temperature: - velocities = generate_initial_velocities(velocities, config, prng, comm=comm) + velocities = generate_initial_velocities(velocities, config, prng, comm=intracomm) elif config.cancel_com_momentum: - velocities = cancel_com_momentum(velocities, config, comm=comm) + velocities = cancel_com_momentum(velocities, config, comm=intracomm) # set all PRNG to the root's PRNG # this is done to avoid communication between ranks in the thermostat - prng = comm.bcast(prng, root=0) + prng = intracomm.bcast(prng, root=0) bond_forces = np.zeros_like(positions) angle_forces = np.zeros_like(positions) @@ -208,7 +212,7 @@ def conv_fun(comm, diffmesh): hamiltonian = get_hamiltonian(config) - pm_objs = initialize_pm(pmesh, config, comm) + pm_objs = initialize_pm(pmesh, config, intracomm) pm, field_list, elec_common_list, coulomb_list = pm_objs ( phi, @@ -288,7 +292,7 @@ def conv_fun(comm, diffmesh): bonds=bonds if molecules_flag else None, topol=False if topol is None else True, verbose=args.verbose, - comm=comm, + comm=intracomm, ) if charges_flag: @@ -360,13 +364,13 @@ def conv_fun(comm, diffmesh): v_ext, config, layouts, - comm=comm, + comm=intracomm, ) compute_field_force( layouts, positions, force_on_grid, field_forces, types, config.n_types ) else: - kinetic_energy = comm.allreduce(0.5 * config.mass * np.sum(velocities**2)) + kinetic_energy = intracomm.allreduce(0.5 * config.mass * np.sum(velocities**2)) if charges_flag: layout_q = pm.decompose(positions) @@ -398,11 +402,11 @@ def conv_fun(comm, diffmesh): pm, positions, config, - comm=comm, + comm=intracomm, ) field_q_energy = compute_field_energy_q_GPE( - config, phi_eps, field_q_energy, elec_dot, comm=comm + config, phi_eps, field_q_energy, elec_dot, comm=intracomm ) if config.coulombtype == "PIC_Spectral": @@ -460,7 +464,7 @@ def conv_fun(comm, diffmesh): raise NotImplementedError(err_str) # Check if we have a protein - protein_flag = comm.allreduce(bonds_4_type.any() == 1) + protein_flag = intracomm.allreduce(bonds_4_type.any() == 1) if protein_flag and not args.disable_dipole: # Each rank will have different n_tors, we do not need to # domain decompose dipoles @@ -511,7 +515,7 @@ def conv_fun(comm, diffmesh): bonds_2_equilibrium, bonds_2_stength, ) - bond_energy = comm.allreduce(bond_energy_, MPI.SUM) + bond_energy = intracomm.allreduce(bond_energy_, MPI.SUM) else: bonds_2_atom1, bonds_2_atom2 = [], [] if not args.disable_angle_bonds: @@ -525,7 +529,7 @@ def conv_fun(comm, diffmesh): bonds_3_equilibrium, bonds_3_stength, ) - angle_energy = comm.allreduce(angle_energy_, MPI.SUM) + angle_energy = intracomm.allreduce(angle_energy_, MPI.SUM) # angle_pr = comm.allreduce(angle_pr_, MPI.SUM) if not args.disable_dihedrals: @@ -544,7 +548,7 @@ def conv_fun(comm, diffmesh): bonds_4_last, dipole_flag, ) - dihedral_energy = comm.allreduce(dihedral_energy_, MPI.SUM) + dihedral_energy = intracomm.allreduce(dihedral_energy_, MPI.SUM) if protein_flag and not args.disable_dipole: dipole_positions = np.reshape(dipole_positions, (4 * n_tors, 3)) @@ -600,6 +604,7 @@ def conv_fun(comm, diffmesh): config, double_out=args.double_output, disable_mpio=args.disable_mpio, + comm=intracomm, ) store_static( @@ -617,7 +622,7 @@ def conv_fun(comm, diffmesh): charges=charges if charges_flag else False, dielectrics=dielectric_sorted if dielectric_flag else False, plumed_out=True if args.plumed else False, - comm=comm, + comm=intracomm, ) if config.n_print > 0: @@ -639,10 +644,10 @@ def conv_fun(comm, diffmesh): v_ext, config, layouts, - comm=comm, + comm=intracomm, ) else: - kinetic_energy = comm.allreduce(0.5 * config.mass * np.sum(velocities**2)) + kinetic_energy = intracomm.allreduce(0.5 * config.mass * np.sum(velocities**2)) temperature = ( (2 / 3) @@ -664,7 +669,7 @@ def conv_fun(comm, diffmesh): positions, bond_pr_, angle_pr_, - comm=comm, + comm=intracomm, ) # if rank ==0 : print(pressure[9:12]) @@ -688,7 +693,8 @@ def conv_fun(comm, diffmesh): config, args.plumed, args.plumed_outfile, - comm=comm, + intracomm=intracomm, + intercomm=intercomm, verbose=args.verbose ) @@ -747,7 +753,7 @@ def conv_fun(comm, diffmesh): charge_out=charges_flag, plumed_out=True if args.plumed else False, dump_per_particle=args.dump_per_particle, - comm=comm, + comm=intracomm, ) if rank == 0: @@ -900,7 +906,7 @@ def conv_fun(comm, diffmesh): angle_pr_, step, prng, - comm=comm, + comm=intracomm, ) elif config.barostat.lower() == "semiisotropic": @@ -921,7 +927,7 @@ def conv_fun(comm, diffmesh): angle_pr_, step, prng, - comm=comm, + comm=intracomm, ) pm, field_list, elec_common_list, coulomb_list = pm_objs ( @@ -960,11 +966,11 @@ def conv_fun(comm, diffmesh): # inner step if molecules_flag: if not args.disable_bonds: - bond_energy = comm.allreduce(bond_energy_, MPI.SUM) + bond_energy = intracomm.allreduce(bond_energy_, MPI.SUM) if not args.disable_angle_bonds: - angle_energy = comm.allreduce(angle_energy_, MPI.SUM) + angle_energy = intracomm.allreduce(angle_energy_, MPI.SUM) if not args.disable_dihedrals: - dihedral_energy = comm.allreduce(dihedral_energy_, MPI.SUM) + dihedral_energy = intracomm.allreduce(dihedral_energy_, MPI.SUM) # Update slow forces if not args.disable_field: @@ -1027,11 +1033,11 @@ def conv_fun(comm, diffmesh): pm, positions, config, - comm=comm, + comm=intracomm, ) field_q_energy = compute_field_energy_q_GPE( - config, phi_eps, field_q_energy, elec_dot, comm=comm + config, phi_eps, field_q_energy, elec_dot, comm=intracomm ) if config.coulombtype == "PIC_Spectral": @@ -1113,13 +1119,13 @@ def conv_fun(comm, diffmesh): field_energy, kinetic_energy, ) = compute_field_and_kinetic_energy( phi, velocities, hamiltonian, positions, types, v_ext, - config, layouts, comm=comm, + config, layouts, comm=intracomm, ) if charges_flag and config.coulombtype == "PIC_Spectral": field_q_energy = compute_field_energy_q( config, phi_q_fourier, elec_energy_field, - field_q_energy, comm=comm, + field_q_energy, comm=intracomm, ) else: field_energy = 0.0 @@ -1127,7 +1133,7 @@ def conv_fun(comm, diffmesh): # PLUMED likes the energy per rank poteng = (field_energy + bond_energy + angle_energy - + dihedral_energy + field_q_energy) / size + + dihedral_energy + field_q_energy) / size else: poteng = 0.0 @@ -1191,7 +1197,7 @@ def conv_fun(comm, diffmesh): bonds=bonds if molecules_flag else None, topol=False if topol is None else True, verbose=args.verbose, - comm=comm, + comm=intracomm, ) if charges_flag: @@ -1307,12 +1313,12 @@ def conv_fun(comm, diffmesh): # Thermostat if config.target_temperature and np.mod(step, config.n_b) == 0: - csvr_thermostat(velocities, names, config, prng, comm=comm) + csvr_thermostat(velocities, names, config, prng, comm=intracomm) # Remove total linear momentum if config.cancel_com_momentum: if np.mod(step, config.cancel_com_momentum) == 0: - velocities = cancel_com_momentum(velocities, config, comm=comm) + velocities = cancel_com_momentum(velocities, config, comm=intracomm) # Print trajectory if config.n_print > 0: @@ -1355,16 +1361,16 @@ def conv_fun(comm, diffmesh): v_ext, config, layouts, - comm=comm, + comm=intracomm, ) if charges_flag: if config.coulombtype == "PIC_Spectral_GPE": field_q_energy = compute_field_energy_q_GPE( - config, phi_eps, field_q_energy, elec_dot, comm=comm + config, phi_eps, field_q_energy, elec_dot, comm=intracomm ) else: - kinetic_energy = comm.allreduce( + kinetic_energy = intracomm.allreduce( 0.5 * config.mass * np.sum(velocities**2) ) temperature = ( @@ -1389,7 +1395,7 @@ def conv_fun(comm, diffmesh): positions, bond_pr_, angle_pr_, - comm=comm, + comm=intracomm, ) else: pressure = 0.0 # 0.0 indicates not calculated. To be changed. @@ -1430,7 +1436,7 @@ def conv_fun(comm, diffmesh): charge_out=charges_flag, plumed_out=True if args.plumed else False, dump_per_particle=args.dump_per_particle, - comm=comm, + comm=intracomm, ) if np.mod(step, config.n_print * config.n_flush) == 0: out_dataset.flush() @@ -1485,7 +1491,7 @@ def conv_fun(comm, diffmesh): v_ext, config, layouts, - comm=comm, + comm=intracomm, ) if charges_flag: @@ -1518,11 +1524,11 @@ def conv_fun(comm, diffmesh): pm, positions, config, - comm=comm, + comm=intracomm, ) field_q_energy = compute_field_energy_q_GPE( - config, phi_eps, field_q_energy, elec_dot, comm=comm + config, phi_eps, field_q_energy, elec_dot, comm=intracomm ) if config.coulombtype == "PIC_Spectral": @@ -1543,7 +1549,7 @@ def conv_fun(comm, diffmesh): ) else: - kinetic_energy = comm.allreduce(0.5 * config.mass * np.sum(velocities**2)) + kinetic_energy = intracomm.allreduce(0.5 * config.mass * np.sum(velocities**2)) frame = (step + 1) // config.n_print temperature = ( (2 / 3) * kinetic_energy / (config.gas_constant * config.n_particles) @@ -1566,7 +1572,7 @@ def conv_fun(comm, diffmesh): bond_pr_, angle_pr_, Vbar_elec, - comm=comm, + comm=intracomm, ) else: pressure = 0.0 # 0.0 indicates not calculated. To be changed. @@ -1607,9 +1613,9 @@ def conv_fun(comm, diffmesh): charge_out=charges_flag, plumed_out=True if args.plumed else False, dump_per_particle=args.dump_per_particle, - comm=comm, + comm=intracomm, ) - out_dataset.close_file() + out_dataset.close_file(comm=intracomm) if args.plumed: plumed.finalize() diff --git a/hymd/plumed.py b/hymd/plumed.py index d50f91a3..87c0acb7 100644 --- a/hymd/plumed.py +++ b/hymd/plumed.py @@ -36,8 +36,10 @@ class PlumedBias: Used as a pointer to an :code:`double` to store the bias energy. plumed_version : (1,) numpy.ndarray Used as a pointer to an :code:`int` to store PLUMED API version. - comm : mpi4py.Comm - MPI communicator to use for rank communication. + intracomm : mpi4py.Comm + MPI communicator to use for rank communication within a replica. + intercomm : mpi4py.Comm + MPI communicator to use for rank communication between replicas. ready : bool Stores wether the :code:`calc()` method can be called or not. """ @@ -48,11 +50,12 @@ class PlumedBias: charges = None plumed_bias = np.zeros(1, np.double) plumed_version = np.zeros(1, dtype=np.intc) - comm = None + intracomm = None + intercomm = None ready = False verbose = None - def __init__(self, config, plumeddat, logfile, comm=MPI.COMM_WORLD, verbose=1): + def __init__(self, config, plumeddat, logfile, intracomm=MPI.COMM_WORLD, intercomm=None, verbose=1): """Constructor Parameters @@ -63,8 +66,10 @@ def __init__(self, config, plumeddat, logfile, comm=MPI.COMM_WORLD, verbose=1): Path file containing PLUMED input. logfile : str Path to PLUMED's output file - comm : mpi4py.Comm, optional - MPI communicator to use for rank communication. + intracomm : mpi4py.Comm, optional + MPI communicator to use for rank communication within a replica. + intercomm : mpi4py.Comm, optional + MPI communicator to use for rank communication between replicas. verbose : int, optional Specify the logging event verbosity of this object. @@ -76,12 +81,13 @@ def __init__(self, config, plumeddat, logfile, comm=MPI.COMM_WORLD, verbose=1): import plumed except ImportError: err_str = ( - "You are trying to use PLUMED " "but HyMD could not import py-plumed." + "You are trying to use PLUMED but HyMD could not import py-plumed." ) Logger.rank0.log(logging.ERROR, err_str) raise ImportError(err_str) - self.comm = comm + self.intracomm = intracomm + self.intercomm = intercomm self.verbose = verbose try: @@ -103,12 +109,19 @@ def __init__(self, config, plumeddat, logfile, comm=MPI.COMM_WORLD, verbose=1): self.plumed_obj.cmd("getApiVersion", self.plumed_version) if self.plumed_version[0] <= 3: - err_str = "HyMD requires a PLUMED API > 3. " "Use a newer PLUMED kernel." + err_str = "HyMD requires a PLUMED API > 3. Use a newer PLUMED kernel." Logger.rank0.log(logging.ERROR, err_str) raise AssertionError(err_str) self.plumed_obj.cmd("setMDEngine", "HyMD") - self.plumed_obj.cmd("setMPIComm", comm) + + if intercomm is not None: + if intracomm.Get_rank() == 0: + self.plumed_obj.cmd("GREX setMPIIntercomm", intercomm) + self.plumed_obj.cmd("GREX setMPIIntracomm", intracomm) + self.plumed_obj.cmd("GREX init") + + self.plumed_obj.cmd("setMPIComm", intracomm) Logger.rank0.log( logging.INFO, f"Attempting to read PLUMED input from {plumeddat}" diff --git a/test/test_configure_runtime.py b/test/test_configure_runtime.py index 65cd3e87..963f2598 100644 --- a/test/test_configure_runtime.py +++ b/test/test_configure_runtime.py @@ -23,11 +23,10 @@ def test_extant_file(tmp_path, caplog): f.write("test") ret_path = extant_file(file_path) - assert ret_path == file_path + assert ret_path == os.path.abspath(file_path) caplog.clear() - def test_configure_runtime(h5py_molecules_file, config_toml, change_tmp_dir, caplog): caplog.set_level(logging.ERROR) @@ -36,63 +35,71 @@ def test_configure_runtime(h5py_molecules_file, config_toml, basearg = [config_toml_file, h5py_molecules_file] - parsed, config, prng, topol = configure_runtime(basearg, comm) + parsed, config, prng, topol, intracomm, intercomm = configure_runtime(basearg, comm) assert isinstance(parsed, Namespace) assert isinstance(config, Config) assert isinstance(prng, np.random.Generator) + assert isinstance(intracomm, MPI.Comm) + assert isinstance(intercomm, (MPI.Comm, type(None))) assert parsed.destdir == "." + assert len(parsed.replica_dirs) == 0 assert parsed.logfile == "sim.log" assert parsed.plumed_outfile == "plumed.out" - parsed, _, _, _ = configure_runtime(["-v", "2"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["-v", "2"]+basearg, comm) assert parsed.verbose == 2 - parsed, _, _, _ = configure_runtime(["--disable-field"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--disable-field"]+basearg, comm) assert parsed.disable_field - parsed, _, _, _ = configure_runtime(["--disable-bonds"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--disable-bonds"]+basearg, comm) assert parsed.disable_bonds - parsed, _, _, _ = configure_runtime(["--disable-angle-bonds"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--disable-angle-bonds"]+basearg, comm) assert parsed.disable_angle_bonds - parsed, _, _, _ = configure_runtime(["--disable-dihedrals"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--disable-dihedrals"]+basearg, comm) assert parsed.disable_dihedrals - parsed, _, _, _ = configure_runtime(["--disable-dipole"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--disable-dipole"]+basearg, comm) assert parsed.disable_dipole - parsed, _, _, _ = configure_runtime(["--double-precision"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--double-precision"]+basearg, comm) assert parsed.double_precision - parsed, _, _, _ = configure_runtime(["--double-output"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--double-output"]+basearg, comm) assert parsed.double_output - parsed, _, _, _ = configure_runtime(["--dump-per-particle"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--dump-per-particle"]+basearg, comm) assert parsed.dump_per_particle - parsed, _, _, _ = configure_runtime(["--force-output"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--force-output"]+basearg, comm) assert parsed.force_output - parsed, _, _, _ = configure_runtime(["--velocity-output"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--velocity-output"]+basearg, comm) assert parsed.velocity_output - parsed, _, _, _ = configure_runtime(["--disable-mpio"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--disable-mpio"]+basearg, comm) assert parsed.disable_mpio - parsed, _, _, _ = configure_runtime(["--destdir", "testdir"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--destdir", "testdir"]+basearg, comm) assert parsed.destdir == "testdir" - parsed, _, _, _ = configure_runtime(["--seed", "54321"]+basearg, comm) + # parsed, _, _, _, _, _ = configure_runtime(basearg+["--replica-dirs", "a", "b"], comm) + # assert len(parsed.replica_dirs) == 2 + # assert parsed.replica_dirs[0] == "a" + # assert parsed.replica_dirs[1] == "b" + + parsed, _, _, _, _, _ = configure_runtime(["--seed", "54321"]+basearg, comm) assert parsed.seed == 54321 - parsed, _, _, _ = configure_runtime(["--logfile", "test.log"]+basearg, comm) + parsed, _, _, _, _, _ = configure_runtime(["--logfile", "test.log"]+basearg, comm) assert parsed.logfile == "test.log" - parsed, _, _, _ = configure_runtime(["--plumed", "test.log"]+basearg, comm) - assert parsed.plumed == "test.log" + parsed, _, _, _, _, _ = configure_runtime(["--plumed", "test.log"]+basearg, comm) + assert parsed.plumed == os.path.abspath("test.log") - parsed, _, _, _ = configure_runtime( + parsed, _, _, _, _, _ = configure_runtime( ["--plumed-outfile", "test.plumed.out"]+basearg, comm ) diff --git a/test/test_plumed.py b/test/test_plumed.py index 1a2c8d74..0de47312 100644 --- a/test/test_plumed.py +++ b/test/test_plumed.py @@ -52,7 +52,7 @@ def test_plumed_bias_obj(molecules_with_solvent, change_test_dir, tmp_path, config, os.path.join(tmp_path,"plumed.dat"), os.path.join(tmp_path,"plumed.out"), - comm=comm, + intracomm=comm, verbose=2 ) @@ -122,7 +122,7 @@ def test_plumed_bias_obj(molecules_with_solvent, change_test_dir, tmp_path, @pytest.mark.mpi() @pytest.mark.skip(reason="Currently fails in CI due to environment variable") -def test_fail_plumed_bias_obj(monkeypatch): +def test_fail_plumed_bias_obj(monkeypatch, caplog): pytest.importorskip("plumed") comm = MPI.COMM_WORLD @@ -143,13 +143,13 @@ def test_fail_plumed_bias_obj(monkeypatch): break import hymd.plumed - + with pytest.raises(RuntimeError) as recorded_error: with hymd.plumed.PlumedBias( config, "test.in", "test.out", - comm=comm, + intracomm=comm, verbose=2 ) as _: if rank == 0: @@ -168,7 +168,7 @@ def test_fail_plumed_bias_obj(monkeypatch): assert all([(s in message) for s in cmp_strings]) -def test_unavailable_plumed(hide_available_plumed): +def test_unavailable_plumed(hide_available_plumed, caplog): import hymd.plumed comm = MPI.COMM_WORLD @@ -185,7 +185,7 @@ def test_unavailable_plumed(hide_available_plumed): config, "test.in", "test.out", - comm=comm, + intracomm=comm, verbose=2 ) as _: if rank == 0: diff --git a/utils/aggregates.py b/utils/aggregates.py index dbefa930..596aeecc 100644 --- a/utils/aggregates.py +++ b/utils/aggregates.py @@ -5,6 +5,7 @@ import matplotlib.pyplot as plt from matplotlib.ticker import MaxNLocator import os +import sys import argparse from tqdm import tqdm import warnings @@ -114,6 +115,7 @@ def compute_clusters( print_sel.write(f"./colored_pdbs/snap_{frame}.pdb") + plt.close() return clusters @@ -132,6 +134,7 @@ def aggregates_clustering( plot_dendrograms, traj_in_memory, save_solvent, + summary_fig_size=(12, 8), ): u = mda.Universe(grofile, h5mdfile, in_memory=traj_in_memory) @@ -178,38 +181,85 @@ def aggregates_clustering( clusters = dask.compute(job_list, num_workers=nworkers) n_clusters = [] - clust_sizes = [] all_sizes = [] + total_clust_by_size = {} for c in clusters[0]: # get the number of clusters and sizes unique_clusts, clust_counts = np.unique(c, return_counts=True) + for size in clust_counts: + if size not in total_clust_by_size: + total_clust_by_size[size] = 1 + else: + total_clust_by_size[size] += 1 n_clusters.append(len(unique_clusts)) - clust_sizes.append(clust_counts) all_sizes += clust_counts.tolist() # based on cluster sizes get occurence of each size sizes, freq = np.unique(all_sizes, return_counts=True) freq = freq / len(u.trajectory[skip:end:stride]) - # write sizes and freq to file + # overall average number of aggregates + avg_n_aggs = np.average(n_clusters) + + # compute probability of picking cluster of size n + prob_by_size = {} + for k, v in total_clust_by_size.items(): + prob_by_size[k] = v / np.sum(n_clusters) + prob_by_size = dict(sorted(prob_by_size.items())) + + # compute probability of picking a random molecule and + # it belonging to a cluster of size n + prob_mol_size = {} + norm = len(at_sel) * len(u.trajectory[skip:end:stride]) + for k, v in total_clust_by_size.items(): + prob_mol_size[k] = k * v / norm + prob_mol_size = dict(sorted(prob_mol_size.items())) + + # write summary to file with open("summary_clustering.dat", "w") as of: + of.write("Executed command: " + " ".join(sys.argv) + "\n") + + of.write(f"\nAverage number of aggregates: {avg_n_aggs}\n") + + of.write("\nsize\tfrequency\n") for s, f in zip(sizes, freq): of.write(f"{s}\t{f}\n") + + of.write("\nsize\tprobability\n") + for s, p in prob_by_size.items(): + of.write(f"{s}\t{p}\n") + + of.write("\nsize\tprob molecule\n") + for s, p in prob_mol_size.items(): + of.write(f"{s}\t{p}\n") # plot results - fig, (ax1, ax2) = plt.subplots(2, 1) + _, axs = plt.subplots(2, 2, figsize=summary_fig_size) - ax1.plot(frames, n_clusters) - ax1.set_ylabel("Number of aggregates") - ax1.set_xlabel("Frame") - ax1.yaxis.set_major_locator(MaxNLocator(integer=True)) + axs[0, 0].plot(frames, n_clusters) + axs[0, 0].axhline(avg_n_aggs, linestyle="--") + axs[0, 0].set_ylabel("Num. of aggregates") + axs[0, 0].set_xlabel("Frame") + axs[0, 0].yaxis.set_major_locator(MaxNLocator(integer=True)) xticklabels = [f"{sizes[i]}" for i in range(len(sizes))] - ax2.bar(xticklabels, freq, width=0.8) - ax2.set_ylabel("Frequency") - ax2.set_xlabel("Aggregate size") - ax2.tick_params("x", labelrotation=60) + axs[0, 1].bar(xticklabels, freq, width=0.8) + axs[0, 1].set_ylabel("Avg. num. per snapshot") + axs[0, 1].set_xlabel("Aggregate size") + axs[0, 1].tick_params("x", labelrotation=60) + + xticklabels = [f"{k}" for k in prob_by_size.keys()] + axs[1, 0].bar(xticklabels, prob_by_size.values(), width=0.8) + axs[1, 0].set_ylabel("Prob.") + axs[1, 0].set_xlabel("Aggregate size") + axs[1, 0].tick_params("x", labelrotation=60) + + xticklabels = [f"{k}" for k in prob_mol_size.keys()] + axs[1, 1].bar(xticklabels, prob_mol_size.values(), width=0.8) + axs[1, 1].set_ylabel("Prob. molecule in agg.") + axs[1, 1].set_xlabel("Aggregate size") + axs[1, 1].tick_params("x", labelrotation=60) plt.tight_layout() plt.savefig("summary_clustering.pdf", bbox_inches="tight") @@ -294,6 +344,13 @@ def aggregates_clustering( default=False, help="plot the dendrograms (saved in ./dendrograms) (use with stride because its ~10x slower)", ) + parser.add_argument( + "--summary-fig-size", + type=int, + nargs=2, + default=(8, 6), + help="two integers to define the size of the summary figure (default = 8 6)", + ) parser.add_argument( "--traj-in-memory", action="store_true", @@ -347,4 +404,5 @@ def aggregates_clustering( args.plot_dendrograms, args.traj_in_memory, args.save_solvent, + tuple(args.summary_fig_size), ) diff --git a/utils/compute_rdf_aggregates.py b/utils/compute_rdf_aggregates.py index 6836cb22..5c8b0604 100644 --- a/utils/compute_rdf_aggregates.py +++ b/utils/compute_rdf_aggregates.py @@ -45,6 +45,7 @@ def compute_rdfs( compute_i_rg, compute_pddf, save_centered, + save_agg_only, ): rdfs = {} for i in range(len(selections)): @@ -104,30 +105,38 @@ def compute_rdfs( cog = (box_vectors[:3] * tetha_bar) / (2.0 * np.pi) - if save_centered: - dirname = f"./centered_{agg_size}" + # since some methods are not PBC aware, we center the + # aggregate in the box so it does not split in the PBCs + if save_agg_only or save_centered or compute_i_rg: box_center = box_vectors[:3] / 2.0 u.atoms.translate(box_center - cog) u.atoms.wrap(compound="atoms") cog = box_center + if save_agg_only: + agg_sel = u.select_atoms(f"resid {resid}") + + dirname = f"./agg_only_{agg_size}" + + if not os.path.exists(dirname): + os.mkdir(dirname) + + agg_sel.atoms.write( + os.path.join(dirname, f"centered_{os.path.basename(snapshot)}") + ) + + if save_centered: + dirname = f"./centered_{agg_size}" + if not os.path.exists(dirname): os.mkdir(dirname) u.atoms.write( - f"./centered_{agg_size}/centered_{os.path.basename(snapshot)}" + os.path.join(dirname, f"centered_{os.path.basename(snapshot)}") ) # compute the principal moment of inertia and Rg if compute_i_rg: - # since the methods are not PBC aware, we center the - # aggregate in the box so it does not split in the PBCs - if not save_centered: - box_center = box_vectors[:3] / 2.0 - u.atoms.translate(box_center - cog) - u.atoms.wrap(compound="atoms") - cog = box_center - # set the masses agg_sel = u.select_atoms(f"resid {resid}") @@ -163,7 +172,7 @@ def compute_rdfs( n = len(agg_sel) cond_distmat = np.zeros((int((n * n - n) / 2),), dtype=np.float64) - all_distances = distances.self_distance_array( + distances.self_distance_array( agg_sel.positions, box=box_vectors, result=cond_distmat, @@ -341,6 +350,12 @@ def compute_rdfs( default=False, help='for each snapshot containing an aggregate of selected size, save the snapshot with the aggregate centered in the "centered" directory', ) + parser.add_argument( + "--save-aggregate-only", + action="store_true", + default=False, + help="for each snapshot, save the configuration of the aggregate only" + ) parser.add_argument( "--do-not-compute-rdfs", action="store_true", @@ -385,4 +400,5 @@ def compute_rdfs( args.principal_moments_rg, args.compute_pddf, args.save_centered_aggregate, + args.save_aggregate_only, )