Skip to content

Commit

Permalink
update fixtures so can have different samples
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Sep 12, 2023
1 parent 4483f78 commit a1cfde5
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 111 deletions.
179 changes: 108 additions & 71 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Callable
import cftime
import numpy as np
import pytest
Expand All @@ -7,89 +8,125 @@

grid_longitude = xr.Variable(["grid_longitude"], np.linspace(-4, 4, 17), attrs={})

time = xr.Variable(
["time"],
xr.cftime_range(
cftime.Datetime360Day(1980, 12, 1, 12, 0, 0, 0, has_year_zero=True),
periods=10,

def build_time_dim(start_year: int = 1980, time_len: int = 10):
time = xr.Variable(
["time"],
xr.cftime_range(
cftime.Datetime360Day(start_year, 12, 1, 12, 0, 0, 0, has_year_zero=True),
periods=time_len,
freq="D",
),
)
time_bnds_values = xr.cftime_range(
cftime.Datetime360Day(start_year, 12, 1, 0, 0, 0, 0, has_year_zero=True),
periods=len(time) + 1,
freq="D",
),
)
time_bnds_values = xr.cftime_range(
cftime.Datetime360Day(1980, 12, 1, 0, 0, 0, 0, has_year_zero=True),
periods=len(time) + 1,
freq="D",
).values
time_bnds_pairs = np.concatenate(
[time_bnds_values[:-1, np.newaxis], time_bnds_values[1:, np.newaxis]], axis=1
)
).values
time_bnds_pairs = np.concatenate(
[time_bnds_values[:-1, np.newaxis], time_bnds_values[1:, np.newaxis]], axis=1
)

time_bnds = xr.Variable(["time", "bnds"], time_bnds_pairs, attrs={})

time_bnds = xr.Variable(["time", "bnds"], time_bnds_pairs, attrs={})
return time, time_bnds


@pytest.fixture
def samples_set() -> xr.Dataset:
"""Create a dummy Dataset that looks like a set of samples from the emulator."""

ensemble_member = xr.Variable(["ensemble_member"], np.array(["01"]))

coords = {
"ensemble_member": ensemble_member,
"time": time,
"grid_latitude": grid_latitude,
"grid_longitude": grid_longitude,
}

