Skip to content

Commit

Permalink
add domain-aware implementation of 1d qm
Browse files Browse the repository at this point in the history
and tests and settle on just the apply_ufunc approach
  • Loading branch information
henryaddison committed Sep 11, 2023
1 parent b21ccc9 commit 1df23fc
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 17 deletions.
59 changes: 42 additions & 17 deletions src/ml_downscaling_emulator/postprocess.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,55 @@
from cmethods import CMethods
from typing import Callable
import numpy as np
import xarray as xr


def qm(sim_train_da, ml_train_da, ml_eval_da):
values = np.zeros(ml_eval_da.shape, float)
for ilat in range(len(ml_eval_da["grid_latitude"])):
for ilon in range(len(ml_eval_da["grid_longitude"])):
values[:, ilat, ilon] = CMethods.quantile_mapping(
sim_train_da.isel(grid_latitude=ilat, grid_longitude=ilon),
ml_train_da.isel(grid_latitude=ilat, grid_longitude=ilon),
ml_eval_da.isel(grid_latitude=ilat, grid_longitude=ilon),
n_quantiles=250,
kind="+",
)
def _get_cdf(x, xbins):
pdf, _ = np.histogram(x, xbins)
return np.insert(np.cumsum(pdf), 0, 0.0)

qmapped = xr.zeros_like(ml_eval_da)
qmapped.data = values

return qmapped
def qm_1d_dom_aware(
obs: np.ndarray,
simh: np.ndarray,
simp: np.ndarray,
n_quantiles: int = 250,
kind: str = "+",
):
"""
A 1D quantile mapping function replacement for CMethods.quantile_mapping
Unlike the CMethods version it takes into account that the obs and simh may have different max and mins (i.e. their CDFs have different supported domains). The CMethods version just uses a domain between the min and max of both obs and simh.
"""
obs, simh, simp = np.array(obs), np.array(simh), np.array(simp)

def qm_vec(sim_train_da, ml_train_da, ml_eval_da):
obs_min = np.amin(obs)
obs_max = np.amax(obs)
wide = abs(obs_max - obs_min) / n_quantiles
xbins_obs = np.arange(obs_min, obs_max + wide, wide)

simh_min = np.amin(simh)
simh_max = np.amax(simh)
wide = abs(simh_max - simh_min) / n_quantiles
xbins_simh = np.arange(simh_min, simh_max + wide, wide)

cdf_obs = _get_cdf(obs, xbins_obs)
cdf_simh = _get_cdf(simh, xbins_simh)

epsilon = np.interp(simp, xbins_simh, cdf_simh)

return np.interp(epsilon, cdf_obs, xbins_obs)


def xrqm(
sim_train_da: xr.DataArray,
ml_train_da: xr.DataArray,
ml_eval_da: xr.DataArray,
qm_func: Callable = qm_1d_dom_aware,
):
"""Apply a 1D quantile mapping function point-by-point to a multi-dimensional xarray DataArray."""
return (
xr.apply_ufunc(
CMethods.quantile_mapping, # first the function
qm_func, # first the function
sim_train_da, # now arguments in the order expected by the function
ml_train_da,
ml_eval_da,
Expand Down
106 changes: 106 additions & 0 deletions tests/ml_downscaling_emulator/test_postprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import numpy as np
import pytest
import xarray as xr

from ml_downscaling_emulator.postprocess import xrqm, qm_1d_dom_aware


def test_qm_applies_qm_at_each_gridbox(sim_train_da, ml_train_da, ml_eval_da):
qm_ml_eval_da = xrqm(sim_train_da, ml_train_da, ml_eval_da)

for ilat in range(len(ml_eval_da["grid_latitude"])):
for ilon in range(len(ml_eval_da["grid_longitude"])):
exp_value = qm_1d_dom_aware(
sim_train_da.isel(grid_latitude=ilat, grid_longitude=ilon),
ml_train_da.isel(grid_latitude=ilat, grid_longitude=ilon),
ml_eval_da.isel(grid_latitude=ilat, grid_longitude=ilon),
n_quantiles=250,
)

np.testing.assert_allclose(
exp_value, qm_ml_eval_da.isel(grid_latitude=ilat, grid_longitude=ilon)
)


def test_all_train_qm_match_sim_quantiles(sim_train_da, ml_train_da):
qm_ml_train_da = xrqm(sim_train_da, ml_train_da, ml_train_da)

np.testing.assert_allclose(
sim_train_da.quantile([0.1, 0.25, 0.5, 0.75, 0.9], dim="time"),
qm_ml_train_da.quantile([0.1, 0.25, 0.5, 0.75, 0.9], dim="time"),
rtol=5e-2,
)


def test_all_train_qm_match_sim_histogram(sim_train_da, ml_train_da):
qm_ml_train_da = xrqm(sim_train_da, ml_train_da, ml_train_da)

sim_ns, bins = np.histogram(sim_train_da, range=(-5.0, 5.0), bins=20)
qm_ml_ns, bins = np.histogram(qm_ml_train_da, bins=bins)

np.testing.assert_allclose(sim_ns, qm_ml_ns, atol=200)

np.testing.assert_allclose(np.abs(sim_ns - qm_ml_ns).sum(), 0.0, atol=500)


@pytest.fixture
def time_range():
return np.linspace(-2, 2, 20000)


@pytest.fixture
def lat_range():
return np.linspace(-2, 2, 7)


@pytest.fixture
def lon_range():
return np.linspace(-2, 2, 3)


@pytest.fixture
def sim_train_da(time_range, lat_range, lon_range):
rng = np.random.default_rng()
return xr.DataArray(
data=rng.normal(
loc=1.0, size=(len(time_range), len(lat_range), len(lon_range))
),
dims=["time", "grid_latitude", "grid_longitude"],
name="target_pr",
coords=dict(
time=(["time"], time_range),
grid_latitude=(["grid_latitude"], lat_range),
grid_longitude=(["grid_longitude"], lon_range),
),
)


@pytest.fixture
def ml_train_da(time_range, lat_range, lon_range):
rng = np.random.default_rng()
return xr.DataArray(
data=rng.normal(size=(len(time_range), len(lat_range), len(lon_range))),
dims=["time", "grid_latitude", "grid_longitude"],
name="pred_pr",
coords=dict(
time=(["time"], time_range),
grid_latitude=(["grid_latitude"], lat_range),
grid_longitude=(["grid_longitude"], lon_range),
),
)


@pytest.fixture
def ml_eval_da(lat_range, lon_range):
eval_time_range = np.linspace(3, 4, 50)
rng = np.random.default_rng()
return xr.DataArray(
data=rng.normal(size=(len(eval_time_range), len(lat_range), len(lon_range))),
dims=["time", "grid_latitude", "grid_longitude"],
name="pred_pr",
coords=dict(
time=(["time"], eval_time_range),
grid_latitude=(["grid_latitude"], lat_range),
grid_longitude=(["grid_longitude"], lon_range),
),
)

0 comments on commit 1df23fc

Please sign in to comment.