From 6deab10563a4b22be26479d0a9b3db637c6099a9 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Thu, 20 Jun 2024 03:02:07 +0100 Subject: [PATCH] don't shuffle val dataset now val loss uses fixed random numbers --- src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py index 998df6920..5dd339038 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py @@ -112,7 +112,7 @@ def train(config, workdir): # Build dataloaders dataset_meta = DatasetMetadata(config.data.dataset_name) train_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="train", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) - eval_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="val", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) + eval_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="val", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False, shuffle=False) # Initialize model. score_model = mutils.create_model(config)