Skip to content

Commit

Permalink
fix tests now running in CI
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Mar 5, 2024
1 parent 454c605 commit 5c83b6e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/ml_downscaling_emulator/bin/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from mlde_utils.training.dataset import load_raw_dataset_split
from ..deterministic import sampling
from ..deterministic.utils import create_model, restore_checkpoint
from ..torch import get_dataloader
from ..data import get_dataloader


logging.basicConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import torch
import xarray as xr

from ml_downscaling_emulator.torch import XRDataset
from ml_downscaling_emulator.data import UKCPLocalDataset


def test_XRDataset_item_cond_var(xr_dataset, time_range):
pt_dataset = XRDataset(xr_dataset, ["var1", "var2"], ["target"], time_range)
def test_UKCPLocalDataset_item_cond_var(xr_dataset, time_range):
pt_dataset = UKCPLocalDataset(xr_dataset, ["var1", "var2"], ["target"], time_range)
cond = pt_dataset[0][0][:2]
expected_cond = torch.stack(
[
Expand All @@ -20,8 +20,10 @@ def test_XRDataset_item_cond_var(xr_dataset, time_range):
assert torch.all(cond == expected_cond)


def test_XRDataset_item_cond_time(xr_dataset, time_range, earliest_time, latest_time):
pt_dataset = XRDataset(xr_dataset, ["var1", "var2"], ["target"], time_range)
def test_UKCPLocalDataset_item_cond_time(
xr_dataset, time_range, earliest_time, latest_time
):
pt_dataset = UKCPLocalDataset(xr_dataset, ["var1", "var2"], ["target"], time_range)

test_date = cftime.Datetime360Day(1980, 12, 1, 12, 0, 0, 0, has_year_zero=True)
expected_climate_time = (test_date - earliest_time) / (latest_time - earliest_time)
Expand Down Expand Up @@ -54,16 +56,16 @@ def test_XRDataset_item_cond_time(xr_dataset, time_range, earliest_time, latest_
)


def test_XRDataset_item_target(xr_dataset, time_range):
pt_dataset = XRDataset(xr_dataset, ["var1", "var2"], ["target"], time_range)
def test_UKCPLocalDataset_item_target(xr_dataset, time_range):
pt_dataset = UKCPLocalDataset(xr_dataset, ["var1", "var2"], ["target"], time_range)
target = pt_dataset[0][1]
expected_target = torch.Tensor(np.arange(5 * 7).reshape(1, 5, 7))

assert torch.all(target == expected_target)


def test_XRDataset_item_time(xr_dataset, time_range):
pt_dataset = XRDataset(xr_dataset, ["var1", "var2"], ["target"], time_range)
def test_UKCPLocalDataset_item_time(xr_dataset, time_range):
pt_dataset = UKCPLocalDataset(xr_dataset, ["var1", "var2"], ["target"], time_range)

time = pt_dataset[0][2]
expected_time = cftime.Datetime360Day(1980, 12, 1, 12, 0, 0, 0, has_year_zero=True)
Expand Down Expand Up @@ -108,28 +110,40 @@ def time_coords():
)


@pytest.fixture
def em_coords():
return ["01", "02"]


def values(shape):
return np.arange(np.prod(shape)).reshape(*shape)


@pytest.fixture
def xr_dataset(time_coords, lat_coords, lon_coords):
def xr_dataset(em_coords, time_coords, lat_coords, lon_coords):
ds = xr.Dataset(
data_vars={
"var1": (
["time", "grid_longitude", "grid_latitude"],
values((len(time_coords), len(lon_coords), len(lat_coords))),
["ensemble_member", "time", "grid_longitude", "grid_latitude"],
values(
(len(em_coords), len(time_coords), len(lon_coords), len(lat_coords))
),
),
"var2": (
["time", "grid_longitude", "grid_latitude"],
values((len(time_coords), len(lon_coords), len(lat_coords))),
["ensemble_member", "time", "grid_longitude", "grid_latitude"],
values(
(len(em_coords), len(time_coords), len(lon_coords), len(lat_coords))
),
),
"target": (
["time", "grid_longitude", "grid_latitude"],
values((len(time_coords), len(lon_coords), len(lat_coords))),
["ensemble_member", "time", "grid_longitude", "grid_latitude"],
values(
(len(em_coords), len(time_coords), len(lon_coords), len(lat_coords))
),
),
},
coords=dict(
ensemble_member=(["ensemble_member"], em_coords),
time=(["time"], time_coords),
grid_longitude=(["grid_longitude"], lon_coords),
grid_latitude=(["grid_latitude"], lat_coords),
Expand Down

0 comments on commit 5c83b6e

Please sign in to comment.