Skip to content

Commit

Permalink
add a ci action
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Mar 5, 2024
1 parent bdd5028 commit 6faabcd
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 17 deletions.
33 changes: 33 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: CI

on: [push]

jobs:
ci-checks:
runs-on: ubuntu-latest
strategy:
max-parallel: 5

steps:
- name: Clone repo
uses: actions/checkout@v4
- name: setup-micromamba
uses: mamba-org/[email protected]
with:
environment-file: environment.lock.yml
init-shell: bash
cache-environment: true
post-cleanup: 'all'
- name: Install package
run: |
pip install -e .
shell: micromamba-shell {0}
- name: Install unet
uses: actions/checkout@v4
with:
repository: henryaddison/Pytorch-UNet
path: src/ml_downscaling_emulator/unet
- name: Test with pytest
run: |
pytest
shell: micromamba-shell {0}
2 changes: 1 addition & 1 deletion src/ml_downscaling_emulator/bin/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from mlde_utils.training.dataset import load_raw_dataset_split
from ..deterministic import sampling
from ..deterministic.utils import create_model, restore_checkpoint
from ..torch import get_dataloader
from ..data import get_dataloader


logging.basicConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import torch
import xarray as xr

from ml_downscaling_emulator.torch import XRDataset
from ml_downscaling_emulator.data import UKCPLocalDataset


def test_XRDataset_item_cond_var(xr_dataset, time_range):
pt_dataset = XRDataset(xr_dataset, ["var1", "var2"], ["target"], time_range)
def test_UKCPLocalDataset_item_cond_var(xr_dataset, time_range):
pt_dataset = UKCPLocalDataset(xr_dataset, ["var1", "var2"], ["target"], time_range)
cond = pt_dataset[0][0][:2]
expected_cond = torch.stack(
[
Expand All @@ -20,8 +20,10 @@ def test_XRDataset_item_cond_var(xr_dataset, time_range):
assert torch.all(cond == expected_cond)


def test_XRDataset_item_cond_time(xr_dataset, time_range, earliest_time, latest_time):
pt_dataset = XRDataset(xr_dataset, ["var1", "var2"], ["target"], time_range)
def test_UKCPLocalDataset_item_cond_time(
xr_dataset, time_range, earliest_time, latest_time
):
pt_dataset = UKCPLocalDataset(xr_dataset, ["var1", "var2"], ["target"], time_range)

test_date = cftime.Datetime360Day(1980, 12, 1, 12, 0, 0, 0, has_year_zero=True)
expected_climate_time = (test_date - earliest_time) / (latest_time - earliest_time)
Expand Down Expand Up @@ -54,16 +56,16 @@ def test_XRDataset_item_cond_time(xr_dataset, time_range, earliest_time, latest_
)


def test_XRDataset_item_target(xr_dataset, time_range):
pt_dataset = XRDataset(xr_dataset, ["var1", "var2"], ["target"], time_range)
def test_UKCPLocalDataset_item_target(xr_dataset, time_range):
pt_dataset = UKCPLocalDataset(xr_dataset, ["var1", "var2"], ["target"], time_range)
target = pt_dataset[0][1]
expected_target = torch.Tensor(np.arange(5 * 7).reshape(1, 5, 7))

assert torch.all(target == expected_target)


def test_XRDataset_item_time(xr_dataset, time_range):
pt_dataset = XRDataset(xr_dataset, ["var1", "var2"], ["target"], time_range)
def test_UKCPLocalDataset_item_time(xr_dataset, time_range):
pt_dataset = UKCPLocalDataset(xr_dataset, ["var1", "var2"], ["target"], time_range)

time = pt_dataset[0][2]
expected_time = cftime.Datetime360Day(1980, 12, 1, 12, 0, 0, 0, has_year_zero=True)
Expand Down Expand Up @@ -108,28 +110,40 @@ def time_coords():
)


@pytest.fixture
def em_coords():
return ["01", "02"]


def values(shape):
return np.arange(np.prod(shape)).reshape(*shape)


@pytest.fixture
def xr_dataset(time_coords, lat_coords, lon_coords):
def xr_dataset(em_coords, time_coords, lat_coords, lon_coords):
ds = xr.Dataset(
data_vars={
"var1": (
["time", "grid_longitude", "grid_latitude"],
values((len(time_coords), len(lon_coords), len(lat_coords))),
["ensemble_member", "time", "grid_longitude", "grid_latitude"],
values(
(len(em_coords), len(time_coords), len(lon_coords), len(lat_coords))
),
),
"var2": (
["time", "grid_longitude", "grid_latitude"],
values((len(time_coords), len(lon_coords), len(lat_coords))),
["ensemble_member", "time", "grid_longitude", "grid_latitude"],
values(
(len(em_coords), len(time_coords), len(lon_coords), len(lat_coords))
),
),
"target": (
["time", "grid_longitude", "grid_latitude"],
values((len(time_coords), len(lon_coords), len(lat_coords))),
["ensemble_member", "time", "grid_longitude", "grid_latitude"],
values(
(len(em_coords), len(time_coords), len(lon_coords), len(lat_coords))
),
),
},
coords=dict(
ensemble_member=(["ensemble_member"], em_coords),
time=(["time"], time_coords),
grid_longitude=(["grid_longitude"], lon_coords),
grid_latitude=(["grid_latitude"], lat_coords),
Expand Down

0 comments on commit 6faabcd

Please sign in to comment.