diff --git a/mammoth/bin/train.py b/mammoth/bin/train.py index 4c1d9072..330b6dde 100644 --- a/mammoth/bin/train.py +++ b/mammoth/bin/train.py @@ -144,7 +144,7 @@ def validate_slurm_node_opts(current_env, world_context, opts): def train(opts): - init_logger(opts.log_file) + init_logger(opts.log_file, structured_log_file=opts.structured_log_file) ArgumentParser.validate_train_opts(opts) ArgumentParser.update_model_opts(opts) ArgumentParser.validate_model_opts(opts) diff --git a/mammoth/inputters/dataset.py b/mammoth/inputters/dataset.py index 6bb2d1a9..fe9e5cf2 100644 --- a/mammoth/inputters/dataset.py +++ b/mammoth/inputters/dataset.py @@ -179,7 +179,7 @@ def get_corpus(opts, task, src_vocab: Vocab, tgt_vocab: Vocab, is_train: bool = vocabs = {'src': src_vocab, 'tgt': tgt_vocab} corpus_opts = opts.tasks[task.corpus_id] transforms_to_apply = corpus_opts.get('transforms', None) - transforms_to_apply = transforms_to_apply or opts.get('transforms', None) + transforms_to_apply = transforms_to_apply or opts.transforms transforms_to_apply = transforms_to_apply or [] transforms_cls = make_transforms( opts, diff --git a/mammoth/opts.py b/mammoth/opts.py index 22a3dfd6..ca8a31d6 100644 --- a/mammoth/opts.py +++ b/mammoth/opts.py @@ -23,6 +23,13 @@ def config_opts(parser): def _add_logging_opts(parser, is_train=True): group = parser.add_argument_group('Logging') group.add('--log_file', '-log_file', type=str, default="", help="Output logs to a file under this path.") + group.add( + '--structured_log_file', + '-structured_log_file', + type=str, + default="", + help="Output machine-readable structured logs to a file under this path." + ) group.add( '--log_file_level', '-log_file_level', diff --git a/mammoth/utils/logging.py b/mammoth/utils/logging.py index 895a89b1..3b923955 100644 --- a/mammoth/utils/logging.py +++ b/mammoth/utils/logging.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- -import os import json import logging from logging.handlers import RotatingFileHandler +from typing import Dict, Union logger = logging.getLogger() @@ -13,6 +13,7 @@ def init_logger( rotate=False, log_level=logging.INFO, gpu_id='', + structured_log_file=None, ): log_format = logging.Formatter(f"[%(asctime)s %(process)s {gpu_id} %(levelname)s] %(message)s") logger = logging.getLogger() @@ -31,18 +32,31 @@ def init_logger( file_handler.setFormatter(log_format) logger.addHandler(file_handler) - return logger + if structured_log_file: + init_structured_logger(structured_log_file) + return logger -def log_lca_values(step, lca_logs, lca_params, opath, dump_logs=False): - for k, v in lca_params.items(): - lca_sum = v.sum().item() - lca_mean = v.mean().item() - lca_logs[k][f'STEP_{step}'] = {'sum': lca_sum, 'mean': lca_mean} - if dump_logs: - if os.path.exists(opath): - os.system(f'cp {opath} {opath}.previous') - with open(opath, 'w+') as f: - json.dump(lca_logs, f) - logger.info(f'dumped LCA logs in {opath}') +def init_structured_logger( + log_file=None, +): + # Log should be parseable as a jsonl file. Format should not include anything extra. + log_format = logging.Formatter("%(message)s") + logger = logging.getLogger("structured_logger") + logger.setLevel(logging.INFO) + file_handler = logging.FileHandler(log_file, mode='a', delay=True) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(log_format) + logger.handlers = [file_handler] + logger.propagate = False + + +def structured_logging(obj: Dict[str, Union[str, int, float]]): + structured_logger = logging.getLogger("structured_logger") + if not structured_logger.hasHandlers: + return + try: + structured_logger.info(json.dumps(obj)) + except Exception: + pass diff --git a/mammoth/utils/report_manager.py b/mammoth/utils/report_manager.py index 822938d0..894f757a 100644 --- a/mammoth/utils/report_manager.py +++ b/mammoth/utils/report_manager.py @@ -4,7 +4,7 @@ import mammoth -from mammoth.utils.logging import logger +from mammoth.utils.logging import logger, structured_logging def build_report_manager(opts, node_rank, local_rank): @@ -140,7 +140,16 @@ def _report_step(self, lr, patience, step, train_stats=None, valid_stats=None): self.maybe_log_tensorboard(train_stats, "train", lr, patience, step) if valid_stats is not None: - self.log('Validation perplexity: %g' % valid_stats.ppl()) - self.log('Validation accuracy: %g' % valid_stats.accuracy()) - + ppl = valid_stats.ppl() + acc = valid_stats.accuracy() + self.log('Validation perplexity: %g', ppl) + self.log('Validation accuracy: %g', acc) + structured_logging({ + 'type': 'validation', + 'step': step, + 'learning_rate': lr, + 'perplexity': ppl, + 'accuracy': acc, + 'crossentropy': valid_stats.xent(), + }) self.maybe_log_tensorboard(valid_stats, "valid", lr, patience, step)