From 5c83b6e60cc2a37626b1939c815f48f77f528999 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Tue, 5 Mar 2024 10:55:06 +0000 Subject: [PATCH] fix tests now running in CI --- src/ml_downscaling_emulator/bin/evaluate.py | 2 +- .../{test_torch.py => test_data.py} | 46 ++++++++++++------- 2 files changed, 31 insertions(+), 17 deletions(-) rename tests/ml_downscaling_emulator/{test_torch.py => test_data.py} (66%) diff --git a/src/ml_downscaling_emulator/bin/evaluate.py b/src/ml_downscaling_emulator/bin/evaluate.py index 5915002cc..a5d4b3eea 100644 --- a/src/ml_downscaling_emulator/bin/evaluate.py +++ b/src/ml_downscaling_emulator/bin/evaluate.py @@ -13,7 +13,7 @@ from mlde_utils.training.dataset import load_raw_dataset_split from ..deterministic import sampling from ..deterministic.utils import create_model, restore_checkpoint -from ..torch import get_dataloader +from ..data import get_dataloader logging.basicConfig( diff --git a/tests/ml_downscaling_emulator/test_torch.py b/tests/ml_downscaling_emulator/test_data.py similarity index 66% rename from tests/ml_downscaling_emulator/test_torch.py rename to tests/ml_downscaling_emulator/test_data.py index 4a157b152..4e855ac13 100644 --- a/tests/ml_downscaling_emulator/test_torch.py +++ b/tests/ml_downscaling_emulator/test_data.py @@ -4,11 +4,11 @@ import torch import xarray as xr -from ml_downscaling_emulator.torch import XRDataset +from ml_downscaling_emulator.data import UKCPLocalDataset -def test_XRDataset_item_cond_var(xr_dataset, time_range): - pt_dataset = XRDataset(xr_dataset, ["var1", "var2"], ["target"], time_range) +def test_UKCPLocalDataset_item_cond_var(xr_dataset, time_range): + pt_dataset = UKCPLocalDataset(xr_dataset, ["var1", "var2"], ["target"], time_range) cond = pt_dataset[0][0][:2] expected_cond = torch.stack( [ @@ -20,8 +20,10 @@ def test_XRDataset_item_cond_var(xr_dataset, time_range): assert torch.all(cond == expected_cond) -def test_XRDataset_item_cond_time(xr_dataset, time_range, earliest_time, latest_time): - pt_dataset = XRDataset(xr_dataset, ["var1", "var2"], ["target"], time_range) +def test_UKCPLocalDataset_item_cond_time( + xr_dataset, time_range, earliest_time, latest_time +): + pt_dataset = UKCPLocalDataset(xr_dataset, ["var1", "var2"], ["target"], time_range) test_date = cftime.Datetime360Day(1980, 12, 1, 12, 0, 0, 0, has_year_zero=True) expected_climate_time = (test_date - earliest_time) / (latest_time - earliest_time) @@ -54,16 +56,16 @@ def test_XRDataset_item_cond_time(xr_dataset, time_range, earliest_time, latest_ ) -def test_XRDataset_item_target(xr_dataset, time_range): - pt_dataset = XRDataset(xr_dataset, ["var1", "var2"], ["target"], time_range) +def test_UKCPLocalDataset_item_target(xr_dataset, time_range): + pt_dataset = UKCPLocalDataset(xr_dataset, ["var1", "var2"], ["target"], time_range) target = pt_dataset[0][1] expected_target = torch.Tensor(np.arange(5 * 7).reshape(1, 5, 7)) assert torch.all(target == expected_target) -def test_XRDataset_item_time(xr_dataset, time_range): - pt_dataset = XRDataset(xr_dataset, ["var1", "var2"], ["target"], time_range) +def test_UKCPLocalDataset_item_time(xr_dataset, time_range): + pt_dataset = UKCPLocalDataset(xr_dataset, ["var1", "var2"], ["target"], time_range) time = pt_dataset[0][2] expected_time = cftime.Datetime360Day(1980, 12, 1, 12, 0, 0, 0, has_year_zero=True) @@ -108,28 +110,40 @@ def time_coords(): ) +@pytest.fixture +def em_coords(): + return ["01", "02"] + + def values(shape): return np.arange(np.prod(shape)).reshape(*shape) @pytest.fixture -def xr_dataset(time_coords, lat_coords, lon_coords): +def xr_dataset(em_coords, time_coords, lat_coords, lon_coords): ds = xr.Dataset( data_vars={ "var1": ( - ["time", "grid_longitude", "grid_latitude"], - values((len(time_coords), len(lon_coords), len(lat_coords))), + ["ensemble_member", "time", "grid_longitude", "grid_latitude"], + values( + (len(em_coords), len(time_coords), len(lon_coords), len(lat_coords)) + ), ), "var2": ( - ["time", "grid_longitude", "grid_latitude"], - values((len(time_coords), len(lon_coords), len(lat_coords))), + ["ensemble_member", "time", "grid_longitude", "grid_latitude"], + values( + (len(em_coords), len(time_coords), len(lon_coords), len(lat_coords)) + ), ), "target": ( - ["time", "grid_longitude", "grid_latitude"], - values((len(time_coords), len(lon_coords), len(lat_coords))), + ["ensemble_member", "time", "grid_longitude", "grid_latitude"], + values( + (len(em_coords), len(time_coords), len(lon_coords), len(lat_coords)) + ), ), }, coords=dict( + ensemble_member=(["ensemble_member"], em_coords), time=(["time"], time_coords), grid_longitude=(["grid_longitude"], lon_coords), grid_latitude=(["grid_latitude"], lat_coords),