From 8894f7f99f4f24e2512b66be1035263088611032 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Wed, 14 Feb 2024 14:37:21 +0000 Subject: [PATCH] allow for setting of which dataset a transform should be fitted to so doesn't have to be the same as the dataset it is applied during sampling during training it is always the same dataset for both (though potentially even this could be uncoupled) --- bin/predict.py | 8 +++++++- src/ml_downscaling_emulator/bin/evaluate.py | 8 +++++++- src/ml_downscaling_emulator/deterministic/run_lib.py | 2 ++ .../score_sde_pytorch_hja22/run_lib.py | 4 ++-- src/ml_downscaling_emulator/torch.py | 3 +++ 5 files changed, 21 insertions(+), 4 deletions(-) diff --git a/bin/predict.py b/bin/predict.py index def13b88f..79ec5ab35 100644 --- a/bin/predict.py +++ b/bin/predict.py @@ -240,6 +240,7 @@ def main( epoch: int = typer.Option(...), batch_size: int = None, num_samples: int = 3, + input_transform_dataset: str = None, input_transform_key: str = None, ensemble_member: str = DEFAULT_ENSEMBLE_MEMBER, ): @@ -247,6 +248,10 @@ def main( config = load_config(config_path) if batch_size is not None: config.eval.batch_size = batch_size + if input_transform_dataset is not None: + config.data.input_transform_dataset = input_transform_dataset + else: + config.data.input_transform_dataset = dataset if input_transform_key is not None: config.data.input_transform_key = input_transform_key @@ -254,7 +259,7 @@ def main( workdir=workdir, checkpoint=f"epoch-{epoch}", dataset=dataset, - input_xfm=config.data.input_transform_key, + input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}", split=split, ensemble_member=ensemble_member, ) @@ -270,6 +275,7 @@ def main( eval_dl, _, target_transform = get_dataloader( dataset, config.data.dataset_name, + config.data.input_transform_dataset, config.data.input_transform_key, config.data.target_transform_key, transform_dir, diff --git a/src/ml_downscaling_emulator/bin/evaluate.py b/src/ml_downscaling_emulator/bin/evaluate.py index e6f78f36e..5915002cc 100644 --- a/src/ml_downscaling_emulator/bin/evaluate.py +++ b/src/ml_downscaling_emulator/bin/evaluate.py @@ -61,6 +61,7 @@ def sample( epoch: int = typer.Option(...), batch_size: int = None, num_samples: int = 1, + input_transform_dataset: str = None, input_transform_key: str = None, ensemble_member: str = DEFAULT_ENSEMBLE_MEMBER, ): @@ -70,6 +71,10 @@ def sample( if batch_size is not None: config.eval.batch_size = batch_size + if input_transform_dataset is not None: + config.data.input_transform_dataset = input_transform_dataset + else: + config.data.input_transform_dataset = dataset if input_transform_key is not None: config.data.input_transform_key = input_transform_key @@ -77,7 +82,7 @@ def sample( workdir=workdir, checkpoint=f"epoch-{epoch}", dataset=dataset, - input_xfm=config.data.input_transform_key, + input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}", split=split, ensemble_member=ensemble_member, ) @@ -88,6 +93,7 @@ def sample( eval_dl, _, target_transform = get_dataloader( dataset, config.data.dataset_name, + config.data.input_transform_dataset, config.data.input_transform_key, config.data.target_transform_key, transform_dir, diff --git a/src/ml_downscaling_emulator/deterministic/run_lib.py b/src/ml_downscaling_emulator/deterministic/run_lib.py index 407444cd8..79af70e34 100644 --- a/src/ml_downscaling_emulator/deterministic/run_lib.py +++ b/src/ml_downscaling_emulator/deterministic/run_lib.py @@ -77,6 +77,7 @@ def train(config, workdir): # Build dataloaders train_dl, _, _ = get_dataloader( + config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, @@ -89,6 +90,7 @@ def train(config, workdir): evaluation=False, ) val_dl, _, _ = get_dataloader( + config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/run_lib.py index b51672bd5..f82cd76ee 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/run_lib.py @@ -110,8 +110,8 @@ def train(config, workdir): ) as (wandb_run, writer): # Build dataloaders dataset_meta = DatasetMetadata(config.data.dataset_name) - train_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="train", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) - eval_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="val", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) + train_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="train", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) + eval_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="val", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) # Initialize model. score_model = mutils.create_model(config) diff --git a/src/ml_downscaling_emulator/torch.py b/src/ml_downscaling_emulator/torch.py index aa0b6a90b..857af654c 100644 --- a/src/ml_downscaling_emulator/torch.py +++ b/src/ml_downscaling_emulator/torch.py @@ -112,6 +112,7 @@ def custom_collate(batch): def get_dataloader( active_dataset_name, model_src_dataset_name, + input_transform_dataset_name, input_transform_key, target_transform_key, transform_dir, @@ -127,6 +128,7 @@ def get_dataloader( Args: active_dataset_name: Name of dataset from which to load data splits model_src_dataset_name: Name of dataset used to train the diffusion model (may be the same) + input_transform_dataset_name: Name of dataset to use for fitting input transform (may be the same as active_dataset_name or model_src_dataset_name) transform_dir: Path to where transforms should be stored input_transform_key: Name of input transform pipeline to use target_transform_key: Name of target transform pipeline to use @@ -140,6 +142,7 @@ def get_dataloader( xr_data, transform, target_transform = get_dataset( active_dataset_name, model_src_dataset_name, + input_transform_dataset_name, input_transform_key, target_transform_key, transform_dir,