Skip to content

Commit

Permalink
add debug config for testing diff emulators in mv setting
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Jun 12, 2024
1 parent 2846023 commit 3ff0ecc
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training conditional U-Net on precip data with sub-VP SDE.
DEBUGGING ONLY"""
from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_1em_configs import get_default_configs


def get_config():
config = get_default_configs()
# training
training = config.training
training.sde = 'subvpsde'
training.continuous = True
training.reduce_mean = True

# sampling
sampling = config.sampling
sampling.method = 'pc'
sampling.predictor = 'euler_maruyama'
sampling.corrector = 'none'

# data
data = config.data
data.centered = True
data.dataset_name = 'debug-sample-mv'

# model
model = config.model
model.name = 'cunet'
model.ema_rate = 0.9999

return config
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_config():
# data
data = config.data
data.centered = True
data.dataset_name = 'debug-sample'

# model
model = config.model
Expand Down
6 changes: 3 additions & 3 deletions tests/smoke-test
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
set -euo pipefail

sde="subvpsde"
config_name="ukcp_local_mv_12em_cncsnpp_continuous" # should switch out for a debug config
dataset="bham_gcmx-4x_2em_mv" # should switch out for a debugging small dataset
config_name="ukcp_local_mv_debug"
dataset="debug-sample-mv"

workdir="output/test/${sde}/${config_name}/test-run"
config_path="src/ml_downscaling_emulator/score_sde_pytorch/configs/${sde}/${config_name}.py"
Expand All @@ -14,7 +14,7 @@ train_batch_size=2
epoch=2

rm -rf ${workdir}
WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.data.dataset_name=${dataset} --config.training.snapshot_freq=5 --config.training.eval_freq=100 --config.training.log_freq=50 --config.training.n_epochs=${epoch} --config.model.loc_spec_channels=${loc_spec_channels} --config.training.batch_size=${train_batch_size} --config.data.input_transform_key=stan --config.data.target_transform_key=sqrturrecen --config.data.time_inputs=True --config.model.name=cunet --config.model.num_scales=10
WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.data.dataset_name=${dataset} --config.training.snapshot_freq=5 --config.training.eval_freq=100 --config.training.log_freq=50 --config.training.n_epochs=${epoch} --config.model.loc_spec_channels=${loc_spec_channels} --config.training.batch_size=${train_batch_size} --config.data.input_transform_key=stan --config.data.target_transform_key=sqrturrecen --config.data.time_inputs=True --config.model.num_scales=10

num_samples=2
eval_batch_size=32
Expand Down

0 comments on commit 3ff0ecc

Please sign in to comment.