From 6faabcdb7d45a6ccf015933b033c08972923aec6 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Tue, 5 Mar 2024 10:49:07 +0000 Subject: [PATCH] add a ci action --- .github/workflows/ci.yml | 33 +++++++++++++ src/ml_downscaling_emulator/bin/evaluate.py | 2 +- .../{test_torch.py => test_data.py} | 46 ++++++++++++------- 3 files changed, 64 insertions(+), 17 deletions(-) create mode 100644 .github/workflows/ci.yml rename tests/ml_downscaling_emulator/{test_torch.py => test_data.py} (66%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 000000000..9698dfa4e --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,33 @@ +name: CI + +on: [push] + +jobs: + ci-checks: + runs-on: ubuntu-latest + strategy: + max-parallel: 5 + + steps: + - name: Clone repo + uses: actions/checkout@v4 + - name: setup-micromamba + uses: mamba-org/setup-micromamba@v1.4.1 + with: + environment-file: environment.lock.yml + init-shell: bash + cache-environment: true + post-cleanup: 'all' + - name: Install package + run: | + pip install -e . + shell: micromamba-shell {0} + - name: Install unet + uses: actions/checkout@v4 + with: + repository: henryaddison/Pytorch-UNet + path: src/ml_downscaling_emulator/unet + - name: Test with pytest + run: | + pytest + shell: micromamba-shell {0} diff --git a/src/ml_downscaling_emulator/bin/evaluate.py b/src/ml_downscaling_emulator/bin/evaluate.py index 5915002cc..a5d4b3eea 100644 --- a/src/ml_downscaling_emulator/bin/evaluate.py +++ b/src/ml_downscaling_emulator/bin/evaluate.py @@ -13,7 +13,7 @@ from mlde_utils.training.dataset import load_raw_dataset_split from ..deterministic import sampling from ..deterministic.utils import create_model, restore_checkpoint -from ..torch import get_dataloader +from ..data import get_dataloader logging.basicConfig( diff --git a/tests/ml_downscaling_emulator/test_torch.py b/tests/ml_downscaling_emulator/test_data.py similarity index 66% rename from tests/ml_downscaling_emulator/test_torch.py rename to tests/ml_downscaling_emulator/test_data.py index 4a157b152..4e855ac13 100644 --- a/tests/ml_downscaling_emulator/test_torch.py +++ b/tests/ml_downscaling_emulator/test_data.py @@ -4,11 +4,11 @@ import torch import xarray as xr -from ml_downscaling_emulator.torch import XRDataset +from ml_downscaling_emulator.data import UKCPLocalDataset -def test_XRDataset_item_cond_var(xr_dataset, time_range): - pt_dataset = XRDataset(xr_dataset, ["var1", "var2"], ["target"], time_range) +def test_UKCPLocalDataset_item_cond_var(xr_dataset, time_range): + pt_dataset = UKCPLocalDataset(xr_dataset, ["var1", "var2"], ["target"], time_range) cond = pt_dataset[0][0][:2] expected_cond = torch.stack( [ @@ -20,8 +20,10 @@ def test_XRDataset_item_cond_var(xr_dataset, time_range): assert torch.all(cond == expected_cond) -def test_XRDataset_item_cond_time(xr_dataset, time_range, earliest_time, latest_time): - pt_dataset = XRDataset(xr_dataset, ["var1", "var2"], ["target"], time_range) +def test_UKCPLocalDataset_item_cond_time( + xr_dataset, time_range, earliest_time, latest_time +): + pt_dataset = UKCPLocalDataset(xr_dataset, ["var1", "var2"], ["target"], time_range) test_date = cftime.Datetime360Day(1980, 12, 1, 12, 0, 0, 0, has_year_zero=True) expected_climate_time = (test_date - earliest_time) / (latest_time - earliest_time) @@ -54,16 +56,16 @@ def test_XRDataset_item_cond_time(xr_dataset, time_range, earliest_time, latest_ ) -def test_XRDataset_item_target(xr_dataset, time_range): - pt_dataset = XRDataset(xr_dataset, ["var1", "var2"], ["target"], time_range) +def test_UKCPLocalDataset_item_target(xr_dataset, time_range): + pt_dataset = UKCPLocalDataset(xr_dataset, ["var1", "var2"], ["target"], time_range) target = pt_dataset[0][1] expected_target = torch.Tensor(np.arange(5 * 7).reshape(1, 5, 7)) assert torch.all(target == expected_target) -def test_XRDataset_item_time(xr_dataset, time_range): - pt_dataset = XRDataset(xr_dataset, ["var1", "var2"], ["target"], time_range) +def test_UKCPLocalDataset_item_time(xr_dataset, time_range): + pt_dataset = UKCPLocalDataset(xr_dataset, ["var1", "var2"], ["target"], time_range) time = pt_dataset[0][2] expected_time = cftime.Datetime360Day(1980, 12, 1, 12, 0, 0, 0, has_year_zero=True) @@ -108,28 +110,40 @@ def time_coords(): ) +@pytest.fixture +def em_coords(): + return ["01", "02"] + + def values(shape): return np.arange(np.prod(shape)).reshape(*shape) @pytest.fixture -def xr_dataset(time_coords, lat_coords, lon_coords): +def xr_dataset(em_coords, time_coords, lat_coords, lon_coords): ds = xr.Dataset( data_vars={ "var1": ( - ["time", "grid_longitude", "grid_latitude"], - values((len(time_coords), len(lon_coords), len(lat_coords))), + ["ensemble_member", "time", "grid_longitude", "grid_latitude"], + values( + (len(em_coords), len(time_coords), len(lon_coords), len(lat_coords)) + ), ), "var2": ( - ["time", "grid_longitude", "grid_latitude"], - values((len(time_coords), len(lon_coords), len(lat_coords))), + ["ensemble_member", "time", "grid_longitude", "grid_latitude"], + values( + (len(em_coords), len(time_coords), len(lon_coords), len(lat_coords)) + ), ), "target": ( - ["time", "grid_longitude", "grid_latitude"], - values((len(time_coords), len(lon_coords), len(lat_coords))), + ["ensemble_member", "time", "grid_longitude", "grid_latitude"], + values( + (len(em_coords), len(time_coords), len(lon_coords), len(lat_coords)) + ), ), }, coords=dict( + ensemble_member=(["ensemble_member"], em_coords), time=(["time"], time_coords), grid_longitude=(["grid_longitude"], lon_coords), grid_latitude=(["grid_latitude"], lat_coords),