Skip to content

Commit

Permalink
Merge pull request #31 from henryaddison/rename-configs-to-local
Browse files Browse the repository at this point in the history
rename configs to UKCP Local pr
  • Loading branch information
henryaddison authored Mar 4, 2024
2 parents 03ebfec + afa9f45 commit 9158ed6
Show file tree
Hide file tree
Showing 10 changed files with 14 additions and 16 deletions.
2 changes: 1 addition & 1 deletion bin/local-test-train
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

set -euo pipefail

config_name="ukcp18_cunet_continuous"
config_name="ukcp_local_pr_cunet_continuous"
dataset="bham_gcmx-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic"

sde="subvpsde"
Expand Down
8 changes: 3 additions & 5 deletions bin/queue-training
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import typer
app = typer.Typer()


def train_cmd(sde, dataset, workdir, config, config_overrides=list):
def train_cmd(sde, workdir, config, config_overrides=list):
train_basecmd = ["python", f"bin/main.py"]

train_opts = {
Expand All @@ -22,7 +22,6 @@ def train_cmd(sde, dataset, workdir, config, config_overrides=list):
return (
train_basecmd
+ [arg for item in train_opts.items() for arg in item]
+ [f"--config.data.dataset_name={dataset}"]
+ config_overrides
)

Expand Down Expand Up @@ -51,9 +50,8 @@ def queue_cmd(duration, memory):
def main(
ctx: typer.Context,
model_run_id: str,
cpm_dataset: str,
sde: str,
config: str = "xarray_12em_cncsnpp_continuous",
config: str = "ukcp_local_pr_12em_cncsnpp_continuous",
memory: int = 64,
duration: int = 72,
):
Expand All @@ -67,7 +65,7 @@ def main(
full_cmd = (
queue_cmd(duration=duration, memory=memory)
+ ["--"]
+ train_cmd(sde, cpm_dataset, workdir, config, ctx.args)
+ train_cmd(sde, workdir, config, ctx.args)
)
print(" ".join(full_cmd).strip(), file=sys.stderr)
output = subprocess.run(full_cmd, capture_output=True)
Expand Down
6 changes: 3 additions & 3 deletions src/ml_downscaling_emulator/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Loading UKCP18 data into PyTorch"""
"""Loading UKCP Local data into PyTorch"""

import cftime
import numpy as np
Expand All @@ -13,7 +13,7 @@
)


class UKCP18Dataset(Dataset):
class UKCPLocalDataset(Dataset):
def __init__(self, ds, variables, target_variables, time_range):
self.ds = ds
self.variables = variables
Expand Down Expand Up @@ -96,7 +96,7 @@ def custom_collate(batch):
time_range = None
if include_time_inputs:
time_range = TIME_RANGE
xr_dataset = UKCP18Dataset(xr_data, variables, target_variables, time_range)
xr_dataset = UKCPLocalDataset(xr_data, variables, target_variables, time_range)
data_loader = DataLoader(
xr_dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=custom_collate
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ml_collections
import torch

from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_configs import get_default_configs as get_base_configs
from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_configs import get_default_configs as get_base_configs


def get_default_configs():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_default_configs():

# data
config.data = data = ml_collections.ConfigDict()
data.dataset = 'UKCP18'
data.dataset = 'UKCP_Local'
data.dataset_name = 'bham_gcmx-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season'
data.image_size = 64
data.random_flip = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# Lint as: python3
"""Training NCSN++ on precip data with sub-VP SDE."""
from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_12em_configs import get_default_configs
from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_12em_configs import get_default_configs


def get_config():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# Lint as: python3
"""Training NCSN++ on precip data with sub-VP SDE."""
from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_configs import get_default_configs
from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_configs import get_default_configs


def get_config():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# Lint as: python3
"""Training conditional U-Net on precip data with sub-VP SDE."""
from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_configs import get_default_configs
from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_configs import get_default_configs


def get_config():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# Lint as: python3
"""Training NCSN++ on precip data with VE SDE."""
from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_configs import get_default_configs
from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_configs import get_default_configs


def get_config():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# Lint as: python3
"""Training UNet on XArray with VE SDE."""
from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_configs import get_default_configs
from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_configs import get_default_configs


def get_config():
Expand Down

0 comments on commit 9158ed6

Please sign in to comment.