From fc4ed246a6b99915a423ede64f3dbeaffbd1fede Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Mon, 18 Mar 2024 15:00:57 +0000 Subject: [PATCH] add 1em to non-12em configs and add a low data config --- bin/local-test-train | 2 +- .../default_ukcp_local_pr_12em_configs.py | 2 +- ...y => default_ukcp_local_pr_1em_configs.py} | 0 ...> ukcp_local_pr_1em_cncsnpp_continuous.py} | 2 +- ...ukcp_local_pr_1em_cncsnpp_continuous_ld.py | 72 +++++++++++++++++++ ... => ukcp_local_pr_1em_cunet_continuous.py} | 2 +- .../vesde/ukcp_local_pr_cncsnpp_continuous.py | 2 +- .../vesde/ukcp_local_pr_cunet_continuous.py | 2 +- 8 files changed, 78 insertions(+), 6 deletions(-) rename src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/{default_ukcp_local_pr_configs.py => default_ukcp_local_pr_1em_configs.py} (100%) rename src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/{ukcp_local_pr_cncsnpp_continuous.py => ukcp_local_pr_1em_cncsnpp_continuous.py} (97%) create mode 100644 src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_1em_cncsnpp_continuous_ld.py rename src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/{ukcp_local_pr_cunet_continuous.py => ukcp_local_pr_1em_cunet_continuous.py} (96%) diff --git a/bin/local-test-train b/bin/local-test-train index b4fe29730..2c2ac091d 100755 --- a/bin/local-test-train +++ b/bin/local-test-train @@ -2,7 +2,7 @@ set -euo pipefail -config_name="ukcp_local_pr_cunet_continuous" +config_name="ukcp_local_pr_1em_cunet_continuous" dataset="bham_gcmx-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic" sde="subvpsde" diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp_local_pr_12em_configs.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp_local_pr_12em_configs.py index 9d8e44f22..54a4ff3ef 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp_local_pr_12em_configs.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp_local_pr_12em_configs.py @@ -1,7 +1,7 @@ import ml_collections import torch -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_configs import get_default_configs as get_base_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_1em_configs import get_default_configs as get_base_configs def get_default_configs(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp_local_pr_configs.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp_local_pr_1em_configs.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp_local_pr_configs.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp_local_pr_1em_configs.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_cncsnpp_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_1em_cncsnpp_continuous.py similarity index 97% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_cncsnpp_continuous.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_1em_cncsnpp_continuous.py index b8a095c9f..b180b88e2 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_cncsnpp_continuous.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_1em_cncsnpp_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training NCSN++ on precip data with sub-VP SDE.""" -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_configs import get_default_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_1em_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_1em_cncsnpp_continuous_ld.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_1em_cncsnpp_continuous_ld.py new file mode 100644 index 000000000..4700f229a --- /dev/null +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_1em_cncsnpp_continuous_ld.py @@ -0,0 +1,72 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Training NCSN++ on precip data with sub-VP SDE.""" +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_1em_configs import get_default_configs + + +def get_config(): + config = get_default_configs() + # training + training = config.training + training.sde = 'subvpsde' + training.continuous = True + training.reduce_mean = True + training.n_epochs = 300 + + # sampling + sampling = config.sampling + sampling.method = 'pc' + sampling.predictor = 'euler_maruyama' + sampling.corrector = 'none' + + # data + data = config.data + data.centered = True + data.dataset_name = 'bham_gcmx-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic' + + # model + model = config.model + model.name = 'cncsnpp' + model.scale_by_sigma = False + model.ema_rate = 0.9999 + model.normalization = 'GroupNorm' + model.nonlinearity = 'swish' + model.nf = 128 + model.ch_mult = (1, 2, 2, 2) + model.num_res_blocks = 4 + model.attn_resolutions = (16,) + model.resamp_with_conv = True + model.conditional = True + model.fir = True + model.fir_kernel = [1, 3, 3, 1] + model.skip_rescale = True + model.resblock_type = 'biggan' + model.progressive = 'none' + model.progressive_input = 'residual' + model.progressive_combine = 'sum' + model.attention_type = 'ddpm' + model.embedding_type = 'positional' + model.init_scale = 0. + model.fourier_scale = 16 + model.conv_size = 3 + + # data + data = config.data + data.input_transform_key = "stan" + data.target_transform_key = "sqrturrecen" + + return config diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_cunet_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_1em_cunet_continuous.py similarity index 96% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_cunet_continuous.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_1em_cunet_continuous.py index dfbc38e59..caba144d1 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_cunet_continuous.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_1em_cunet_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training conditional U-Net on precip data with sub-VP SDE.""" -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_configs import get_default_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_1em_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp_local_pr_cncsnpp_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp_local_pr_cncsnpp_continuous.py index 6a0184f3e..5dea3c41c 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp_local_pr_cncsnpp_continuous.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp_local_pr_cncsnpp_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training NCSN++ on precip data with VE SDE.""" -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_configs import get_default_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_1em_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp_local_pr_cunet_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp_local_pr_cunet_continuous.py index 0ad9c6b6c..a5b114539 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp_local_pr_cunet_continuous.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp_local_pr_cunet_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training UNet on XArray with VE SDE.""" -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_configs import get_default_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp_local_pr_1em_configs import get_default_configs def get_config():