Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interface to quantile map samples #29

Merged
merged 5 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions src/ml_downscaling_emulator/bin/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import xarray as xr

from mlde_utils import samples_path, samples_glob, TIME_PERIODS
from mlde_utils.training.dataset import open_raw_dataset_split

from ml_downscaling_emulator.postprocess import xrqm

logging.basicConfig(
level=logging.INFO,
Expand Down Expand Up @@ -64,3 +67,78 @@ def filter(
samples_ds.sel(time=slice(*TIME_PERIODS[time_period])).to_netcdf(
filtered_samples_filepath
)


@app.command()
def qm(
workdir: Path,
checkpoint: str = typer.Option(...),
train_dataset: str = typer.Option(...),
train_input_xfm: str = "stan",
eval_dataset: str = typer.Option(...),
eval_input_xfm: str = "stan",
split: str = "val",
ensemble_member: str = typer.Option(...),
):
pass
# to compute the mapping, use train split data

# open train split of dataset for the target_pr
sim_train_da = open_raw_dataset_split(train_dataset, "train").sel(
ensemble_member=ensemble_member
)["target_pr"]

# open sample of model from train split
ml_train_da = xr.open_dataset(
list(
samples_glob(
samples_path(
workdir,
checkpoint=checkpoint,
input_xfm=train_input_xfm,
dataset=train_dataset,
split="train",
ensemble_member=ensemble_member,
)
)
)[0]
)["pred_pr"]

ml_eval_samples_dirpath = samples_path(
workdir,
checkpoint=checkpoint,
input_xfm=eval_input_xfm,
dataset=eval_dataset,
split=split,
ensemble_member=ensemble_member,
)
logger.info(f"QMapping samplesin {ml_eval_samples_dirpath}")
for sample_filepath in samples_glob(ml_eval_samples_dirpath):
logger.info(f"Working on {sample_filepath}")
# open the samples to be qmapped
ml_eval_ds = xr.open_dataset(sample_filepath)

# do the qmapping
qmapped_eval_da = xrqm(sim_train_da, ml_train_da, ml_eval_ds["pred_pr"])

qmapped_eval_ds = ml_eval_ds.copy()
qmapped_eval_ds["pred_pr"] = qmapped_eval_da

# save output
new_workdir = workdir / "postprocess" / "qm"

qmapped_sample_filepath = (
samples_path(
new_workdir,
checkpoint=checkpoint,
input_xfm=eval_input_xfm,
dataset=eval_dataset,
split=split,
ensemble_member=ensemble_member,
)
/ sample_filepath.name
)

logger.info(f"Saving to {qmapped_sample_filepath}")
qmapped_sample_filepath.parent.mkdir(parents=True, exist_ok=True)
qmapped_eval_ds.to_netcdf(qmapped_sample_filepath)
70 changes: 70 additions & 0 deletions src/ml_downscaling_emulator/postprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Callable
import numpy as np
import xarray as xr


def _get_cdf(x, xbins):
pdf, _ = np.histogram(x, xbins)
return np.insert(np.cumsum(pdf), 0, 0.0)


def qm_1d_dom_aware(
obs: np.ndarray,
simh: np.ndarray,
simp: np.ndarray,
n_quantiles: int = 250,
kind: str = "+",
):
"""
A 1D quantile mapping function replacement for CMethods.quantile_mapping

Unlike the CMethods version it takes into account that the obs and simh may have different max and mins (i.e. their CDFs have different supported domains). The CMethods version just uses a domain between the min and max of both obs and simh.
"""
obs, simh, simp = np.array(obs), np.array(simh), np.array(simp)

obs_min = np.amin(obs)
obs_max = np.amax(obs)
wide = abs(obs_max - obs_min) / n_quantiles
xbins_obs = np.arange(obs_min, obs_max + wide, wide)

simh_min = np.amin(simh)
simh_max = np.amax(simh)
wide = abs(simh_max - simh_min) / n_quantiles
xbins_simh = np.arange(simh_min, simh_max + wide, wide)

cdf_obs = _get_cdf(obs, xbins_obs)
cdf_simh = _get_cdf(simh, xbins_simh)

epsilon = np.interp(simp, xbins_simh, cdf_simh)

return np.interp(epsilon, cdf_obs, xbins_obs)


def xrqm(
sim_train_da: xr.DataArray,
ml_train_da: xr.DataArray,
ml_eval_da: xr.DataArray,
qm_func: Callable = qm_1d_dom_aware,
):
"""Apply a 1D quantile mapping function point-by-point to a multi-dimensional xarray DataArray."""
return (
xr.apply_ufunc(
qm_func, # first the function
sim_train_da, # now arguments in the order expected by the function
ml_train_da,
ml_eval_da,
kwargs=dict(n_quantiles=250, kind="+"),
input_core_dims=[
["time"],
["time"],
["time"],
], # list with one entry per arg
output_core_dims=[["time"]],
exclude_dims=set(
("time",)
), # dimensions allowed to change size. Must be set!
vectorize=True,
)
.transpose("ensemble_member", "time", "grid_latitude", "grid_longitude")
.assign_coords(time=ml_eval_da["time"])
)
179 changes: 108 additions & 71 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Callable
import cftime
import numpy as np
import pytest
Expand All @@ -7,89 +8,125 @@

grid_longitude = xr.Variable(["grid_longitude"], np.linspace(-4, 4, 17), attrs={})

time = xr.Variable(
["time"],
xr.cftime_range(
cftime.Datetime360Day(1980, 12, 1, 12, 0, 0, 0, has_year_zero=True),
periods=10,

def build_time_dim(start_year: int = 1980, time_len: int = 10):
time = xr.Variable(
["time"],
xr.cftime_range(
cftime.Datetime360Day(start_year, 12, 1, 12, 0, 0, 0, has_year_zero=True),
periods=time_len,
freq="D",
),
)
time_bnds_values = xr.cftime_range(
cftime.Datetime360Day(start_year, 12, 1, 0, 0, 0, 0, has_year_zero=True),
periods=len(time) + 1,
freq="D",
),
)
time_bnds_values = xr.cftime_range(
cftime.Datetime360Day(1980, 12, 1, 0, 0, 0, 0, has_year_zero=True),
periods=len(time) + 1,
freq="D",
).values
time_bnds_pairs = np.concatenate(
[time_bnds_values[:-1, np.newaxis], time_bnds_values[1:, np.newaxis]], axis=1
)
).values
time_bnds_pairs = np.concatenate(
[time_bnds_values[:-1, np.newaxis], time_bnds_values[1:, np.newaxis]], axis=1
)

time_bnds = xr.Variable(["time", "bnds"], time_bnds_pairs, attrs={})

time_bnds = xr.Variable(["time", "bnds"], time_bnds_pairs, attrs={})
return time, time_bnds


@pytest.fixture
def samples_set() -> xr.Dataset:
"""Create a dummy Dataset that looks like a set of samples from the emulator."""

ensemble_member = xr.Variable(["ensemble_member"], np.array(["01"]))

coords = {
"ensemble_member": ensemble_member,
"time": time,
"grid_latitude": grid_latitude,
"grid_longitude": grid_longitude,
}

data_vars = {
"pred_pr": xr.Variable(
["ensemble_member", "time", "grid_latitude", "grid_longitude"],
np.random.rand(
len(ensemble_member), len(time), len(grid_latitude), len(grid_longitude)
def samples_factory() -> Callable[[int, int], xr.Dataset]:
"""Create a factory function for creating dummy xarray Datasets that look like samples from the emulator."""

def _samples_factory(start_year: int = 1980, time_len: int = 10) -> xr.Dataset:
ensemble_member = xr.Variable(["ensemble_member"], np.array(["01"]))
time, time_bnds = build_time_dim(start_year=start_year, time_len=time_len)
coords = {
"ensemble_member": ensemble_member,
"time": time,
"grid_latitude": grid_latitude,
"grid_longitude": grid_longitude,
}

data_vars = {
"pred_pr": xr.Variable(
["ensemble_member", "time", "grid_latitude", "grid_longitude"],
np.random.rand(
len(ensemble_member),
len(time),
len(grid_latitude),
len(grid_longitude),
),
),
),
"time_bnds": time_bnds,
}
"time_bnds": time_bnds,
}

ds = xr.Dataset(
data_vars=data_vars,
coords=coords,
)
ds = xr.Dataset(
data_vars=data_vars,
coords=coords,
)

return ds

return ds
return _samples_factory


@pytest.fixture
def dataset() -> xr.Dataset:
"""Create a dummy Dataset representing a split of a set of data for training and sampling."""

ensemble_member = xr.Variable(["ensemble_member"], np.array(["01", "02", "03"]))

coords = {
"ensemble_member": ensemble_member,
"time": time,
"grid_latitude": grid_latitude,
"grid_longitude": grid_longitude,
}

data_vars = {
"linpr": xr.Variable(
["ensemble_member", "time", "grid_latitude", "grid_longitude"],
np.random.rand(
len(ensemble_member), len(time), len(grid_latitude), len(grid_longitude)
def samples_set(samples_factory) -> xr.Dataset:
"""Create a dummy xarray Dataset that looks like a set of samples from the emulator."""

return samples_factory()


@pytest.fixture
def dataset_factory() -> Callable[[int, int], xr.Dataset]:
"""Create a factory function for creating dummy xarray Datasets that look like the training data."""

def _dataset_factory(start_year: int = 1980, time_len: int = 10) -> xr.Dataset:
ensemble_member = xr.Variable(
["ensemble_member"], np.array([f"{i:02}" for i in range(3)])
)

time, time_bnds = build_time_dim(start_year=start_year, time_len=time_len)

coords = {
"ensemble_member": ensemble_member,
"time": time,
"grid_latitude": grid_latitude,
"grid_longitude": grid_longitude,
}

data_vars = {
"linpr": xr.Variable(
["ensemble_member", "time", "grid_latitude", "grid_longitude"],
np.random.rand(
len(ensemble_member),
len(time),
len(grid_latitude),
len(grid_longitude),
),
),
),
"target_pr": xr.Variable(
["ensemble_member", "time", "grid_latitude", "grid_longitude"],
np.random.rand(
len(ensemble_member), len(time), len(grid_latitude), len(grid_longitude)
"target_pr": xr.Variable(
["ensemble_member", "time", "grid_latitude", "grid_longitude"],
np.random.rand(
len(ensemble_member),
len(time),
len(grid_latitude),
len(grid_longitude),
),
),
),
"time_bnds": time_bnds,
}
"time_bnds": time_bnds,
}

ds = xr.Dataset(
data_vars=data_vars,
coords=coords,
)
ds = xr.Dataset(
data_vars=data_vars,
coords=coords,
)

return ds
return ds

return _dataset_factory


@pytest.fixture
def dataset(dataset_factory) -> xr.Dataset:
"""Create a dummy xarray Dataset representing a split of a set of data for training and sampling."""
return dataset_factory()
Loading