From 6f1060bf6ea45c2e75e36b4b1b9a86f2858327e1 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Mon, 24 Jun 2024 14:18:30 +0100 Subject: [PATCH] don't duplicate logging for validation loss each epoch --- src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py index 51e15ea1..1371e61e 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py @@ -218,9 +218,7 @@ def train(config, workdir): val_set_loss = val_loss(config, eval_dl, eval_step_fn, state) epoch_metrics = {"epoch/train/loss": train_set_loss, "epoch/val/loss": val_set_loss} - logging.info("epoch: %d, val_loss: %.5e" % (state['epoch'], val_set_loss)) - writer.add_scalar("epoch/val/loss", val_set_loss, global_step=state['epoch']) - log_epoch(state['epoch'], epoch_metrics, wandb_run,writer) + log_epoch(state['epoch'], epoch_metrics, wandb_run, writer) if (state['epoch'] != 0 and state['epoch'] % config.training.snapshot_freq == 0) or state['epoch'] == num_train_epochs: # Save the checkpoint.