Skip to content

Commit

Permalink
allow overriding of target transform per variable
Browse files Browse the repository at this point in the history
though I think this can't be done entirely on the fly from CLI yet
  • Loading branch information
henryaddison committed Oct 4, 2024
1 parent 279cacf commit 122b88d
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 10 deletions.
11 changes: 10 additions & 1 deletion bin/predict.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Generate samples"""

from collections import defaultdict
import itertools
import os
from pathlib import Path
Expand Down Expand Up @@ -234,6 +235,10 @@ def main(
config.data.input_transform_dataset = input_transform_dataset
else:
config.data.input_transform_dataset = dataset

if "target_transform_overrides" not in config.data:
config.data.target_transform_overrides = config_dict.ConfigDict()

if input_transform_key is not None:
config.data.input_transform_key = input_transform_key

Expand All @@ -253,13 +258,17 @@ def main(

transform_dir = os.path.join(workdir, "transforms")

target_xfm_keys = defaultdict(lambda: config.data.target_transform_key) | dict(
config.data.target_transform_overrides
)

# Data
eval_dl, _, target_transform = get_dataloader(
dataset,
config.data.dataset_name,
config.data.input_transform_dataset,
config.data.input_transform_key,
config.data.target_transform_key,
target_xfm_keys,
transform_dir,
split=split,
ensemble_members=[ensemble_member],
Expand Down
2 changes: 1 addition & 1 deletion environment.lock.yml
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ dependencies:
- pip:
- atpublic==3.1.1
- flufl-lock==7.1.1
- mlde-utils==0.2.0a4
- mlde-utils==0.2.0a5
- netcdf4==1.6.3
- python-cmethods==1.0.1
prefix: /home/henry/miniforge3/envs/mv-mlde
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
flufl-lock
mlde-utils~=0.2.0a4
mlde-utils~=0.2.0a5
python-cmethods
6 changes: 3 additions & 3 deletions src/ml_downscaling_emulator/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def get_dataloader(
model_src_dataset_name,
input_transform_dataset_name,
input_transform_key,
target_transform_key,
target_transform_keys,
transform_dir,
batch_size,
split,
Expand All @@ -126,7 +126,7 @@ def get_dataloader(
input_transform_dataset_name: Name of dataset to use for fitting input transform (may be the same as active_dataset_name or model_src_dataset_name)
transform_dir: Path to where transforms should be stored
input_transform_key: Name of input transform pipeline to use
target_transform_key: Name of target transform pipeline to use
target_transform_keys: Mapping from target variable name to target transform pipeline to use
batch_size: Size of batch to use for DataLoaders
split: Split of the active dataset to load
evaluation: If `True`, fix number of epochs to 1.
Expand All @@ -139,7 +139,7 @@ def get_dataloader(
model_src_dataset_name,
input_transform_dataset_name,
input_transform_key,
target_transform_key,
target_transform_keys,
transform_dir,
split,
ensemble_members,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def get_default_configs():
data.input_transform_dataset = None
data.input_transform_key = "stan"
data.target_transform_key = "sqrturrecen"
data.target_transform_overrides = ml_collections.ConfigDict()

data.time_inputs = False

# model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_config():
data = config.data
data.centered = True
data.dataset_name = 'bham64_ccpm-4x_12em_mv'
data.target_transform_overrides.target_tmean150cm = "mm;recen"

# model
model = config.model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_config():
data = config.data
data.centered = True
data.dataset_name = 'debug-sample-mv'
data.target_transform_overrides.target_tmean150cm = "mm;recen"

# model
model = config.model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,6 @@ def get_config():

# data
data = config.data
data.target_transform_key = 'stanmmrecen'
data.target_transform_key = 'mm;recen'

return config
9 changes: 6 additions & 3 deletions src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# pylint: skip-file
"""Training for score-based generative models. """

from collections import defaultdict
import itertools
import os

Expand Down Expand Up @@ -100,11 +101,13 @@ def train(config, workdir):
tb_dir = os.path.join(workdir, "tensorboard")
os.makedirs(tb_dir, exist_ok=True)

target_xfm_keys = defaultdict(lambda: config.data.target_transform_key) | dict(config.data.target_transform_overrides)

run_name = os.path.basename(workdir)
run_config = dict(
dataset=config.data.dataset_name,
input_transform_key=config.data.input_transform_key,
target_transform_key=config.data.target_transform_key,
target_transform_keys=target_xfm_keys,
architecture=config.model.name,
sde=config.training.sde,
name=run_name,
Expand All @@ -115,8 +118,8 @@ def train(config, workdir):
) as (wandb_run, writer):
# Build dataloaders
dataset_meta = DatasetMetadata(config.data.dataset_name)
train_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="train", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False)
eval_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="val", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False, shuffle=False)
train_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, target_xfm_keys, transform_dir, batch_size=config.training.batch_size, split="train", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False)
eval_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, target_xfm_keys, transform_dir, batch_size=config.training.batch_size, split="val", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False, shuffle=False)

# Initialize model.
score_model = mutils.create_model(config)
Expand Down

0 comments on commit 122b88d

Please sign in to comment.