diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py index 51e15ea1..1371e61e 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py @@ -218,9 +218,7 @@ def train(config, workdir): val_set_loss = val_loss(config, eval_dl, eval_step_fn, state) epoch_metrics = {"epoch/train/loss": train_set_loss, "epoch/val/loss": val_set_loss} - logging.info("epoch: %d, val_loss: %.5e" % (state['epoch'], val_set_loss)) - writer.add_scalar("epoch/val/loss", val_set_loss, global_step=state['epoch']) - log_epoch(state['epoch'], epoch_metrics, wandb_run,writer) + log_epoch(state['epoch'], epoch_metrics, wandb_run, writer) if (state['epoch'] != 0 and state['epoch'] % config.training.snapshot_freq == 0) or state['epoch'] == num_train_epochs: # Save the checkpoint.