Skip to content

Commit

Permalink
Logging valid steps in another file
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Attieh authored and Joseph Attieh committed Sep 29, 2023
1 parent 61910fc commit 9325e01
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 4 deletions.
4 changes: 3 additions & 1 deletion mammoth/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from mammoth.utils.misc import set_random_seed
# from mammoth.modules.embeddings import prepare_pretrained_embeddings
from mammoth.utils.logging import init_logger, logger
from mammoth.utils.logging import init_logger, logger,init_valid_logger

from mammoth.models.model_saver import load_checkpoint
from mammoth.train_single import main as single_main
Expand Down Expand Up @@ -145,6 +145,8 @@ def validate_slurm_node_opts(current_env, world_context, opts):

def train(opts):
init_logger(opts.log_file)
if opts.valid_log_file is not None and opts.valid_log_file !="":
init_valid_logger(opts.valid_log_file)
ArgumentParser.validate_train_opts(opts)
ArgumentParser.update_model_opts(opts)
ArgumentParser.validate_model_opts(opts)
Expand Down
1 change: 1 addition & 0 deletions mammoth/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def config_opts(parser):
def _add_logging_opts(parser, is_train=True):
group = parser.add_argument_group('Logging')
group.add('--log_file', '-log_file', type=str, default="", help="Output logs to a file under this path.")
group.add('--valid_log_file', '-valid_log_file', type=str, default="", help="Output logs to a file under this path.")
group.add(
'--log_file_level',
'-log_file_level',
Expand Down
31 changes: 30 additions & 1 deletion mammoth/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

logger = logging.getLogger()

valid_logger = logging.getLogger("valid_logger")
if not valid_logger.hasHandlers:
valid_logger =None

def init_logger(
log_file=None,
Expand All @@ -29,11 +32,37 @@ def init_logger(
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(log_file_level)
file_handler.setFormatter(log_format)
logger.addHandler(file_handler)
logger.addHandler(file_handler)

return logger

def init_valid_logger(
log_file=None,
log_file_level=logging.DEBUG,
rotate=False,
log_level=logging.DEBUG,
gpu_id='',
):
log_format = logging.Formatter(f"[%(asctime)s %(process)s {gpu_id} %(levelname)s] %(message)s")
logger = logging.getLogger("valid_logger")
logger.setLevel(log_level)
console_handler = logging.StreamHandler()
console_handler.setFormatter(log_format)
logger.handlers = [console_handler]

if log_file and log_file != '':
if rotate:
file_handler = RotatingFileHandler(log_file, maxBytes=1000000, backupCount=10, mode='a', buffering=1, delay=True)
else:
file_handler = logging.FileHandler(log_file, mode='a', buffering=1, delay=True)
file_handler.setLevel(log_file_level)
file_handler.setFormatter(log_format)
logger.addHandler(file_handler)
logger.propagate = False
return logger



def log_lca_values(step, lca_logs, lca_params, opath, dump_logs=False):
for k, v in lca_params.items():
lca_sum = v.sum().item()
Expand Down
8 changes: 6 additions & 2 deletions mammoth/utils/report_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import mammoth

from mammoth.utils.logging import logger
from mammoth.utils.logging import logger, valid_logger


def build_report_manager(opts, node_rank, local_rank):
Expand Down Expand Up @@ -51,6 +51,10 @@ def start(self):
def log(self, *args, **kwargs):
logger.info(*args, **kwargs)

def log_valid(self, *args, **kwargs):
if valid_logger is not None:
valid_logger.info(*args, **kwargs)

def report_training(self, step, num_steps, learning_rate, patience, report_stats, multigpu=False):
"""
This is the user-defined batch-level traing progress
Expand Down Expand Up @@ -142,5 +146,5 @@ def _report_step(self, lr, patience, step, train_stats=None, valid_stats=None):
if valid_stats is not None:
self.log('Validation perplexity: %g' % valid_stats.ppl())
self.log('Validation accuracy: %g' % valid_stats.accuracy())

self.log_valid(f'Step {step}; lr: {lr}; ppl: {valid_stats.ppl()}; acc: {valid_stats.accuracy()}; xent: {valid_stats.xent()};' )
self.maybe_log_tensorboard(valid_stats, "valid", lr, patience, step)

0 comments on commit 9325e01

Please sign in to comment.