diff --git a/bin/local-test-train b/bin/local-test-train index 16e7824cb..b5a16816f 100755 --- a/bin/local-test-train +++ b/bin/local-test-train @@ -2,7 +2,7 @@ set -euo pipefail -config_name="xarray_cunet_continuous" +config_name="ukcp18_cunet_continuous" dataset="bham_gcmx-4x_12em_psl-temp4th-vort4th_eqvt_random-season" sde="subvpsde" diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_xarray_12em_configs.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp18_12em_configs.py similarity index 88% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_xarray_12em_configs.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp18_12em_configs.py index fe5b0184a..abd461ff4 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_xarray_12em_configs.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp18_12em_configs.py @@ -1,7 +1,7 @@ import ml_collections import torch -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_xarray_configs import get_default_configs as get_base_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_configs import get_default_configs as get_base_configs def get_default_configs(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_xarray_configs.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp18_configs.py similarity index 100% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_xarray_configs.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/default_ukcp18_configs.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/xarray_12em_cncsnpp_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_12em_cncsnpp_continuous.py similarity index 97% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/xarray_12em_cncsnpp_continuous.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_12em_cncsnpp_continuous.py index 3b11f75cb..8c75391b8 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/xarray_12em_cncsnpp_continuous.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_12em_cncsnpp_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training NCSN++ on precip data with sub-VP SDE.""" -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_xarray_12em_configs import get_default_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_12em_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/xarray_cncsnpp_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_cncsnpp_continuous.py similarity index 97% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/xarray_cncsnpp_continuous.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_cncsnpp_continuous.py index c3a2b22b7..d58d2227d 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/xarray_cncsnpp_continuous.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_cncsnpp_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training NCSN++ on precip data with sub-VP SDE.""" -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_xarray_configs import get_default_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/xarray_cunet_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_cunet_continuous.py similarity index 97% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/xarray_cunet_continuous.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_cunet_continuous.py index eda2dd25f..5059f077a 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/xarray_cunet_continuous.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp18_cunet_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training conditional U-Net on precip data with sub-VP SDE.""" -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_xarray_configs import get_default_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/xarray_ncsnpp_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/xarray_ncsnpp_continuous.py deleted file mode 100644 index 5a539f062..000000000 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/xarray_ncsnpp_continuous.py +++ /dev/null @@ -1,65 +0,0 @@ -# coding=utf-8 -# Copyright 2020 The Google Research Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Lint as: python3 -"""Training NCSN++ on precip data with sub-VP SDE.""" -from configs.default_xarray_configs import get_default_configs - - -def get_config(): - config = get_default_configs() - # training - training = config.training - training.sde = 'subvpsde' - training.continuous = True - training.reduce_mean = True - - # sampling - sampling = config.sampling - sampling.method = 'pc' - sampling.predictor = 'euler_maruyama' - sampling.corrector = 'none' - - # data - data = config.data - data.centered = True - - # model - model = config.model - model.name = 'ncsnpp' - model.scale_by_sigma = False - model.ema_rate = 0.9999 - model.normalization = 'GroupNorm' - model.nonlinearity = 'swish' - model.nf = 128 - model.ch_mult = (1, 2, 2, 2) - model.num_res_blocks = 4 - model.attn_resolutions = (16,) - model.resamp_with_conv = True - model.conditional = True - model.fir = True - model.fir_kernel = [1, 3, 3, 1] - model.skip_rescale = True - model.resblock_type = 'biggan' - model.progressive = 'none' - model.progressive_input = 'residual' - model.progressive_combine = 'sum' - model.attention_type = 'ddpm' - model.embedding_type = 'positional' - model.init_scale = 0. - model.fourier_scale = 16 - model.conv_size = 3 - - return config diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/xarray_cncsnpp_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp18_cncsnpp_continuous.py similarity index 97% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/xarray_cncsnpp_continuous.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp18_cncsnpp_continuous.py index e21422a76..691000f2a 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/xarray_cncsnpp_continuous.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp18_cncsnpp_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training NCSN++ on precip data with VE SDE.""" -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_xarray_configs import get_default_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/xarray_cunet_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp18_cunet_continuous.py similarity index 97% rename from src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/xarray_cunet_continuous.py rename to src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp18_cunet_continuous.py index d29614cac..0bb63f147 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/xarray_cunet_continuous.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/ukcp18_cunet_continuous.py @@ -15,7 +15,7 @@ # Lint as: python3 """Training UNet on XArray with VE SDE.""" -from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_xarray_configs import get_default_configs +from ml_downscaling_emulator.score_sde_pytorch_hja22.configs.default_ukcp18_configs import get_default_configs def get_config(): diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/xarray_ncsnpp_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/xarray_ncsnpp_continuous.py deleted file mode 100644 index 78e0fd9e7..000000000 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/xarray_ncsnpp_continuous.py +++ /dev/null @@ -1,59 +0,0 @@ -# coding=utf-8 -# Copyright 2020 The Google Research Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Lint as: python3 -"""Training NCSN++ on precip data with VE SDE.""" -from configs.default_xarray_configs import get_default_configs - - -def get_config(): - config = get_default_configs() - # training - training = config.training - training.sde = 'vesde' - training.continuous = True - - # sampling - sampling = config.sampling - sampling.method = 'pc' - sampling.predictor = 'reverse_diffusion' - sampling.corrector = 'langevin' - - # model - model = config.model - model.name = 'ncsnpp' - model.scale_by_sigma = True - model.ema_rate = 0.999 - model.normalization = 'GroupNorm' - model.nonlinearity = 'swish' - model.nf = 128 - model.ch_mult = (1, 2, 2, 2) - model.num_res_blocks = 4 - model.attn_resolutions = (16,) - model.resamp_with_conv = True - model.conditional = True - model.fir = True - model.fir_kernel = [1, 3, 3, 1] - model.skip_rescale = True - model.resblock_type = 'biggan' - model.progressive = 'none' - model.progressive_input = 'residual' - model.progressive_combine = 'sum' - model.attention_type = 'ddpm' - model.init_scale = 0. - model.fourier_scale = 16 - model.conv_size = 3 - - return config diff --git a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/xarray_unet_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/xarray_unet_continuous.py deleted file mode 100644 index 468d84aa7..000000000 --- a/src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/vesde/xarray_unet_continuous.py +++ /dev/null @@ -1,63 +0,0 @@ -# coding=utf-8 -# Copyright 2020 The Google Research Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Lint as: python3 -"""Training UNet on XArray with VE SDE.""" -from configs.default_xarray_configs import get_default_configs - - -def get_config(): - config = get_default_configs() - # training - training = config.training - training.sde = 'vesde' - training.continuous = True - - # sampling - sampling = config.sampling - sampling.method = 'pc' - sampling.predictor = 'reverse_diffusion' - sampling.corrector = 'langevin' - - # data - data = config.data - data.image_size = 28 # u-net architechture currently designed to work with 28x28 images - - # model - model = config.model - model.name = 'unet' - model.scale_by_sigma = True - model.ema_rate = 0.999 - model.normalization = 'GroupNorm' - model.nonlinearity = 'swish' - model.nf = 128 - model.ch_mult = (1, 2, 2, 2) - model.num_res_blocks = 4 - model.attn_resolutions = (16,) - model.resamp_with_conv = True - model.conditional = True - model.fir = True - model.fir_kernel = [1, 3, 3, 1] - model.skip_rescale = True - model.resblock_type = 'biggan' - model.progressive = 'none' - model.progressive_input = 'residual' - model.progressive_combine = 'sum' - model.attention_type = 'ddpm' - model.init_scale = 0. - model.fourier_scale = 16 - model.conv_size = 3 - - return config