diff --git a/mammoth/bin/train.py b/mammoth/bin/train.py index 4c1d9072..8327176f 100644 --- a/mammoth/bin/train.py +++ b/mammoth/bin/train.py @@ -15,7 +15,7 @@ ) from mammoth.utils.misc import set_random_seed # from mammoth.modules.embeddings import prepare_pretrained_embeddings -from mammoth.utils.logging import init_logger, logger +from mammoth.utils.logging import init_logger, logger,init_valid_logger from mammoth.models.model_saver import load_checkpoint from mammoth.train_single import main as single_main @@ -145,6 +145,8 @@ def validate_slurm_node_opts(current_env, world_context, opts): def train(opts): init_logger(opts.log_file) + if opts.valid_log_file is not None and opts.valid_log_file !="": + init_valid_logger(opts.valid_log_file) ArgumentParser.validate_train_opts(opts) ArgumentParser.update_model_opts(opts) ArgumentParser.validate_model_opts(opts) diff --git a/mammoth/opts.py b/mammoth/opts.py index 22a3dfd6..4a105f62 100644 --- a/mammoth/opts.py +++ b/mammoth/opts.py @@ -23,6 +23,7 @@ def config_opts(parser): def _add_logging_opts(parser, is_train=True): group = parser.add_argument_group('Logging') group.add('--log_file', '-log_file', type=str, default="", help="Output logs to a file under this path.") + group.add('--valid_log_file', '-valid_log_file', type=str, default="", help="Output logs to a file under this path.") group.add( '--log_file_level', '-log_file_level', diff --git a/mammoth/utils/logging.py b/mammoth/utils/logging.py index 895a89b1..7ccce83a 100644 --- a/mammoth/utils/logging.py +++ b/mammoth/utils/logging.py @@ -6,6 +6,9 @@ logger = logging.getLogger() +valid_logger = logging.getLogger("valid_logger") +if not valid_logger.hasHandlers: + valid_logger =None def init_logger( log_file=None, @@ -29,11 +32,37 @@ def init_logger( file_handler = logging.FileHandler(log_file) file_handler.setLevel(log_file_level) file_handler.setFormatter(log_format) - logger.addHandler(file_handler) + logger.addHandler(file_handler) return logger +def init_valid_logger( + log_file=None, + log_file_level=logging.DEBUG, + rotate=False, + log_level=logging.DEBUG, + gpu_id='', +): + log_format = logging.Formatter(f"[%(asctime)s %(process)s {gpu_id} %(levelname)s] %(message)s") + logger = logging.getLogger("valid_logger") + logger.setLevel(log_level) + console_handler = logging.StreamHandler() + console_handler.setFormatter(log_format) + logger.handlers = [console_handler] + + if log_file and log_file != '': + if rotate: + file_handler = RotatingFileHandler(log_file, maxBytes=1000000, backupCount=10, mode='a', buffering=1, delay=True) + else: + file_handler = logging.FileHandler(log_file, mode='a', buffering=1, delay=True) + file_handler.setLevel(log_file_level) + file_handler.setFormatter(log_format) + logger.addHandler(file_handler) + logger.propagate = False + return logger + + def log_lca_values(step, lca_logs, lca_params, opath, dump_logs=False): for k, v in lca_params.items(): lca_sum = v.sum().item() diff --git a/mammoth/utils/report_manager.py b/mammoth/utils/report_manager.py index 822938d0..97e44e8a 100644 --- a/mammoth/utils/report_manager.py +++ b/mammoth/utils/report_manager.py @@ -4,7 +4,7 @@ import mammoth -from mammoth.utils.logging import logger +from mammoth.utils.logging import logger, valid_logger def build_report_manager(opts, node_rank, local_rank): @@ -51,6 +51,10 @@ def start(self): def log(self, *args, **kwargs): logger.info(*args, **kwargs) + def log_valid(self, *args, **kwargs): + if valid_logger is not None: + valid_logger.info(*args, **kwargs) + def report_training(self, step, num_steps, learning_rate, patience, report_stats, multigpu=False): """ This is the user-defined batch-level traing progress @@ -142,5 +146,5 @@ def _report_step(self, lr, patience, step, train_stats=None, valid_stats=None): if valid_stats is not None: self.log('Validation perplexity: %g' % valid_stats.ppl()) self.log('Validation accuracy: %g' % valid_stats.accuracy()) - + self.log_valid(f'Step {step}; lr: {lr}; ppl: {valid_stats.ppl()}; acc: {valid_stats.accuracy()}; xent: {valid_stats.xent()};' ) self.maybe_log_tensorboard(valid_stats, "valid", lr, patience, step)