diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_mv_12em_cncsnpp_continuous.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_mv_12em_cncsnpp_continuous.py new file mode 100644 index 000000000..b6b48e0cb --- /dev/null +++ b/src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_mv_12em_cncsnpp_continuous.py @@ -0,0 +1,69 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Training NCSN++ on precip data with sub-VP SDE.""" +from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_12em_configs import get_default_configs + + +def get_config(): + config = get_default_configs() + # training + training = config.training + training.sde = 'subvpsde' + training.continuous = True + training.reduce_mean = True + + # sampling + sampling = config.sampling + sampling.method = 'pc' + sampling.predictor = 'euler_maruyama' + sampling.corrector = 'none' + + # data + data = config.data + data.centered = True + data.dataset_name = 'bham_gcmx-4x_2em_mv' + + # model + model = config.model + model.name = 'cncsnpp' + model.scale_by_sigma = False + model.ema_rate = 0.9999 + model.normalization = 'GroupNorm' + model.nonlinearity = 'swish' + model.nf = 128 + model.ch_mult = (1, 2, 2, 2) + model.num_res_blocks = 4 + model.attn_resolutions = (16,) + model.resamp_with_conv = True + model.conditional = True + model.fir = True + model.fir_kernel = [1, 3, 3, 1] + model.skip_rescale = True + model.resblock_type = 'biggan' + model.progressive = 'none' + model.progressive_input = 'residual' + model.progressive_combine = 'sum' + model.attention_type = 'ddpm' + model.embedding_type = 'positional' + model.init_scale = 0. + model.fourier_scale = 16 + model.conv_size = 3 + + # data + data = config.data + + return config