Skip to content

Commit

Permalink
make predict.py multivariate aware
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Mar 29, 2024
1 parent 26a3bff commit e4ee079
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 43 deletions.
61 changes: 18 additions & 43 deletions bin/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from dotenv import load_dotenv
from knockknock import slack_sender
from ml_collections import config_dict
import numpy as np
import shortuuid
import torch
import typer
Expand All @@ -18,7 +17,7 @@
import xarray as xr
import yaml

from ml_downscaling_emulator.data import get_dataloader
from ml_downscaling_emulator.data import get_dataloader, np_samples_to_xr
from mlde_utils import samples_path, DEFAULT_ENSEMBLE_MEMBER
from mlde_utils.training.dataset import get_variables

Expand Down Expand Up @@ -129,7 +128,8 @@ def load_model(config, ckpt_filename):
state["ema"].copy_to(state["model"].parameters())

# Sampling
num_output_channels = len(get_variables(config.data.dataset_name)[1])
input_variables, target_vars = get_variables(config.data.dataset_name)
num_output_channels = len(target_vars)
sampling_shape = (
config.eval.batch_size,
num_output_channels,
Expand All @@ -138,48 +138,20 @@ def load_model(config, ckpt_filename):
)
sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, sampling_eps)

return state, sampling_fn
return state, sampling_fn, target_vars


def generate_np_samples(sampling_fn, score_model, config, cond_batch):
def generate_np_sample_batch(sampling_fn, score_model, config, cond_batch):
cond_batch = cond_batch.to(config.device)

samples = sampling_fn(score_model, cond_batch)[0]
# drop the feature channel dimension (only have target pr as output)
samples = samples.squeeze(dim=1)

# extract numpy array
samples = samples.cpu().numpy()
return samples


def np_samples_to_xr(np_samples, target_transform, coords, cf_data_vars):
coords = {**dict(coords)}

pred_pr_dims = ["ensemble_member", "time", "grid_latitude", "grid_longitude"]
pred_pr_attrs = {
"grid_mapping": "rotated_latitude_longitude",
"standard_name": "pred_pr",
"units": "kg m-2 s-1",
}
# add ensemble member axis to np samples
np_samples = np_samples[np.newaxis, :]
pred_pr_var = (pred_pr_dims, np_samples, pred_pr_attrs)
raw_pred_var = (
pred_pr_dims,
np_samples,
{"grid_mapping": "rotated_latitude_longitude"},
)
data_vars = {**cf_data_vars, "target_pr": pred_pr_var, "raw_pred": raw_pred_var}

samples_ds = target_transform.invert(
xr.Dataset(data_vars=data_vars, coords=coords, attrs={})
)
samples_ds = samples_ds.rename({"target_pr": "pred_pr"})
samples_ds["pred_pr"] = samples_ds["pred_pr"].assign_attrs(pred_pr_attrs)
return samples_ds


def sample(sampling_fn, state, config, eval_dl, target_transform):
def sample(sampling_fn, state, config, eval_dl, target_transform, target_vars):
score_model = state["model"]
location_params = state["location_params"]

Expand All @@ -193,7 +165,7 @@ def sample(sampling_fn, state, config, eval_dl, target_transform):
]
}

preds = []
xr_sample_batches = []
with logging_redirect_tqdm():
with tqdm(
total=len(eval_dl.dataset),
Expand All @@ -206,23 +178,24 @@ def sample(sampling_fn, state, config, eval_dl, target_transform):

coords = eval_dl.dataset.ds.sel(time=time_batch).coords

np_samples = generate_np_samples(
np_sample_batch = generate_np_sample_batch(
sampling_fn, score_model, config, cond_batch
)

xr_samples = np_samples_to_xr(
np_samples,
xr_sample_batch = np_samples_to_xr(
np_sample_batch,
target_transform,
target_vars,
coords,
cf_data_vars,
)

preds.append(xr_samples)
xr_sample_batches.append(xr_sample_batch)

pbar.update(cond_batch.shape[0])

ds = xr.combine_by_coords(
preds,
xr_sample_batches,
compat="no_conflicts",
combine_attrs="drop_conflicts",
coords="all",
Expand Down Expand Up @@ -292,11 +265,13 @@ def main(

ckpt_filename = os.path.join(workdir, "checkpoints", f"{checkpoint}.pth")
logger.info(f"Loading model from {ckpt_filename}")
state, sampling_fn = load_model(config, ckpt_filename)
state, sampling_fn, target_vars = load_model(config, ckpt_filename)

for sample_id in range(num_samples):
typer.echo(f"Sample run {sample_id}...")
xr_samples = sample(sampling_fn, state, config, eval_dl, target_transform)
xr_samples = sample(
sampling_fn, state, config, eval_dl, target_transform, target_vars
)

output_filepath = output_dirpath / f"predictions-{shortuuid.uuid()}.nc"

Expand Down
49 changes: 49 additions & 0 deletions src/ml_downscaling_emulator/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import xarray as xr

from mlde_utils.training.dataset import get_dataset, get_variables

Expand Down Expand Up @@ -157,3 +158,51 @@ def get_dataloader(
)

return data_loader, transform, target_transform


def np_samples_to_xr(np_samples, target_transform, target_vars, coords, cf_data_vars):
"""
Convert samples from a model in numpy format to an xarray Dataset, including inverting any transformation applied to the target variables before modelling.
"""
coords = {**dict(coords)}

pred_dims = ["ensemble_member", "time", "grid_latitude", "grid_longitude"]

data_vars = {**cf_data_vars}
for var_idx, var in enumerate(target_vars):
# add ensemble member axis to np samples and get just values for current variable
np_var_pred = np_samples[np.newaxis, :, var_idx, :]
pred_attrs = {
"grid_mapping": "rotated_latitude_longitude",
"standard_name": var.replace("target_", "pred_"),
# "units": "kg m-2 s-1",
}
pred_var = (pred_dims, np_var_pred, pred_attrs)
raw_pred_var = (
pred_dims,
{"grid_mapping": "rotated_latitude_longitude"},
)
data_vars.update(
{
var.replace("target_", "pred_"): pred_var,
var.replace("target_", "raw_pred_"): raw_pred_var,
}
)

samples_ds = target_transform.invert(
xr.Dataset(data_vars=data_vars, coords=coords, attrs={})
)
samples_ds = samples_ds.rename(
{var: var.replace("target_", "pred_") for var in target_vars}
)

for var_idx, var in enumerate(target_vars):
pred_attrs = {
"grid_mapping": "rotated_latitude_longitude",
"standard_name": var.replace("target_", "pred_"),
# "units": "kg m-2 s-1",
}
samples_ds[var.replace("target_", "pred_")] = samples_ds[
var.replace("target_", "pred_")
].assign_attrs(pred_attrs)
return samples_ds

0 comments on commit e4ee079

Please sign in to comment.