From 893dc484b6493a7115e6e4930d6247a6802269e4 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Mon, 4 Mar 2024 13:13:57 +0000 Subject: [PATCH 1/2] rename configs to UKCP Local pr to make room for multivar and inform using CPM from UKCP Local (not 18) also updated queue train helper to not require dataset now have them set to good versions in config by default --- bin/local-test-train | 2 +- bin/queue-training | 8 +++----- ...m_configs.py => default_ukcp_local_pr_12em_configs.py} | 2 +- ...ukcp18_configs.py => default_ukcp_local_pr_configs.py} | 2 +- ...inuous.py => ukcp_local_pr_12em_cncsnpp_continuous.py} | 2 +- ..._continuous.py => ukcp_local_pr_cncsnpp_continuous.py} | 2 +- ...et_continuous.py => ukcp_local_pr_cunet_continuous.py} | 2 +- ..._continuous.py => ukcp_local_pr_cncsnpp_continuous.py} | 2 +- ...et_continuous.py => ukcp_local_pr_cunet_continuous.py} | 2 +- 9 files changed, 11 insertions(+), 13 deletions(-) rename src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/{default_ukcp18_12em_configs.py => default_ukcp_local_pr_12em_configs.py} (85%) rename src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/{default_ukcp18_configs.py => default_ukcp_local_pr_configs.py} (98%) rename src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/{ukcp18_12em_cncsnpp_continuous.py => ukcp_local_pr_12em_cncsnpp_continuous.py} (96%) rename src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/{ukcp18_cncsnpp_continuous.py => ukcp_local_pr_cncsnpp_continuous.py} (97%) rename src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/{ukcp18_cunet_continuous.py => ukcp_local_pr_cunet_continuous.py} (97%) rename src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/{ukcp18_cncsnpp_continuous.py => ukcp_local_pr_cncsnpp_continuous.py} (97%) rename src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/{ukcp18_cunet_continuous.py => ukcp_local_pr_cunet_continuous.py} (96%) diff --git a/bin/local-test-train b/bin/local-test-train index d67283c5e..0443c4769 100755 --- a/bin/local-test-train +++ b/bin/local-test-train @@ -2,7 +2,7 @@ set -euo pipefail -config_name="ukcp18_cunet_continuous" +config_name="ukcp_local_pr_cunet_continuous" dataset="bham_gcmx-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic" sde="subvpsde" diff --git a/bin/queue-training b/bin/queue-training index 22f2fdab0..3fa5f75ca 100755 --- a/bin/queue-training +++ b/bin/queue-training @@ -10,7 +10,7 @@ import typer app = typer.Typer() -def train_cmd(sde, dataset, workdir, config, config_overrides=list): +def train_cmd(sde, workdir, config, config_overrides=list): train_basecmd = ["python", f"bin/main.py"] train_opts = { @@ -22,7 +22,6 @@ def train_cmd(sde, dataset, workdir, config, config_overrides=list): return ( train_basecmd + [arg for item in train_opts.items() for arg in item] - + [f"--config.data.dataset_name={dataset}"] + config_overrides ) @@ -51,9 +50,8 @@ def queue_cmd(duration, memory): def main( ctx: typer.Context, model_run_id: str, - cpm_dataset: str, sde: str, - config: str = "xarray_12em_cncsnpp_continuous", + config: str = "ukcp_local_pr_12em_cncsnpp_continuous", memory: int = 64, duration: int = 72, ): @@ -67,7 +65,7 @@ def main( full_cmd = ( queue_cmd(duration=duration, memory=memory) + ["--"] - + train_cmd(sde, cpm_dataset, workdir, config, ctx.args) + + train_cmd(sde, workdir, config, ctx.args) ) print(" ".join(full_cmd).strip(), file=sys.stderr) output = subprocess.run(full_cmd, capture_output=True) diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp18_12em_configs.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp_local_pr_12em_configs.py similarity index 85% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp18_12em_configs.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp_local_pr_12em_configs.py index 5172f7bc1..9d8e44f22 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp18_12em_configs.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp_local_pr_12em_configs.py @@ -1,7 +1,7 @@ import ml_collections import torch -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_configs import get_default_configs as get_base_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_configs import get_default_configs as get_base_configs def get_default_configs(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp18_configs.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp_local_pr_configs.py similarity index 98% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp18_configs.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp_local_pr_configs.py index bb6bb48d1..86cc24ae4 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp18_configs.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp_local_pr_configs.py @@ -40,7 +40,7 @@ def get_default_configs(): # data config.data = data = ml_collections.ConfigDict() - data.dataset = 'UKCP18' + data.dataset = 'UKCP_Local' data.dataset_name = 'bham_gcmx-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season' data.image_size = 64 data.random_flip = False diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_12em_cncsnpp_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py similarity index 96% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_12em_cncsnpp_continuous.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py index 8c75391b8..a498b5ce8 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_12em_cncsnpp_continuous.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training NCSN++ on precip data with sub-VP SDE.""" -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_12em_configs import get_default_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_12em_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_cncsnpp_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_cncsnpp_continuous.py similarity index 97% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_cncsnpp_continuous.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_cncsnpp_continuous.py index d58d2227d..b8a095c9f 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_cncsnpp_continuous.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_cncsnpp_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training NCSN++ on precip data with sub-VP SDE.""" -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_configs import get_default_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_cunet_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_cunet_continuous.py similarity index 97% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_cunet_continuous.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_cunet_continuous.py index 5059f077a..dfbc38e59 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_cunet_continuous.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_cunet_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training conditional U-Net on precip data with sub-VP SDE.""" -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_configs import get_default_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp18_cncsnpp_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp_local_pr_cncsnpp_continuous.py similarity index 97% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp18_cncsnpp_continuous.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp_local_pr_cncsnpp_continuous.py index 691000f2a..6a0184f3e 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp18_cncsnpp_continuous.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp_local_pr_cncsnpp_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training NCSN++ on precip data with VE SDE.""" -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_configs import get_default_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp18_cunet_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp_local_pr_cunet_continuous.py similarity index 96% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp18_cunet_continuous.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp_local_pr_cunet_continuous.py index 0bb63f147..0ad9c6b6c 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp18_cunet_continuous.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp_local_pr_cunet_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training UNet on XArray with VE SDE.""" -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_configs import get_default_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_configs import get_default_configs def get_config(): From afa9f45830a667a33f684c62bff7c7be415e5c27 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Mon, 4 Mar 2024 13:17:58 +0000 Subject: [PATCH 2/2] rename torch Dataset to Local from 18 --- src/ml_downscaling_emulator/data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ml_downscaling_emulator/data.py b/src/ml_downscaling_emulator/data.py index 3fe0f64ba..e4060a262 100644 --- a/src/ml_downscaling_emulator/data.py +++ b/src/ml_downscaling_emulator/data.py @@ -1,4 +1,4 @@ -"""Loading UKCP18 data into PyTorch""" +"""Loading UKCP Local data into PyTorch""" import cftime import numpy as np @@ -13,7 +13,7 @@ ) -class UKCP18Dataset(Dataset): +class UKCPLocalDataset(Dataset): def __init__(self, ds, variables, target_variables, time_range): self.ds = ds self.variables = variables @@ -96,7 +96,7 @@ def custom_collate(batch): time_range = None if include_time_inputs: time_range = TIME_RANGE - xr_dataset = UKCP18Dataset(xr_data, variables, target_variables, time_range) + xr_dataset = UKCPLocalDataset(xr_data, variables, target_variables, time_range) data_loader = DataLoader( xr_dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=custom_collate )