diff --git a/hymd/configure_runtime.py b/hymd/configure_runtime.py index c374c87e..37c0d46b 100644 --- a/hymd/configure_runtime.py +++ b/hymd/configure_runtime.py @@ -161,15 +161,6 @@ def configure_runtime(args_in, comm): # check if we have at least one rank per replica if comm.Get_size() < len(args.replica_dirs) and comm.Get_rank() == 0: raise ValueError("You should have at least one MPI rank per replica.") - - # block destdir with replicas - if (len(args.replica_dirs) > 0) and args.destdir != "." and comm.Get_rank() == 0: - raise ValueError("You should not specify a destination directory when using replicas.") - - if comm.Get_rank() == 0: - os.makedirs(args.destdir, exist_ok=True) - - comm.barrier() # Safely define seeds seeds = None @@ -208,6 +199,11 @@ def configure_runtime(args_in, comm): intracomm = comm intercomm = None + if intracomm.Get_rank() == 0: + os.makedirs(args.destdir, exist_ok=True) + + intracomm.barrier() + # Setup logger Logger.setup( default_level=logging.INFO, diff --git a/hymd/logger.py b/hymd/logger.py index 822b757e..41deaade 100644 --- a/hymd/logger.py +++ b/hymd/logger.py @@ -56,6 +56,28 @@ def filter(self, record): return True +class MPIFilterStdout(logging.Filter): + """Log output Filter wrapper class for filtering STDOUT log""" + + def filter(self, record): + """Log event message filter + + Parameters + ---------- + record : logging.LogRecord + LogRecord object corresponding to the log event. + """ + if record.funcName == "": + record.funcName = "main" + record.funcName = "replica 0 " + record.funcName + if MPI.COMM_WORLD.Get_rank() == 0: + record.rank = MPI.COMM_WORLD.Get_rank() + record.size = MPI.COMM_WORLD.Get_size() + return True + else: + return False + + class Logger: """Log output handler class @@ -157,6 +179,7 @@ def setup(cls, default_level=logging.INFO, log_file=None, verbose=False, comm=MP cls.stdout_handler.setLevel(level) cls.stdout_handler.setStream(sys.stdout) cls.stdout_handler.setFormatter(cls.formatter) + cls.stdout_handler.addFilter(MPIFilterStdout()) cls.rank0.addHandler(cls.stdout_handler) cls.all_ranks.addHandler(cls.stdout_handler)