diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py index e04711244..fb9c72037 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py @@ -44,7 +44,7 @@ def get_default_configs(): training.n_epochs = 20 training.snapshot_freq = 5 training.eval_freq = 5000 -# training.sde = "" + training.sde = "" # sampling config.sampling = sampling = ml_collections.ConfigDict() diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_debug.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_debug.py new file mode 100644 index 000000000..2a8153fa4 --- /dev/null +++ b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_debug.py @@ -0,0 +1,42 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# Modifications copyright 2024 Henry Addison +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Debug config for training in a deterministic fashion.""" + +from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.default_configs import get_default_configs + +def get_config(): + config = get_default_configs() + + # training + training = config.training + training.n_epochs = 2 + training.snapshot_freq = 5 + training.eval_freq = 100 + training.log_freq = 50 + training.batch_size = 2 + + # data + data = config.data + data.dataset_name = 'debug-sample' + data.time_inputs=True + + # model + model = config.model + model.name = 'cunet' + + return config diff --git a/tests/smoke-test-det b/tests/smoke-test-det new file mode 100755 index 000000000..99498b505 --- /dev/null +++ b/tests/smoke-test-det @@ -0,0 +1,20 @@ +#! /usr/bin/env bash + +set -euo pipefail + +config_name="ukcp_local_pr_debug" + +workdir="output/test/deterministic/${config_name}/test-run" +config_path="src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/${config_name}.py" + +loc_spec_channels=0 + +rm -rf ${workdir} +WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.model.loc_spec_channels=${loc_spec_channels} + +num_samples=2 +eval_batch_size=32 +checkpoint="epoch_${epoch}" + +rm -rf "${workdir}/samples/${checkpoint}/${dataset}" +python bin/predict.py ${workdir} --dataset ${dataset} --checkpoint ${checkpoint} --batch-size ${eval_batch_size} --num-samples ${num_samples} --ensemble-member 01