data_vars = {
"pred_pr": xr.Variable(
["ensemble_member", "time", "grid_latitude", "grid_longitude"],
np.random.rand(
len(ensemble_member), len(time), len(grid_latitude), len(grid_longitude)
def samples_factory() -> Callable[[int, int], xr.Dataset]:
"""Create a factory function for creating dummy xarray Datasets that look like samples from the emulator."""

def _samples_factory(start_year: int = 1980, time_len: int = 10) -> xr.Dataset:
ensemble_member = xr.Variable(["ensemble_member"], np.array(["01"]))
time, time_bnds = build_time_dim(start_year=start_year, time_len=time_len)
coords = {
"ensemble_member": ensemble_member,
"time": time,
"grid_latitude": grid_latitude,
"grid_longitude": grid_longitude,
}

data_vars = {
"pred_pr": xr.Variable(
["ensemble_member", "time", "grid_latitude", "grid_longitude"],
np.random.rand(
len(ensemble_member),
len(time),
len(grid_latitude),
len(grid_longitude),
),
),
),
"time_bnds": time_bnds,
}
"time_bnds": time_bnds,
}

ds = xr.Dataset(
data_vars=data_vars,
coords=coords,
)
ds = xr.Dataset(
data_vars=data_vars,
coords=coords,
)

return ds

return ds
return _samples_factory


@pytest.fixture
def dataset() -> xr.Dataset:
"""Create a dummy Dataset representing a split of a set of data for training and sampling."""

ensemble_member = xr.Variable(["ensemble_member"], np.array(["01", "02", "03"]))

coords = {
"ensemble_member": ensemble_member,
"time": time,
"grid_latitude": grid_latitude,
"grid_longitude": grid_longitude,
}

data_vars = {
"linpr": xr.Variable(
["ensemble_member", "time", "grid_latitude", "grid_longitude"],
np.random.rand(
len(ensemble_member), len(time), len(grid_latitude), len(grid_longitude)
def samples_set(samples_factory) -> xr.Dataset:
"""Create a dummy xarray Dataset that looks like a set of samples from the emulator."""

return samples_factory()


@pytest.fixture
def dataset_factory() -> Callable[[int, int], xr.Dataset]:
"""Create a factory function for creating dummy xarray Datasets that look like the training data."""

def _dataset_factory(start_year: int = 1980, time_len: int = 10) -> xr.Dataset:
ensemble_member = xr.Variable(
["ensemble_member"], np.array([f"{i:02}" for i in range(3)])
)

time, time_bnds = build_time_dim(start_year=start_year, time_len=time_len)

coords = {
"ensemble_member": ensemble_member,
"time": time,
"grid_latitude": grid_latitude,
"grid_longitude": grid_longitude,
}

data_vars = {
"linpr": xr.Variable(
["ensemble_member", "time", "grid_latitude", "grid_longitude"],
np.random.rand(
len(ensemble_member),
len(time),
len(grid_latitude),
len(grid_longitude),
),
),
),
"target_pr": xr.Variable(
["ensemble_member", "time", "grid_latitude", "grid_longitude"],
np.random.rand(
len(ensemble_member), len(time), len(grid_latitude), len(grid_longitude)
"target_pr": xr.Variable(
["ensemble_member", "time", "grid_latitude", "grid_longitude"],
np.random.rand(
len(ensemble_member),
len(time),
len(grid_latitude),
len(grid_longitude),
),
),
),
"time_bnds": time_bnds,
}
"time_bnds": time_bnds,
}

ds = xr.Dataset(
data_vars=data_vars,
coords=coords,
)
ds = xr.Dataset(
data_vars=data_vars,
coords=coords,
)

return ds
return ds

return _dataset_factory


@pytest.fixture
def dataset(dataset_factory) -> xr.Dataset:
"""Create a dummy xarray Dataset representing a split of a set of data for training and sampling."""
return dataset_factory()
46 changes: 6 additions & 40 deletions tests/ml_downscaling_emulator/test_postprocess.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import pytest
import xarray as xr

from ml_downscaling_emulator.postprocess import xrqm, qm_1d_dom_aware

Expand Down Expand Up @@ -59,48 +58,15 @@ def lon_range():


@pytest.fixture
def sim_train_da(time_range, lat_range, lon_range):
rng = np.random.default_rng()
return xr.DataArray(
data=rng.normal(
loc=1.0, size=(len(time_range), len(lat_range), len(lon_range))
),
dims=["time", "grid_latitude", "grid_longitude"],
name="target_pr",
coords=dict(
time=(["time"], time_range),
grid_latitude=(["grid_latitude"], lat_range),
grid_longitude=(["grid_longitude"], lon_range),
),
)
def sim_train_da(dataset_factory):
return dataset_factory(time_len=2000).sel(ensemble_member=["01"])["target_pr"]


@pytest.fixture
def ml_train_da(time_range, lat_range, lon_range):
rng = np.random.default_rng()
return xr.DataArray(
data=rng.normal(size=(len(time_range), len(lat_range), len(lon_range))),
dims=["time", "grid_latitude", "grid_longitude"],
name="pred_pr",
coords=dict(
time=(["time"], time_range),
grid_latitude=(["grid_latitude"], lat_range),
grid_longitude=(["grid_longitude"], lon_range),
),
)
def ml_train_da(samples_factory):
return samples_factory(time_len=2000)["pred_pr"]


@pytest.fixture
def ml_eval_da(lat_range, lon_range):
eval_time_range = np.linspace(3, 4, 50)
rng = np.random.default_rng()
return xr.DataArray(
data=rng.normal(size=(len(eval_time_range), len(lat_range), len(lon_range))),
dims=["time", "grid_latitude", "grid_longitude"],
name="pred_pr",
coords=dict(
time=(["time"], eval_time_range),
grid_latitude=(["grid_latitude"], lat_range),
grid_longitude=(["grid_longitude"], lon_range),
),
)
def ml_eval_da(samples_factory):
return samples_factory(start_year=2060, time_len=1000)["pred_pr"]

0 comments on commit a1cfde5

Please sign in to comment.