diff --git a/bin/predict.py b/bin/predict.py index def13b88f..79ec5ab35 100644 --- a/bin/predict.py +++ b/bin/predict.py @@ -240,6 +240,7 @@ def main( epoch: int = typer.Option(...), batch_size: int = None, num_samples: int = 3, + input_transform_dataset: str = None, input_transform_key: str = None, ensemble_member: str = DEFAULT_ENSEMBLE_MEMBER, ): @@ -247,6 +248,10 @@ def main( config = load_config(config_path) if batch_size is not None: config.eval.batch_size = batch_size + if input_transform_dataset is not None: + config.data.input_transform_dataset = input_transform_dataset + else: + config.data.input_transform_dataset = dataset if input_transform_key is not None: config.data.input_transform_key = input_transform_key @@ -254,7 +259,7 @@ def main( workdir=workdir, checkpoint=f"epoch-{epoch}", dataset=dataset, - input_xfm=config.data.input_transform_key, + input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}", split=split, ensemble_member=ensemble_member, ) @@ -270,6 +275,7 @@ def main( eval_dl, _, target_transform = get_dataloader( dataset, config.data.dataset_name, + config.data.input_transform_dataset, config.data.input_transform_key, config.data.target_transform_key, transform_dir, diff --git a/src/ml_downscaling_emulator/bin/evaluate.py b/src/ml_downscaling_emulator/bin/evaluate.py index e6f78f36e..5915002cc 100644 --- a/src/ml_downscaling_emulator/bin/evaluate.py +++ b/src/ml_downscaling_emulator/bin/evaluate.py @@ -61,6 +61,7 @@ def sample( epoch: int = typer.Option(...), batch_size: int = None, num_samples: int = 1, + input_transform_dataset: str = None, input_transform_key: str = None, ensemble_member: str = DEFAULT_ENSEMBLE_MEMBER, ): @@ -70,6 +71,10 @@ def sample( if batch_size is not None: config.eval.batch_size = batch_size + if input_transform_dataset is not None: + config.data.input_transform_dataset = input_transform_dataset + else: + config.data.input_transform_dataset = dataset if input_transform_key is not None: config.data.input_transform_key = input_transform_key @@ -77,7 +82,7 @@ def sample( workdir=workdir, checkpoint=f"epoch-{epoch}", dataset=dataset, - input_xfm=config.data.input_transform_key, + input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}", split=split, ensemble_member=ensemble_member, ) @@ -88,6 +93,7 @@ def sample( eval_dl, _, target_transform = get_dataloader( dataset, config.data.dataset_name, + config.data.input_transform_dataset, config.data.input_transform_key, config.data.target_transform_key, transform_dir, diff --git a/src/ml_downscaling_emulator/deterministic/run_lib.py b/src/ml_downscaling_emulator/deterministic/run_lib.py index 407444cd8..79af70e34 100644 --- a/src/ml_downscaling_emulator/deterministic/run_lib.py +++ b/src/ml_downscaling_emulator/deterministic/run_lib.py @@ -77,6 +77,7 @@ def train(config, workdir): # Build dataloaders train_dl, _, _ = get_dataloader( + config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, @@ -89,6 +90,7 @@ def train(config, workdir): evaluation=False, ) val_dl, _, _ = get_dataloader( + config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/run_lib.py index b51672bd5..f82cd76ee 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/run_lib.py @@ -110,8 +110,8 @@ def train(config, workdir): ) as (wandb_run, writer): # Build dataloaders dataset_meta = DatasetMetadata(config.data.dataset_name) - train_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="train", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) - eval_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="val", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) + train_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="train", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) + eval_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="val", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) # Initialize model. score_model = mutils.create_model(config) diff --git a/src/ml_downscaling_emulator/torch.py b/src/ml_downscaling_emulator/torch.py index aa0b6a90b..857af654c 100644 --- a/src/ml_downscaling_emulator/torch.py +++ b/src/ml_downscaling_emulator/torch.py @@ -112,6 +112,7 @@ def custom_collate(batch): def get_dataloader( active_dataset_name, model_src_dataset_name, + input_transform_dataset_name, input_transform_key, target_transform_key, transform_dir, @@ -127,6 +128,7 @@ def get_dataloader( Args: active_dataset_name: Name of dataset from which to load data splits model_src_dataset_name: Name of dataset used to train the diffusion model (may be the same) + input_transform_dataset_name: Name of dataset to use for fitting input transform (may be the same as active_dataset_name or model_src_dataset_name) transform_dir: Path to where transforms should be stored input_transform_key: Name of input transform pipeline to use target_transform_key: Name of target transform pipeline to use @@ -140,6 +142,7 @@ def get_dataloader( xr_data, transform, target_transform = get_dataset( active_dataset_name, model_src_dataset_name, + input_transform_dataset_name, input_transform_key, target_transform_key, transform_dir,