From 1df23fc6a21727016dd462a86b58d13d5662630a Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Fri, 1 Sep 2023 11:36:17 +0100 Subject: [PATCH] add domain-aware implementation of 1d qm and tests and settle on just the apply_ufunc approach --- src/ml_downscaling_emulator/postprocess.py | 59 +++++++--- .../test_postprocess.py | 106 ++++++++++++++++++ 2 files changed, 148 insertions(+), 17 deletions(-) create mode 100644 tests/ml_downscaling_emulator/test_postprocess.py diff --git a/src/ml_downscaling_emulator/postprocess.py b/src/ml_downscaling_emulator/postprocess.py index 81e59ada0..283c68f51 100644 --- a/src/ml_downscaling_emulator/postprocess.py +++ b/src/ml_downscaling_emulator/postprocess.py @@ -1,30 +1,55 @@ -from cmethods import CMethods +from typing import Callable import numpy as np import xarray as xr -def qm(sim_train_da, ml_train_da, ml_eval_da): - values = np.zeros(ml_eval_da.shape, float) - for ilat in range(len(ml_eval_da["grid_latitude"])): - for ilon in range(len(ml_eval_da["grid_longitude"])): - values[:, ilat, ilon] = CMethods.quantile_mapping( - sim_train_da.isel(grid_latitude=ilat, grid_longitude=ilon), - ml_train_da.isel(grid_latitude=ilat, grid_longitude=ilon), - ml_eval_da.isel(grid_latitude=ilat, grid_longitude=ilon), - n_quantiles=250, - kind="+", - ) +def _get_cdf(x, xbins): + pdf, _ = np.histogram(x, xbins) + return np.insert(np.cumsum(pdf), 0, 0.0) - qmapped = xr.zeros_like(ml_eval_da) - qmapped.data = values - return qmapped +def qm_1d_dom_aware( + obs: np.ndarray, + simh: np.ndarray, + simp: np.ndarray, + n_quantiles: int = 250, + kind: str = "+", +): + """ + A 1D quantile mapping function replacement for CMethods.quantile_mapping + Unlike the CMethods version it takes into account that the obs and simh may have different max and mins (i.e. their CDFs have different supported domains). The CMethods version just uses a domain between the min and max of both obs and simh. + """ + obs, simh, simp = np.array(obs), np.array(simh), np.array(simp) -def qm_vec(sim_train_da, ml_train_da, ml_eval_da): + obs_min = np.amin(obs) + obs_max = np.amax(obs) + wide = abs(obs_max - obs_min) / n_quantiles + xbins_obs = np.arange(obs_min, obs_max + wide, wide) + + simh_min = np.amin(simh) + simh_max = np.amax(simh) + wide = abs(simh_max - simh_min) / n_quantiles + xbins_simh = np.arange(simh_min, simh_max + wide, wide) + + cdf_obs = _get_cdf(obs, xbins_obs) + cdf_simh = _get_cdf(simh, xbins_simh) + + epsilon = np.interp(simp, xbins_simh, cdf_simh) + + return np.interp(epsilon, cdf_obs, xbins_obs) + + +def xrqm( + sim_train_da: xr.DataArray, + ml_train_da: xr.DataArray, + ml_eval_da: xr.DataArray, + qm_func: Callable = qm_1d_dom_aware, +): + """Apply a 1D quantile mapping function point-by-point to a multi-dimensional xarray DataArray.""" return ( xr.apply_ufunc( - CMethods.quantile_mapping, # first the function + qm_func, # first the function sim_train_da, # now arguments in the order expected by the function ml_train_da, ml_eval_da, diff --git a/tests/ml_downscaling_emulator/test_postprocess.py b/tests/ml_downscaling_emulator/test_postprocess.py new file mode 100644 index 000000000..187d01024 --- /dev/null +++ b/tests/ml_downscaling_emulator/test_postprocess.py @@ -0,0 +1,106 @@ +import numpy as np +import pytest +import xarray as xr + +from ml_downscaling_emulator.postprocess import xrqm, qm_1d_dom_aware + + +def test_qm_applies_qm_at_each_gridbox(sim_train_da, ml_train_da, ml_eval_da): + qm_ml_eval_da = xrqm(sim_train_da, ml_train_da, ml_eval_da) + + for ilat in range(len(ml_eval_da["grid_latitude"])): + for ilon in range(len(ml_eval_da["grid_longitude"])): + exp_value = qm_1d_dom_aware( + sim_train_da.isel(grid_latitude=ilat, grid_longitude=ilon), + ml_train_da.isel(grid_latitude=ilat, grid_longitude=ilon), + ml_eval_da.isel(grid_latitude=ilat, grid_longitude=ilon), + n_quantiles=250, + ) + + np.testing.assert_allclose( + exp_value, qm_ml_eval_da.isel(grid_latitude=ilat, grid_longitude=ilon) + ) + + +def test_all_train_qm_match_sim_quantiles(sim_train_da, ml_train_da): + qm_ml_train_da = xrqm(sim_train_da, ml_train_da, ml_train_da) + + np.testing.assert_allclose( + sim_train_da.quantile([0.1, 0.25, 0.5, 0.75, 0.9], dim="time"), + qm_ml_train_da.quantile([0.1, 0.25, 0.5, 0.75, 0.9], dim="time"), + rtol=5e-2, + ) + + +def test_all_train_qm_match_sim_histogram(sim_train_da, ml_train_da): + qm_ml_train_da = xrqm(sim_train_da, ml_train_da, ml_train_da) + + sim_ns, bins = np.histogram(sim_train_da, range=(-5.0, 5.0), bins=20) + qm_ml_ns, bins = np.histogram(qm_ml_train_da, bins=bins) + + np.testing.assert_allclose(sim_ns, qm_ml_ns, atol=200) + + np.testing.assert_allclose(np.abs(sim_ns - qm_ml_ns).sum(), 0.0, atol=500) + + +@pytest.fixture +def time_range(): + return np.linspace(-2, 2, 20000) + + +@pytest.fixture +def lat_range(): + return np.linspace(-2, 2, 7) + + +@pytest.fixture +def lon_range(): + return np.linspace(-2, 2, 3) + + +@pytest.fixture +def sim_train_da(time_range, lat_range, lon_range): + rng = np.random.default_rng() + return xr.DataArray( + data=rng.normal( + loc=1.0, size=(len(time_range), len(lat_range), len(lon_range)) + ), + dims=["time", "grid_latitude", "grid_longitude"], + name="target_pr", + coords=dict( + time=(["time"], time_range), + grid_latitude=(["grid_latitude"], lat_range), + grid_longitude=(["grid_longitude"], lon_range), + ), + ) + + +@pytest.fixture +def ml_train_da(time_range, lat_range, lon_range): + rng = np.random.default_rng() + return xr.DataArray( + data=rng.normal(size=(len(time_range), len(lat_range), len(lon_range))), + dims=["time", "grid_latitude", "grid_longitude"], + name="pred_pr", + coords=dict( + time=(["time"], time_range), + grid_latitude=(["grid_latitude"], lat_range), + grid_longitude=(["grid_longitude"], lon_range), + ), + ) + + +@pytest.fixture +def ml_eval_da(lat_range, lon_range): + eval_time_range = np.linspace(3, 4, 50) + rng = np.random.default_rng() + return xr.DataArray( + data=rng.normal(size=(len(eval_time_range), len(lat_range), len(lon_range))), + dims=["time", "grid_latitude", "grid_longitude"], + name="pred_pr", + coords=dict( + time=(["time"], eval_time_range), + grid_latitude=(["grid_latitude"], lat_range), + grid_longitude=(["grid_longitude"], lon_range), + ), + )