Skip to content

Commit

Permalink
log change in loc spec params
Browse files Browse the repository at this point in the history
just to be sure they are being trained
  • Loading branch information
henryaddison committed Aug 15, 2024
1 parent 72c70e3 commit 0f28426
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ def train(config, workdir):

if config.training.random_crop_size > 0:
random_crop = torchvision.transforms.RandomCrop(config.training.random_crop_size)

if config.model.loc_spec_channels > 0:
before = state['location_params'].module.params[0,20,20:30].clone().detach()
for epoch in range(initial_epoch, num_train_epochs + 1):
state['epoch'] = epoch
train_set_loss = 0.0
Expand Down Expand Up @@ -226,3 +227,9 @@ def train(config, workdir):
checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{state['epoch']}.pth")
save_checkpoint(checkpoint_path, state)
logging.info(f"epoch: {state['epoch']}, checkpoint saved to {checkpoint_path}")

if config.model.loc_spec_channels > 0:
after = state['location_params'].module.params[0,20,20:30].clone().detach()
logging.info(f"Before: {before}")
logging.info(f"After: {after}")
logging.info(f"Diff: {after-before}")

0 comments on commit 0f28426

Please sign in to comment.