diff --git a/bin/local-test-train b/bin/local-test-train index d67283c5e..0443c4769 100755 --- a/bin/local-test-train +++ b/bin/local-test-train @@ -2,7 +2,7 @@ set -euo pipefail -config_name="ukcp18_cunet_continuous" +config_name="ukcp_local_pr_cunet_continuous" dataset="bham_gcmx-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic" sde="subvpsde" diff --git a/bin/queue-training b/bin/queue-training index 22f2fdab0..3fa5f75ca 100755 --- a/bin/queue-training +++ b/bin/queue-training @@ -10,7 +10,7 @@ import typer app = typer.Typer() -def train_cmd(sde, dataset, workdir, config, config_overrides=list): +def train_cmd(sde, workdir, config, config_overrides=list): train_basecmd = ["python", f"bin/main.py"] train_opts = { @@ -22,7 +22,6 @@ def train_cmd(sde, dataset, workdir, config, config_overrides=list): return ( train_basecmd + [arg for item in train_opts.items() for arg in item] - + [f"--config.data.dataset_name={dataset}"] + config_overrides ) @@ -51,9 +50,8 @@ def queue_cmd(duration, memory): def main( ctx: typer.Context, model_run_id: str, - cpm_dataset: str, sde: str, - config: str = "xarray_12em_cncsnpp_continuous", + config: str = "ukcp_local_pr_12em_cncsnpp_continuous", memory: int = 64, duration: int = 72, ): @@ -67,7 +65,7 @@ def main( full_cmd = ( queue_cmd(duration=duration, memory=memory) + ["--"] - + train_cmd(sde, cpm_dataset, workdir, config, ctx.args) + + train_cmd(sde, workdir, config, ctx.args) ) print(" ".join(full_cmd).strip(), file=sys.stderr) output = subprocess.run(full_cmd, capture_output=True) diff --git a/src/ml_downscaling_emulator/data.py b/src/ml_downscaling_emulator/data.py index 3fe0f64ba..e4060a262 100644 --- a/src/ml_downscaling_emulator/data.py +++ b/src/ml_downscaling_emulator/data.py @@ -1,4 +1,4 @@ -"""Loading UKCP18 data into PyTorch""" +"""Loading UKCP Local data into PyTorch""" import cftime import numpy as np @@ -13,7 +13,7 @@ ) -class UKCP18Dataset(Dataset): +class UKCPLocalDataset(Dataset): def __init__(self, ds, variables, target_variables, time_range): self.ds = ds self.variables = variables @@ -96,7 +96,7 @@ def custom_collate(batch): time_range = None if include_time_inputs: time_range = TIME_RANGE - xr_dataset = UKCP18Dataset(xr_data, variables, target_variables, time_range) + xr_dataset = UKCPLocalDataset(xr_data, variables, target_variables, time_range) data_loader = DataLoader( xr_dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=custom_collate ) diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp18_12em_configs.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp_local_pr_12em_configs.py similarity index 85% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp18_12em_configs.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp_local_pr_12em_configs.py index 5172f7bc1..9d8e44f22 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp18_12em_configs.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp_local_pr_12em_configs.py @@ -1,7 +1,7 @@ import ml_collections import torch -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_configs import get_default_configs as get_base_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_configs import get_default_configs as get_base_configs def get_default_configs(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp18_configs.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp_local_pr_configs.py similarity index 98% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp18_configs.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp_local_pr_configs.py index bb6bb48d1..86cc24ae4 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp18_configs.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp_local_pr_configs.py @@ -40,7 +40,7 @@ def get_default_configs(): # data config.data = data = ml_collections.ConfigDict() - data.dataset = 'UKCP18' + data.dataset = 'UKCP_Local' data.dataset_name = 'bham_gcmx-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season' data.image_size = 64 data.random_flip = False diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_12em_cncsnpp_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py similarity index 96% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_12em_cncsnpp_continuous.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py index 8c75391b8..a498b5ce8 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_12em_cncsnpp_continuous.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training NCSN++ on precip data with sub-VP SDE.""" -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_12em_configs import get_default_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_12em_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_cncsnpp_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_cncsnpp_continuous.py similarity index 97% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_cncsnpp_continuous.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_cncsnpp_continuous.py index d58d2227d..b8a095c9f 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_cncsnpp_continuous.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_cncsnpp_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training NCSN++ on precip data with sub-VP SDE.""" -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_configs import get_default_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_cunet_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_cunet_continuous.py similarity index 97% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_cunet_continuous.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_cunet_continuous.py index 5059f077a..dfbc38e59 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_cunet_continuous.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_cunet_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training conditional U-Net on precip data with sub-VP SDE.""" -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_configs import get_default_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp18_cncsnpp_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp_local_pr_cncsnpp_continuous.py similarity index 97% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp18_cncsnpp_continuous.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp_local_pr_cncsnpp_continuous.py index 691000f2a..6a0184f3e 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp18_cncsnpp_continuous.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp_local_pr_cncsnpp_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training NCSN++ on precip data with VE SDE.""" -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_configs import get_default_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp18_cunet_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp_local_pr_cunet_continuous.py similarity index 96% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp18_cunet_continuous.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp_local_pr_cunet_continuous.py index 0bb63f147..0ad9c6b6c 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp18_cunet_continuous.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp_local_pr_cunet_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training UNet on XArray with VE SDE.""" -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_configs import get_default_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_configs import get_default_configs def get_config():