Skip to content

Commit

Permalink
Fixes logging with multiple replicas
Browse files Browse the repository at this point in the history
  • Loading branch information
hmcezar committed Sep 28, 2023
1 parent a4a8431 commit 0ab127a
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 62 deletions.
38 changes: 35 additions & 3 deletions hymd/configure_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def configure_runtime(args_in, comm):
Namespace containing command line arguments.
config : hymd.input_parser.Config
Parsed configuration object.
prng : np.random.Generator
Random number generator for this rank.
topol : dict
Topology dictionary.
intracomm : mpi4py.Comm
MPI communicator to use for rank communication within a replica.
intercomm : mpi4py.Comm
MPI communicator to use for rank communication between replicas.
"""
ap = argparse.ArgumentParser()

Expand Down Expand Up @@ -151,11 +159,11 @@ def configure_runtime(args_in, comm):
args = ap.parse_args(args_in)

# check if we have at least one rank per replica
if comm.Get_size() < len(args.replica_dirs):
if comm.Get_size() < len(args.replica_dirs) and comm.Get_rank() == 0:
raise ValueError("You should have at least one MPI rank per replica.")

# block destdir with replicas
if (len(args.replica_dirs) > 0) and args.destdir != ".":
if (len(args.replica_dirs) > 0) and args.destdir != "." and comm.Get_rank() == 0:
raise ValueError("You should not specify a destination directory when using replicas.")

if comm.Get_rank() == 0:
Expand All @@ -177,11 +185,35 @@ def configure_runtime(args_in, comm):
# Setup a PRNG for each rank
prng = np.random.default_rng(seeds[comm.Get_rank()])

# multi replica setup
n_replicas = len(args.replica_dirs)
if n_replicas > 1:
# split comm=COMM_WORLD in intracomm and intercomm
size = comm.Get_size()
rank = comm.Get_rank()

n_intra = int(size / n_replicas)
if (n_replicas * n_intra != size):
err_str = "Inconsistent number of ranks per replica"

if rank == 0:
raise AssertionError(err_str)
intracomm = comm.Split(int(rank / n_intra), rank)
intercomm = comm.Split(rank % n_intra, rank)

# assign directory to each rank
os.chdir(args.replica_dirs[int(rank / n_intra)])

else:
intracomm = comm
intercomm = None

# Setup logger
Logger.setup(
default_level=logging.INFO,
log_file=f"{args.destdir}/{args.logfile}",
verbose=args.verbose,
comm=intracomm,
)

# print header info
Expand Down Expand Up @@ -241,7 +273,7 @@ def profile_atexit():
f"Unable to parse configuration file {args.config}"
f"\n\ntoml parse traceback:" + repr(ve)
)
return args, config, prng, topol
return args, config, prng, topol, intracomm, intercomm


def extant_file(x):
Expand Down
28 changes: 20 additions & 8 deletions hymd/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@

class MPIFilterRoot(logging.Filter):
"""Log output Filter wrapper class for the root MPI rank log"""

comm = None

def __init__(self, comm=MPI.COMM_WORLD):
self.comm = comm

def filter(self, record):
"""Log event message filter
Expand All @@ -20,9 +25,9 @@ def filter(self, record):
"""
if record.funcName == "<module>":
record.funcName = "main"
if MPI.COMM_WORLD.Get_rank() == 0:
record.rank = MPI.COMM_WORLD.Get_rank()
record.size = MPI.COMM_WORLD.Get_size()
if self.comm.Get_rank() == 0:
record.rank = self.comm.Get_rank()
record.size = self.comm.Get_size()
return True
else:
return False
Expand All @@ -31,6 +36,11 @@ def filter(self, record):
class MPIFilterAll(logging.Filter):
"""Log output Filter wrapper class for the all-MPI-ranks log"""

comm = None

def __init__(self, comm=MPI.COMM_WORLD):
self.comm = comm

def filter(self, record):
"""Log event message filter
Expand All @@ -41,8 +51,8 @@ def filter(self, record):
"""
if record.funcName == "<module>":
record.funcName = "main"
record.rank = MPI.COMM_WORLD.Get_rank()
record.size = MPI.COMM_WORLD.Get_size()
record.rank = self.comm.Get_rank()
record.size = self.comm.Get_size()
return True


Expand Down Expand Up @@ -96,7 +106,7 @@ class Logger:
all_ranks = logging.getLogger("HyMD.all_ranks")

@classmethod
def setup(cls, default_level=logging.INFO, log_file=None, verbose=False):
def setup(cls, default_level=logging.INFO, log_file=None, verbose=False, comm=MPI.COMM_WORLD):
"""Sets up the logger object.
If a :code:`log_file` path is provided, log event messages are output
Expand Down Expand Up @@ -127,8 +137,10 @@ def setup(cls, default_level=logging.INFO, log_file=None, verbose=False):
cls.rank0.setLevel(level)
cls.all_ranks.setLevel(level)

