Skip to content

Commit

Permalink
compute and log val loss before 1st epoch
Browse files Browse the repository at this point in the history
partly for debugging but also to see improvements
  • Loading branch information
henryaddison committed Jun 24, 2024
1 parent df90518 commit 67cd68b
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ def train(config, workdir):
if config.training.random_crop_size > 0:
random_crop = torchvision.transforms.RandomCrop(config.training.random_crop_size)

# log val loss before any training
if int(state['epoch']) == 0:
val_set_loss = val_loss(config, eval_dl, eval_step_fn, state)
epoch_metrics = {"epoch/val/loss": val_set_loss}
log_epoch(state['epoch'], epoch_metrics, wandb_run, writer)

for epoch in range(initial_epoch, num_train_epochs + 1):
state['epoch'] = epoch
train_set_loss = 0.0
Expand Down

0 comments on commit 67cd68b

Please sign in to comment.