Skip to content

Commit

Permalink
allow for setting of which dataset a transform should be fitted to
Browse files Browse the repository at this point in the history
so doesn't have to be the same as the dataset it is applied  during sampling
during training it is always the same dataset for both (though potentially even this could be uncoupled)
  • Loading branch information
henryaddison committed Feb 14, 2024
1 parent 394454e commit 8894f7f
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 4 deletions.
8 changes: 7 additions & 1 deletion bin/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,21 +240,26 @@ def main(
epoch: int = typer.Option(...),
batch_size: int = None,
num_samples: int = 3,
input_transform_dataset: str = None,
input_transform_key: str = None,
ensemble_member: str = DEFAULT_ENSEMBLE_MEMBER,
):
config_path = os.path.join(workdir, "config.yml")
config = load_config(config_path)
if batch_size is not None:
config.eval.batch_size = batch_size
if input_transform_dataset is not None:
config.data.input_transform_dataset = input_transform_dataset
else:
config.data.input_transform_dataset = dataset
if input_transform_key is not None:
config.data.input_transform_key = input_transform_key

output_dirpath = samples_path(
workdir=workdir,
checkpoint=f"epoch-{epoch}",
dataset=dataset,
input_xfm=config.data.input_transform_key,
input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}",
split=split,
ensemble_member=ensemble_member,
)
Expand All @@ -270,6 +275,7 @@ def main(
eval_dl, _, target_transform = get_dataloader(
dataset,
config.data.dataset_name,
config.data.input_transform_dataset,
config.data.input_transform_key,
config.data.target_transform_key,
transform_dir,
Expand Down
8 changes: 7 additions & 1 deletion src/ml_downscaling_emulator/bin/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def sample(
epoch: int = typer.Option(...),
batch_size: int = None,
num_samples: int = 1,
input_transform_dataset: str = None,
input_transform_key: str = None,
ensemble_member: str = DEFAULT_ENSEMBLE_MEMBER,
):
Expand All @@ -70,14 +71,18 @@ def sample(

if batch_size is not None:
config.eval.batch_size = batch_size
if input_transform_dataset is not None:
config.data.input_transform_dataset = input_transform_dataset
else:
config.data.input_transform_dataset = dataset
if input_transform_key is not None:
config.data.input_transform_key = input_transform_key

output_dirpath = samples_path(
workdir=workdir,
checkpoint=f"epoch-{epoch}",
dataset=dataset,
input_xfm=config.data.input_transform_key,
input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}",
split=split,
ensemble_member=ensemble_member,
)
Expand All @@ -88,6 +93,7 @@ def sample(
eval_dl, _, target_transform = get_dataloader(
dataset,
config.data.dataset_name,
config.data.input_transform_dataset,
config.data.input_transform_key,
config.data.target_transform_key,
transform_dir,
Expand Down
2 changes: 2 additions & 0 deletions src/ml_downscaling_emulator/deterministic/run_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def train(config, workdir):

# Build dataloaders
train_dl, _, _ = get_dataloader(
config.data.dataset_name,
config.data.dataset_name,
config.data.dataset_name,
config.data.input_transform_key,
Expand All @@ -89,6 +90,7 @@ def train(config, workdir):
evaluation=False,
)
val_dl, _, _ = get_dataloader(
config.data.dataset_name,
config.data.dataset_name,
config.data.dataset_name,
config.data.input_transform_key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def train(config, workdir):
) as (wandb_run, writer):
# Build dataloaders
dataset_meta = DatasetMetadata(config.data.dataset_name)
train_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="train", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False)
eval_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="val", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False)
train_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="train", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False)
eval_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="val", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False)

# Initialize model.
score_model = mutils.create_model(config)
Expand Down
3 changes: 3 additions & 0 deletions src/ml_downscaling_emulator/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def custom_collate(batch):
def get_dataloader(
active_dataset_name,
model_src_dataset_name,
input_transform_dataset_name,
input_transform_key,
target_transform_key,
transform_dir,
Expand All @@ -127,6 +128,7 @@ def get_dataloader(
Args:
active_dataset_name: Name of dataset from which to load data splits
model_src_dataset_name: Name of dataset used to train the diffusion model (may be the same)
input_transform_dataset_name: Name of dataset to use for fitting input transform (may be the same as active_dataset_name or model_src_dataset_name)
transform_dir: Path to where transforms should be stored
input_transform_key: Name of input transform pipeline to use
target_transform_key: Name of target transform pipeline to use
Expand All @@ -140,6 +142,7 @@ def get_dataloader(
xr_data, transform, target_transform = get_dataset(
active_dataset_name,
model_src_dataset_name,
input_transform_dataset_name,
input_transform_key,
target_transform_key,
transform_dir,
Expand Down

0 comments on commit 8894f7f

Please sign in to comment.