cls.rank0.addFilter(MPIFilterRoot())
cls.all_ranks.addFilter(MPIFilterAll())
root_filter = MPIFilterRoot(comm)
all_filter = MPIFilterAll(comm)
cls.rank0.addFilter(root_filter)
cls.all_ranks.addFilter(all_filter)

if (not log_file) and (not log_to_stdout):
return
Expand Down
34 changes: 6 additions & 28 deletions hymd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ def main():
the molecular dynamics loop.
"""
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

if rank == 0:
if comm.Get_rank() == 0:
start_time = datetime.datetime.now()

args, config, prng, topol = configure_runtime(sys.argv[1:], comm)
args, config, prng, topol, intracomm, intercomm = configure_runtime(sys.argv[1:], comm)

# Get rank and size
rank = intracomm.Get_rank()
size = intracomm.Get_size()

if args.double_precision:
dtype = np.float64
Expand All @@ -65,30 +67,6 @@ def main():
from .force import compute_angle_forces__fortran as compute_angle_forces
from .force import compute_dihedral_forces__fortran as compute_dihedral_forces

# multi replica setup
n_replicas = len(args.replica_dirs)
if n_replicas > 1:
# split COMM_WORLD in intracomm and intercomm
n_intra = int(size / n_replicas)
if (n_replicas * n_intra != size):
err_str = "Inconsistent number of ranks per replica"
Logger.rank0.log(logging.ERROR, err_str)

if rank == 0:
raise AssertionError(err_str)
intracomm = comm.Split(int(rank / n_intra), rank)
intercomm = comm.Split(rank % n_intra, rank)

# assign directory to each rank
os.chdir(args.replica_dirs[int(rank / n_intra)])

# update rank and size
rank = intracomm.Get_rank()
size = intracomm.Get_size()
else:
intracomm = comm
intercomm = None

# read input .hdf5
driver = "mpio" if not args.disable_mpio else None
_kwargs = {"driver": driver, "comm": intracomm} if not args.disable_mpio else {}
Expand Down
49 changes: 28 additions & 21 deletions test/test_configure_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ def test_extant_file(tmp_path, caplog):
f.write("test")

ret_path = extant_file(file_path)
assert ret_path == file_path
assert ret_path == os.path.abspath(file_path)

caplog.clear()


def test_configure_runtime(h5py_molecules_file, config_toml,
change_tmp_dir, caplog):
caplog.set_level(logging.ERROR)
Expand All @@ -36,63 +35,71 @@ def test_configure_runtime(h5py_molecules_file, config_toml,

basearg = [config_toml_file, h5py_molecules_file]

parsed, config, prng, topol = configure_runtime(basearg, comm)
parsed, config, prng, topol, intracomm, intercomm = configure_runtime(basearg, comm)
assert isinstance(parsed, Namespace)
assert isinstance(config, Config)
assert isinstance(prng, np.random.Generator)
assert isinstance(intracomm, MPI.Comm)
assert isinstance(intercomm, (MPI.Comm, type(None)))
assert parsed.destdir == "."
assert len(parsed.replica_dirs) == 0
assert parsed.logfile == "sim.log"
assert parsed.plumed_outfile == "plumed.out"

parsed, _, _, _ = configure_runtime(["-v", "2"]+basearg, comm)
parsed, _, _, _, _, _ = configure_runtime(["-v", "2"]+basearg, comm)
assert parsed.verbose == 2

parsed, _, _, _ = configure_runtime(["--disable-field"]+basearg, comm)
parsed, _, _, _, _, _ = configure_runtime(["--disable-field"]+basearg, comm)
assert parsed.disable_field

parsed, _, _, _ = configure_runtime(["--disable-bonds"]+basearg, comm)
parsed, _, _, _, _, _ = configure_runtime(["--disable-bonds"]+basearg, comm)
assert parsed.disable_bonds

parsed, _, _, _ = configure_runtime(["--disable-angle-bonds"]+basearg, comm)
parsed, _, _, _, _, _ = configure_runtime(["--disable-angle-bonds"]+basearg, comm)
assert parsed.disable_angle_bonds

parsed, _, _, _ = configure_runtime(["--disable-dihedrals"]+basearg, comm)
parsed, _, _, _, _, _ = configure_runtime(["--disable-dihedrals"]+basearg, comm)
assert parsed.disable_dihedrals

parsed, _, _, _ = configure_runtime(["--disable-dipole"]+basearg, comm)
parsed, _, _, _, _, _ = configure_runtime(["--disable-dipole"]+basearg, comm)
assert parsed.disable_dipole

parsed, _, _, _ = configure_runtime(["--double-precision"]+basearg, comm)
parsed, _, _, _, _, _ = configure_runtime(["--double-precision"]+basearg, comm)
assert parsed.double_precision

parsed, _, _, _ = configure_runtime(["--double-output"]+basearg, comm)
parsed, _, _, _, _, _ = configure_runtime(["--double-output"]+basearg, comm)
assert parsed.double_output

parsed, _, _, _ = configure_runtime(["--dump-per-particle"]+basearg, comm)
parsed, _, _, _, _, _ = configure_runtime(["--dump-per-particle"]+basearg, comm)
assert parsed.dump_per_particle

parsed, _, _, _ = configure_runtime(["--force-output"]+basearg, comm)
parsed, _, _, _, _, _ = configure_runtime(["--force-output"]+basearg, comm)
assert parsed.force_output

parsed, _, _, _ = configure_runtime(["--velocity-output"]+basearg, comm)
parsed, _, _, _, _, _ = configure_runtime(["--velocity-output"]+basearg, comm)
assert parsed.velocity_output

parsed, _, _, _ = configure_runtime(["--disable-mpio"]+basearg, comm)
parsed, _, _, _, _, _ = configure_runtime(["--disable-mpio"]+basearg, comm)
assert parsed.disable_mpio

parsed, _, _, _ = configure_runtime(["--destdir", "testdir"]+basearg, comm)
parsed, _, _, _, _, _ = configure_runtime(["--destdir", "testdir"]+basearg, comm)
assert parsed.destdir == "testdir"

parsed, _, _, _ = configure_runtime(["--seed", "54321"]+basearg, comm)
# parsed, _, _, _, _, _ = configure_runtime(basearg+["--replica-dirs", "a", "b"], comm)
# assert len(parsed.replica_dirs) == 2
# assert parsed.replica_dirs[0] == "a"
# assert parsed.replica_dirs[1] == "b"

parsed, _, _, _, _, _ = configure_runtime(["--seed", "54321"]+basearg, comm)
assert parsed.seed == 54321

parsed, _, _, _ = configure_runtime(["--logfile", "test.log"]+basearg, comm)
parsed, _, _, _, _, _ = configure_runtime(["--logfile", "test.log"]+basearg, comm)
assert parsed.logfile == "test.log"

parsed, _, _, _ = configure_runtime(["--plumed", "test.log"]+basearg, comm)
assert parsed.plumed == "test.log"
parsed, _, _, _, _, _ = configure_runtime(["--plumed", "test.log"]+basearg, comm)
assert parsed.plumed == os.path.abspath("test.log")

parsed, _, _, _ = configure_runtime(
parsed, _, _, _, _, _ = configure_runtime(
["--plumed-outfile", "test.plumed.out"]+basearg,
comm
)
Expand Down
4 changes: 2 additions & 2 deletions test/test_plumed.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_plumed_bias_obj(molecules_with_solvent, change_test_dir, tmp_path,

@pytest.mark.mpi()
@pytest.mark.skip(reason="Currently fails in CI due to environment variable")
def test_fail_plumed_bias_obj(monkeypatch):
def test_fail_plumed_bias_obj(monkeypatch, caplog):
pytest.importorskip("plumed")

comm = MPI.COMM_WORLD
Expand Down Expand Up @@ -168,7 +168,7 @@ def test_fail_plumed_bias_obj(monkeypatch):
assert all([(s in message) for s in cmp_strings])


def test_unavailable_plumed(hide_available_plumed):
def test_unavailable_plumed(hide_available_plumed, caplog):
import hymd.plumed

comm = MPI.COMM_WORLD
Expand Down

0 comments on commit 0ab127a

Please sign in to comment.