Skip to content

Commit

Permalink
update deterministic part so runs again
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Mar 7, 2024
1 parent 6faabcd commit 0bc707f
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 8 deletions.
6 changes: 3 additions & 3 deletions bin/deterministic/local-test-train
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ cpm_dataset="bham_gcmx-4x_12em_psl-temp4th-vort4th_eqvt_random-season"
gcm_dataset="bham_60km-4x_12em_psl-temp4th-vort4th_eqvt_random-season"
workdir="output/test/unet/test-run"

config_path="src/ml_downscaling_emulator/deterministic/configs/default.py"
config_path="src/ml_downscaling_emulator/deterministic/configs/ukcp_local_pr_unet.py"

map_features=1
train_batch_size=32
map_features=0
train_batch_size=4

rm -rf ${workdir}
WANDB_EXPERIMENT_NAME="test" python ${SCRIPT_DIR}/main.py --mode train --workdir ${workdir} --config ${config_path} --config.data.dataset_name=${cpm_dataset} --config.training.batch_size=${train_batch_size} --config.training.n_epochs=2 --config.data.time_inputs=True --config.model.name=debug
Expand Down
9 changes: 5 additions & 4 deletions src/ml_downscaling_emulator/bin/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,11 @@ def sample(

if batch_size is not None:
config.eval.batch_size = batch_size
if input_transform_dataset is not None:
config.data.input_transform_dataset = input_transform_dataset
else:
config.data.input_transform_dataset = dataset
with config.unlocked():
if input_transform_dataset is not None:
config.data.input_transform_dataset = input_transform_dataset
else:
config.data.input_transform_dataset = dataset
if input_transform_key is not None:
config.data.input_transform_key = input_transform_key

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def get_config():
data.dataset_name = ""
data.input_transform_key = "stan"
data.target_transform_key = "sqrturrecen"
data.input_transform_dataset = None
data.time_inputs = False

config.model = model = ml_collections.ConfigDict()
Expand Down
2 changes: 1 addition & 1 deletion src/ml_downscaling_emulator/deterministic/run_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from ..training import log_epoch, track_run
from .utils import restore_checkpoint, save_checkpoint, create_model
from ..torch import get_dataloader
from ..data import get_dataloader

FLAGS = flags.FLAGS
EXPERIMENT_NAME = os.getenv("WANDB_EXPERIMENT_NAME")
Expand Down

0 comments on commit 0bc707f

Please sign in to comment.