Skip to content

Commit

Permalink
Improve logging for multi replicas
Browse files Browse the repository at this point in the history
and enable destdir when replica-dirs is used
  • Loading branch information
hmcezar committed Sep 29, 2023
1 parent 8601483 commit 0e2145c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 9 deletions.
14 changes: 5 additions & 9 deletions hymd/configure_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,6 @@ def configure_runtime(args_in, comm):
# check if we have at least one rank per replica
if comm.Get_size() < len(args.replica_dirs) and comm.Get_rank() == 0:
raise ValueError("You should have at least one MPI rank per replica.")

# block destdir with replicas
if (len(args.replica_dirs) > 0) and args.destdir != "." and comm.Get_rank() == 0:
raise ValueError("You should not specify a destination directory when using replicas.")

if comm.Get_rank() == 0:
os.makedirs(args.destdir, exist_ok=True)

comm.barrier()

# Safely define seeds
seeds = None
Expand Down Expand Up @@ -208,6 +199,11 @@ def configure_runtime(args_in, comm):
intracomm = comm
intercomm = None

if intracomm.Get_rank() == 0:
os.makedirs(args.destdir, exist_ok=True)

intracomm.barrier()

# Setup logger
Logger.setup(
default_level=logging.INFO,
Expand Down
23 changes: 23 additions & 0 deletions hymd/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,28 @@ def filter(self, record):
return True


class MPIFilterStdout(logging.Filter):
"""Log output Filter wrapper class for filtering STDOUT log"""

def filter(self, record):
"""Log event message filter
Parameters
----------
record : logging.LogRecord
LogRecord object corresponding to the log event.
"""
if record.funcName == "<module>":
record.funcName = "main"
record.funcName = "replica 0 " + record.funcName
if MPI.COMM_WORLD.Get_rank() == 0:
record.rank = MPI.COMM_WORLD.Get_rank()
record.size = MPI.COMM_WORLD.Get_size()
return True
else:
return False


class Logger:
"""Log output handler class
Expand Down Expand Up @@ -157,6 +179,7 @@ def setup(cls, default_level=logging.INFO, log_file=None, verbose=False, comm=MP
cls.stdout_handler.setLevel(level)
cls.stdout_handler.setStream(sys.stdout)
cls.stdout_handler.setFormatter(cls.formatter)
cls.stdout_handler.addFilter(MPIFilterStdout())

cls.rank0.addHandler(cls.stdout_handler)
cls.all_ranks.addHandler(cls.stdout_handler)
Expand Down

0 comments on commit 0e2145c

Please sign in to comment.