From c3eaf4aef69a608ff07f96ca2d5ead6730d6a32e Mon Sep 17 00:00:00 2001 From: Joseph Attieh Date: Fri, 29 Sep 2023 16:02:24 +0300 Subject: [PATCH 1/5] Logging valid steps in another file --- mammoth/bin/train.py | 4 +++- mammoth/opts.py | 1 + mammoth/utils/logging.py | 31 ++++++++++++++++++++++++++++++- mammoth/utils/report_manager.py | 8 ++++++-- 4 files changed, 40 insertions(+), 4 deletions(-) diff --git a/mammoth/bin/train.py b/mammoth/bin/train.py index 4c1d9072..8327176f 100644 --- a/mammoth/bin/train.py +++ b/mammoth/bin/train.py @@ -15,7 +15,7 @@ ) from mammoth.utils.misc import set_random_seed # from mammoth.modules.embeddings import prepare_pretrained_embeddings -from mammoth.utils.logging import init_logger, logger +from mammoth.utils.logging import init_logger, logger,init_valid_logger from mammoth.models.model_saver import load_checkpoint from mammoth.train_single import main as single_main @@ -145,6 +145,8 @@ def validate_slurm_node_opts(current_env, world_context, opts): def train(opts): init_logger(opts.log_file) + if opts.valid_log_file is not None and opts.valid_log_file !="": + init_valid_logger(opts.valid_log_file) ArgumentParser.validate_train_opts(opts) ArgumentParser.update_model_opts(opts) ArgumentParser.validate_model_opts(opts) diff --git a/mammoth/opts.py b/mammoth/opts.py index 22a3dfd6..4a105f62 100644 --- a/mammoth/opts.py +++ b/mammoth/opts.py @@ -23,6 +23,7 @@ def config_opts(parser): def _add_logging_opts(parser, is_train=True): group = parser.add_argument_group('Logging') group.add('--log_file', '-log_file', type=str, default="", help="Output logs to a file under this path.") + group.add('--valid_log_file', '-valid_log_file', type=str, default="", help="Output logs to a file under this path.") group.add( '--log_file_level', '-log_file_level', diff --git a/mammoth/utils/logging.py b/mammoth/utils/logging.py index 895a89b1..7ccce83a 100644 --- a/mammoth/utils/logging.py +++ b/mammoth/utils/logging.py @@ -6,6 +6,9 @@ logger = logging.getLogger() +valid_logger = logging.getLogger("valid_logger") +if not valid_logger.hasHandlers: + valid_logger =None def init_logger( log_file=None, @@ -29,11 +32,37 @@ def init_logger( file_handler = logging.FileHandler(log_file) file_handler.setLevel(log_file_level) file_handler.setFormatter(log_format) - logger.addHandler(file_handler) + logger.addHandler(file_handler) return logger +def init_valid_logger( + log_file=None, + log_file_level=logging.DEBUG, + rotate=False, + log_level=logging.DEBUG, + gpu_id='', +): + log_format = logging.Formatter(f"[%(asctime)s %(process)s {gpu_id} %(levelname)s] %(message)s") + logger = logging.getLogger("valid_logger") + logger.setLevel(log_level) + console_handler = logging.StreamHandler() + console_handler.setFormatter(log_format) + logger.handlers = [console_handler] + + if log_file and log_file != '': + if rotate: + file_handler = RotatingFileHandler(log_file, maxBytes=1000000, backupCount=10, mode='a', buffering=1, delay=True) + else: + file_handler = logging.FileHandler(log_file, mode='a', buffering=1, delay=True) + file_handler.setLevel(log_file_level) + file_handler.setFormatter(log_format) + logger.addHandler(file_handler) + logger.propagate = False + return logger + + def log_lca_values(step, lca_logs, lca_params, opath, dump_logs=False): for k, v in lca_params.items(): lca_sum = v.sum().item() diff --git a/mammoth/utils/report_manager.py b/mammoth/utils/report_manager.py index 822938d0..97e44e8a 100644 --- a/mammoth/utils/report_manager.py +++ b/mammoth/utils/report_manager.py @@ -4,7 +4,7 @@ import mammoth -from mammoth.utils.logging import logger +from mammoth.utils.logging import logger, valid_logger def build_report_manager(opts, node_rank, local_rank): @@ -51,6 +51,10 @@ def start(self): def log(self, *args, **kwargs): logger.info(*args, **kwargs) + def log_valid(self, *args, **kwargs): + if valid_logger is not None: + valid_logger.info(*args, **kwargs) + def report_training(self, step, num_steps, learning_rate, patience, report_stats, multigpu=False): """ This is the user-defined batch-level traing progress @@ -142,5 +146,5 @@ def _report_step(self, lr, patience, step, train_stats=None, valid_stats=None): if valid_stats is not None: self.log('Validation perplexity: %g' % valid_stats.ppl()) self.log('Validation accuracy: %g' % valid_stats.accuracy()) - + self.log_valid(f'Step {step}; lr: {lr}; ppl: {valid_stats.ppl()}; acc: {valid_stats.accuracy()}; xent: {valid_stats.xent()};' ) self.maybe_log_tensorboard(valid_stats, "valid", lr, patience, step) From ec596f4f770ffe4d060a15dac2a3a0d302da93fc Mon Sep 17 00:00:00 2001 From: Joseph Attieh Date: Fri, 29 Sep 2023 16:02:54 +0300 Subject: [PATCH 2/5] Changed the logging message to log only metrics --- mammoth/utils/logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mammoth/utils/logging.py b/mammoth/utils/logging.py index 7ccce83a..18b884e8 100644 --- a/mammoth/utils/logging.py +++ b/mammoth/utils/logging.py @@ -43,7 +43,7 @@ def init_valid_logger( log_level=logging.DEBUG, gpu_id='', ): - log_format = logging.Formatter(f"[%(asctime)s %(process)s {gpu_id} %(levelname)s] %(message)s") + log_format = logging.Formatter(f"%(message)s") logger = logging.getLogger("valid_logger") logger.setLevel(log_level) console_handler = logging.StreamHandler() From 41acdbbd16f40b5bbfd8a7fe09f40bf54c2ed863 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 2 Oct 2023 16:50:27 +0300 Subject: [PATCH 3/5] WIP: structured logging --- mammoth/bin/train.py | 6 ++-- mammoth/opts.py | 8 ++++- mammoth/utils/logging.py | 63 +++++++++++++-------------------- mammoth/utils/report_manager.py | 21 ++++++----- 4 files changed, 46 insertions(+), 52 deletions(-) diff --git a/mammoth/bin/train.py b/mammoth/bin/train.py index 8327176f..330b6dde 100644 --- a/mammoth/bin/train.py +++ b/mammoth/bin/train.py @@ -15,7 +15,7 @@ ) from mammoth.utils.misc import set_random_seed # from mammoth.modules.embeddings import prepare_pretrained_embeddings -from mammoth.utils.logging import init_logger, logger,init_valid_logger +from mammoth.utils.logging import init_logger, logger from mammoth.models.model_saver import load_checkpoint from mammoth.train_single import main as single_main @@ -144,9 +144,7 @@ def validate_slurm_node_opts(current_env, world_context, opts): def train(opts): - init_logger(opts.log_file) - if opts.valid_log_file is not None and opts.valid_log_file !="": - init_valid_logger(opts.valid_log_file) + init_logger(opts.log_file, structured_log_file=opts.structured_log_file) ArgumentParser.validate_train_opts(opts) ArgumentParser.update_model_opts(opts) ArgumentParser.validate_model_opts(opts) diff --git a/mammoth/opts.py b/mammoth/opts.py index 4a105f62..ca8a31d6 100644 --- a/mammoth/opts.py +++ b/mammoth/opts.py @@ -23,7 +23,13 @@ def config_opts(parser): def _add_logging_opts(parser, is_train=True): group = parser.add_argument_group('Logging') group.add('--log_file', '-log_file', type=str, default="", help="Output logs to a file under this path.") - group.add('--valid_log_file', '-valid_log_file', type=str, default="", help="Output logs to a file under this path.") + group.add( + '--structured_log_file', + '-structured_log_file', + type=str, + default="", + help="Output machine-readable structured logs to a file under this path." + ) group.add( '--log_file_level', '-log_file_level', diff --git a/mammoth/utils/logging.py b/mammoth/utils/logging.py index 18b884e8..e904846b 100644 --- a/mammoth/utils/logging.py +++ b/mammoth/utils/logging.py @@ -1,14 +1,11 @@ # -*- coding: utf-8 -*- -import os import json import logging from logging.handlers import RotatingFileHandler +from typing import Dict, Union logger = logging.getLogger() -valid_logger = logging.getLogger("valid_logger") -if not valid_logger.hasHandlers: - valid_logger =None def init_logger( log_file=None, @@ -16,6 +13,7 @@ def init_logger( rotate=False, log_level=logging.INFO, gpu_id='', + structured_log_file=None, ): log_format = logging.Formatter(f"[%(asctime)s %(process)s {gpu_id} %(levelname)s] %(message)s") logger = logging.getLogger() @@ -32,46 +30,33 @@ def init_logger( file_handler = logging.FileHandler(log_file) file_handler.setLevel(log_file_level) file_handler.setFormatter(log_format) - logger.addHandler(file_handler) + logger.addHandler(file_handler) + + if structured_log_file: + init_structured_logger(structured_log_file) return logger -def init_valid_logger( + +def init_structured_logger( log_file=None, - log_file_level=logging.DEBUG, - rotate=False, - log_level=logging.DEBUG, - gpu_id='', ): - log_format = logging.Formatter(f"%(message)s") - logger = logging.getLogger("valid_logger") - logger.setLevel(log_level) - console_handler = logging.StreamHandler() - console_handler.setFormatter(log_format) - logger.handlers = [console_handler] - - if log_file and log_file != '': - if rotate: - file_handler = RotatingFileHandler(log_file, maxBytes=1000000, backupCount=10, mode='a', buffering=1, delay=True) - else: - file_handler = logging.FileHandler(log_file, mode='a', buffering=1, delay=True) - file_handler.setLevel(log_file_level) - file_handler.setFormatter(log_format) - logger.addHandler(file_handler) + # Log should be parseable as a jsonl file. Format should not include anything extra. + log_format = logging.Formatter("%(message)s") + logger = logging.getLogger("structured_logger") + logger.setLevel(logging.INFO) + file_handler = logging.FileHandler(log_file, mode='a', buffering=1, delay=True) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(log_format) + logger.handlers = [file_handler] logger.propagate = False - return logger - - -def log_lca_values(step, lca_logs, lca_params, opath, dump_logs=False): - for k, v in lca_params.items(): - lca_sum = v.sum().item() - lca_mean = v.mean().item() - lca_logs[k][f'STEP_{step}'] = {'sum': lca_sum, 'mean': lca_mean} - if dump_logs: - if os.path.exists(opath): - os.system(f'cp {opath} {opath}.previous') - with open(opath, 'w+') as f: - json.dump(lca_logs, f) - logger.info(f'dumped LCA logs in {opath}') +def structured_logging(obj: Dict[str, Union[str, int, float]]): + structured_logger = logging.getLogger("structured_logger") + if not structured_logger.hasHandlers: + return + try: + structured_logger.info(json.dumps(obj)) + except Exception: + pass diff --git a/mammoth/utils/report_manager.py b/mammoth/utils/report_manager.py index 97e44e8a..894f757a 100644 --- a/mammoth/utils/report_manager.py +++ b/mammoth/utils/report_manager.py @@ -4,7 +4,7 @@ import mammoth -from mammoth.utils.logging import logger, valid_logger +from mammoth.utils.logging import logger, structured_logging def build_report_manager(opts, node_rank, local_rank): @@ -51,10 +51,6 @@ def start(self): def log(self, *args, **kwargs): logger.info(*args, **kwargs) - def log_valid(self, *args, **kwargs): - if valid_logger is not None: - valid_logger.info(*args, **kwargs) - def report_training(self, step, num_steps, learning_rate, patience, report_stats, multigpu=False): """ This is the user-defined batch-level traing progress @@ -144,7 +140,16 @@ def _report_step(self, lr, patience, step, train_stats=None, valid_stats=None): self.maybe_log_tensorboard(train_stats, "train", lr, patience, step) if valid_stats is not None: - self.log('Validation perplexity: %g' % valid_stats.ppl()) - self.log('Validation accuracy: %g' % valid_stats.accuracy()) - self.log_valid(f'Step {step}; lr: {lr}; ppl: {valid_stats.ppl()}; acc: {valid_stats.accuracy()}; xent: {valid_stats.xent()};' ) + ppl = valid_stats.ppl() + acc = valid_stats.accuracy() + self.log('Validation perplexity: %g', ppl) + self.log('Validation accuracy: %g', acc) + structured_logging({ + 'type': 'validation', + 'step': step, + 'learning_rate': lr, + 'perplexity': ppl, + 'accuracy': acc, + 'crossentropy': valid_stats.xent(), + }) self.maybe_log_tensorboard(valid_stats, "valid", lr, patience, step) From db11ffa09ecbff8fa0145a8c849a8451f11369b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 2 Oct 2023 18:04:27 +0300 Subject: [PATCH 4/5] Opts is Namespace, not a dict configargparse keeps on sucking --- mammoth/inputters/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mammoth/inputters/dataset.py b/mammoth/inputters/dataset.py index 6bb2d1a9..fe9e5cf2 100644 --- a/mammoth/inputters/dataset.py +++ b/mammoth/inputters/dataset.py @@ -179,7 +179,7 @@ def get_corpus(opts, task, src_vocab: Vocab, tgt_vocab: Vocab, is_train: bool = vocabs = {'src': src_vocab, 'tgt': tgt_vocab} corpus_opts = opts.tasks[task.corpus_id] transforms_to_apply = corpus_opts.get('transforms', None) - transforms_to_apply = transforms_to_apply or opts.get('transforms', None) + transforms_to_apply = transforms_to_apply or opts.transforms transforms_to_apply = transforms_to_apply or [] transforms_cls = make_transforms( opts, From 1fdd7a65ca9db2063e09f3d7f44bd51e3fd6b983 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 2 Oct 2023 18:13:14 +0300 Subject: [PATCH 5/5] No such flag buffering --- mammoth/utils/logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mammoth/utils/logging.py b/mammoth/utils/logging.py index e904846b..3b923955 100644 --- a/mammoth/utils/logging.py +++ b/mammoth/utils/logging.py @@ -45,7 +45,7 @@ def init_structured_logger( log_format = logging.Formatter("%(message)s") logger = logging.getLogger("structured_logger") logger.setLevel(logging.INFO) - file_handler = logging.FileHandler(log_file, mode='a', buffering=1, delay=True) + file_handler = logging.FileHandler(log_file, mode='a', delay=True) file_handler.setLevel(logging.INFO) file_handler.setFormatter(log_format) logger.handlers = [file_handler]