From a1cfde5f82c8b7e6ac86760aefe28dcc03898fd4 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Tue, 12 Sep 2023 11:06:14 +0100 Subject: [PATCH] update fixtures so can have different samples --- tests/conftest.py | 179 +++++++++++------- .../test_postprocess.py | 46 +---- 2 files changed, 114 insertions(+), 111 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 321703576..3b0cfb344 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +from typing import Callable import cftime import numpy as np import pytest @@ -7,89 +8,125 @@ grid_longitude = xr.Variable(["grid_longitude"], np.linspace(-4, 4, 17), attrs={}) -time = xr.Variable( - ["time"], - xr.cftime_range( - cftime.Datetime360Day(1980, 12, 1, 12, 0, 0, 0, has_year_zero=True), - periods=10, + +def build_time_dim(start_year: int = 1980, time_len: int = 10): + time = xr.Variable( + ["time"], + xr.cftime_range( + cftime.Datetime360Day(start_year, 12, 1, 12, 0, 0, 0, has_year_zero=True), + periods=time_len, + freq="D", + ), + ) + time_bnds_values = xr.cftime_range( + cftime.Datetime360Day(start_year, 12, 1, 0, 0, 0, 0, has_year_zero=True), + periods=len(time) + 1, freq="D", - ), -) -time_bnds_values = xr.cftime_range( - cftime.Datetime360Day(1980, 12, 1, 0, 0, 0, 0, has_year_zero=True), - periods=len(time) + 1, - freq="D", -).values -time_bnds_pairs = np.concatenate( - [time_bnds_values[:-1, np.newaxis], time_bnds_values[1:, np.newaxis]], axis=1 -) + ).values + time_bnds_pairs = np.concatenate( + [time_bnds_values[:-1, np.newaxis], time_bnds_values[1:, np.newaxis]], axis=1 + ) + + time_bnds = xr.Variable(["time", "bnds"], time_bnds_pairs, attrs={}) -time_bnds = xr.Variable(["time", "bnds"], time_bnds_pairs, attrs={}) + return time, time_bnds @pytest.fixture -def samples_set() -> xr.Dataset: - """Create a dummy Dataset that looks like a set of samples from the emulator.""" - - ensemble_member = xr.Variable(["ensemble_member"], np.array(["01"])) - - coords = { - "ensemble_member": ensemble_member, - "time": time, - "grid_latitude": grid_latitude, - "grid_longitude": grid_longitude, - } - - data_vars = { - "pred_pr": xr.Variable( - ["ensemble_member", "time", "grid_latitude", "grid_longitude"], - np.random.rand( - len(ensemble_member), len(time), len(grid_latitude), len(grid_longitude) +def samples_factory() -> Callable[[int, int], xr.Dataset]: + """Create a factory function for creating dummy xarray Datasets that look like samples from the emulator.""" + + def _samples_factory(start_year: int = 1980, time_len: int = 10) -> xr.Dataset: + ensemble_member = xr.Variable(["ensemble_member"], np.array(["01"])) + time, time_bnds = build_time_dim(start_year=start_year, time_len=time_len) + coords = { + "ensemble_member": ensemble_member, + "time": time, + "grid_latitude": grid_latitude, + "grid_longitude": grid_longitude, + } + + data_vars = { + "pred_pr": xr.Variable( + ["ensemble_member", "time", "grid_latitude", "grid_longitude"], + np.random.rand( + len(ensemble_member), + len(time), + len(grid_latitude), + len(grid_longitude), + ), ), - ), - "time_bnds": time_bnds, - } + "time_bnds": time_bnds, + } - ds = xr.Dataset( - data_vars=data_vars, - coords=coords, - ) + ds = xr.Dataset( + data_vars=data_vars, + coords=coords, + ) + + return ds - return ds + return _samples_factory @pytest.fixture -def dataset() -> xr.Dataset: - """Create a dummy Dataset representing a split of a set of data for training and sampling.""" - - ensemble_member = xr.Variable(["ensemble_member"], np.array(["01", "02", "03"])) - - coords = { - "ensemble_member": ensemble_member, - "time": time, - "grid_latitude": grid_latitude, - "grid_longitude": grid_longitude, - } - - data_vars = { - "linpr": xr.Variable( - ["ensemble_member", "time", "grid_latitude", "grid_longitude"], - np.random.rand( - len(ensemble_member), len(time), len(grid_latitude), len(grid_longitude) +def samples_set(samples_factory) -> xr.Dataset: + """Create a dummy xarray Dataset that looks like a set of samples from the emulator.""" + + return samples_factory() + + +@pytest.fixture +def dataset_factory() -> Callable[[int, int], xr.Dataset]: + """Create a factory function for creating dummy xarray Datasets that look like the training data.""" + + def _dataset_factory(start_year: int = 1980, time_len: int = 10) -> xr.Dataset: + ensemble_member = xr.Variable( + ["ensemble_member"], np.array([f"{i:02}" for i in range(3)]) + ) + + time, time_bnds = build_time_dim(start_year=start_year, time_len=time_len) + + coords = { + "ensemble_member": ensemble_member, + "time": time, + "grid_latitude": grid_latitude, + "grid_longitude": grid_longitude, + } + + data_vars = { + "linpr": xr.Variable( + ["ensemble_member", "time", "grid_latitude", "grid_longitude"], + np.random.rand( + len(ensemble_member), + len(time), + len(grid_latitude), + len(grid_longitude), + ), ), - ), - "target_pr": xr.Variable( - ["ensemble_member", "time", "grid_latitude", "grid_longitude"], - np.random.rand( - len(ensemble_member), len(time), len(grid_latitude), len(grid_longitude) + "target_pr": xr.Variable( + ["ensemble_member", "time", "grid_latitude", "grid_longitude"], + np.random.rand( + len(ensemble_member), + len(time), + len(grid_latitude), + len(grid_longitude), + ), ), - ), - "time_bnds": time_bnds, - } + "time_bnds": time_bnds, + } - ds = xr.Dataset( - data_vars=data_vars, - coords=coords, - ) + ds = xr.Dataset( + data_vars=data_vars, + coords=coords, + ) - return ds + return ds + + return _dataset_factory + + +@pytest.fixture +def dataset(dataset_factory) -> xr.Dataset: + """Create a dummy xarray Dataset representing a split of a set of data for training and sampling.""" + return dataset_factory() diff --git a/tests/ml_downscaling_emulator/test_postprocess.py b/tests/ml_downscaling_emulator/test_postprocess.py index 187d01024..5a9ef1acf 100644 --- a/tests/ml_downscaling_emulator/test_postprocess.py +++ b/tests/ml_downscaling_emulator/test_postprocess.py @@ -1,6 +1,5 @@ import numpy as np import pytest -import xarray as xr from ml_downscaling_emulator.postprocess import xrqm, qm_1d_dom_aware @@ -59,48 +58,15 @@ def lon_range(): @pytest.fixture -def sim_train_da(time_range, lat_range, lon_range): - rng = np.random.default_rng() - return xr.DataArray( - data=rng.normal( - loc=1.0, size=(len(time_range), len(lat_range), len(lon_range)) - ), - dims=["time", "grid_latitude", "grid_longitude"], - name="target_pr", - coords=dict( - time=(["time"], time_range), - grid_latitude=(["grid_latitude"], lat_range), - grid_longitude=(["grid_longitude"], lon_range), - ), - ) +def sim_train_da(dataset_factory): + return dataset_factory(time_len=2000).sel(ensemble_member=["01"])["target_pr"] @pytest.fixture -def ml_train_da(time_range, lat_range, lon_range): - rng = np.random.default_rng() - return xr.DataArray( - data=rng.normal(size=(len(time_range), len(lat_range), len(lon_range))), - dims=["time", "grid_latitude", "grid_longitude"], - name="pred_pr", - coords=dict( - time=(["time"], time_range), - grid_latitude=(["grid_latitude"], lat_range), - grid_longitude=(["grid_longitude"], lon_range), - ), - ) +def ml_train_da(samples_factory): + return samples_factory(time_len=2000)["pred_pr"] @pytest.fixture -def ml_eval_da(lat_range, lon_range): - eval_time_range = np.linspace(3, 4, 50) - rng = np.random.default_rng() - return xr.DataArray( - data=rng.normal(size=(len(eval_time_range), len(lat_range), len(lon_range))), - dims=["time", "grid_latitude", "grid_longitude"], - name="pred_pr", - coords=dict( - time=(["time"], eval_time_range), - grid_latitude=(["grid_latitude"], lat_range), - grid_longitude=(["grid_longitude"], lon_range), - ), - ) +def ml_eval_da(samples_factory): + return samples_factory(start_year=2060, time_len=1000)["pred_pr"]