Skip to content

Commit

Permalink
Merge pull request #26 from Helsinki-NLP/fix/valid_logs
Browse files Browse the repository at this point in the history
feature: valid logs
  • Loading branch information
Waino authored Oct 2, 2023
2 parents 9232208 + 1fdd7a6 commit bf87d1e
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 19 deletions.
2 changes: 1 addition & 1 deletion mammoth/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def validate_slurm_node_opts(current_env, world_context, opts):


def train(opts):
init_logger(opts.log_file)
init_logger(opts.log_file, structured_log_file=opts.structured_log_file)
ArgumentParser.validate_train_opts(opts)
ArgumentParser.update_model_opts(opts)
ArgumentParser.validate_model_opts(opts)
Expand Down
2 changes: 1 addition & 1 deletion mammoth/inputters/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def get_corpus(opts, task, src_vocab: Vocab, tgt_vocab: Vocab, is_train: bool =
vocabs = {'src': src_vocab, 'tgt': tgt_vocab}
corpus_opts = opts.tasks[task.corpus_id]
transforms_to_apply = corpus_opts.get('transforms', None)
transforms_to_apply = transforms_to_apply or opts.get('transforms', None)
transforms_to_apply = transforms_to_apply or opts.transforms
transforms_to_apply = transforms_to_apply or []
transforms_cls = make_transforms(
opts,
Expand Down
7 changes: 7 additions & 0 deletions mammoth/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ def config_opts(parser):
def _add_logging_opts(parser, is_train=True):
group = parser.add_argument_group('Logging')
group.add('--log_file', '-log_file', type=str, default="", help="Output logs to a file under this path.")
group.add(
'--structured_log_file',
'-structured_log_file',
type=str,
default="",
help="Output machine-readable structured logs to a file under this path."
)
group.add(
'--log_file_level',
'-log_file_level',
Expand Down
40 changes: 27 additions & 13 deletions mammoth/utils/logging.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
import os
import json
import logging
from logging.handlers import RotatingFileHandler
from typing import Dict, Union

logger = logging.getLogger()

Expand All @@ -13,6 +13,7 @@ def init_logger(
rotate=False,
log_level=logging.INFO,
gpu_id='',
structured_log_file=None,
):
log_format = logging.Formatter(f"[%(asctime)s %(process)s {gpu_id} %(levelname)s] %(message)s")
logger = logging.getLogger()
Expand All @@ -31,18 +32,31 @@ def init_logger(
file_handler.setFormatter(log_format)
logger.addHandler(file_handler)

return logger
if structured_log_file:
init_structured_logger(structured_log_file)

return logger

def log_lca_values(step, lca_logs, lca_params, opath, dump_logs=False):
for k, v in lca_params.items():
lca_sum = v.sum().item()
lca_mean = v.mean().item()
lca_logs[k][f'STEP_{step}'] = {'sum': lca_sum, 'mean': lca_mean}

if dump_logs:
if os.path.exists(opath):
os.system(f'cp {opath} {opath}.previous')
with open(opath, 'w+') as f:
json.dump(lca_logs, f)
logger.info(f'dumped LCA logs in {opath}')
def init_structured_logger(
log_file=None,
):
# Log should be parseable as a jsonl file. Format should not include anything extra.
log_format = logging.Formatter("%(message)s")
logger = logging.getLogger("structured_logger")
logger.setLevel(logging.INFO)
file_handler = logging.FileHandler(log_file, mode='a', delay=True)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(log_format)
logger.handlers = [file_handler]
logger.propagate = False


def structured_logging(obj: Dict[str, Union[str, int, float]]):
structured_logger = logging.getLogger("structured_logger")
if not structured_logger.hasHandlers:
return
try:
structured_logger.info(json.dumps(obj))
except Exception:
pass
17 changes: 13 additions & 4 deletions mammoth/utils/report_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import mammoth

from mammoth.utils.logging import logger
from mammoth.utils.logging import logger, structured_logging


def build_report_manager(opts, node_rank, local_rank):
Expand Down Expand Up @@ -140,7 +140,16 @@ def _report_step(self, lr, patience, step, train_stats=None, valid_stats=None):
self.maybe_log_tensorboard(train_stats, "train", lr, patience, step)

if valid_stats is not None:
self.log('Validation perplexity: %g' % valid_stats.ppl())
self.log('Validation accuracy: %g' % valid_stats.accuracy())

ppl = valid_stats.ppl()
acc = valid_stats.accuracy()
self.log('Validation perplexity: %g', ppl)
self.log('Validation accuracy: %g', acc)
structured_logging({
'type': 'validation',
'step': step,
'learning_rate': lr,
'perplexity': ppl,
'accuracy': acc,
'crossentropy': valid_stats.xent(),
})
self.maybe_log_tensorboard(valid_stats, "valid", lr, patience, step)

0 comments on commit bf87d1e

Please sign in to comment.