diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py index 14c398cba..79d5ecd73 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py @@ -177,7 +177,8 @@ def train(config, workdir): if config.training.random_crop_size > 0: random_crop = torchvision.transforms.RandomCrop(config.training.random_crop_size) - + if config.model.loc_spec_channels > 0: + before = state['location_params'].module.params[0,20,20:30].clone().detach() for epoch in range(initial_epoch, num_train_epochs + 1): state['epoch'] = epoch train_set_loss = 0.0 @@ -226,3 +227,9 @@ def train(config, workdir): checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{state['epoch']}.pth") save_checkpoint(checkpoint_path, state) logging.info(f"epoch: {state['epoch']}, checkpoint saved to {checkpoint_path}") + + if config.model.loc_spec_channels > 0: + after = state['location_params'].module.params[0,20,20:30].clone().detach() + logging.info(f"Before: {before}") + logging.info(f"After: {after}") + logging.info(f"Diff: {after-before}")