Skip to content

Commit

Permalink
do need sde in det configs
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Aug 3, 2024
1 parent 65cb649 commit a2fd23d
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_default_configs():
training.n_epochs = 20
training.snapshot_freq = 5
training.eval_freq = 5000
# training.sde = ""
training.sde = ""

# sampling
config.sampling = sampling = ml_collections.ConfigDict()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
# Modifications copyright 2024 Henry Addison
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Debug config for training in a deterministic fashion."""

from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.default_configs import get_default_configs

def get_config():
config = get_default_configs()

# training
training = config.training
training.n_epochs = 2
training.snapshot_freq = 5
training.eval_freq = 100
training.log_freq = 50
training.batch_size = 2

# data
data = config.data
data.dataset_name = 'debug-sample'
data.time_inputs=True

# model
model = config.model
model.name = 'cunet'

return config
20 changes: 20 additions & 0 deletions tests/smoke-test-det
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#! /usr/bin/env bash

set -euo pipefail

config_name="ukcp_local_pr_debug"

workdir="output/test/deterministic/${config_name}/test-run"
config_path="src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/${config_name}.py"

loc_spec_channels=0

rm -rf ${workdir}
WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.model.loc_spec_channels=${loc_spec_channels}

num_samples=2
eval_batch_size=32
checkpoint="epoch_${epoch}"

rm -rf "${workdir}/samples/${checkpoint}/${dataset}"
python bin/predict.py ${workdir} --dataset ${dataset} --checkpoint ${checkpoint} --batch-size ${eval_batch_size} --num-samples ${num_samples} --ensemble-member 01

0 comments on commit a2fd23d

Please sign in to comment.