Skip to content

Commit

Permalink
add a script to update references to dataset names in workdir configs
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Jun 25, 2024
1 parent 9ffe893 commit ded681b
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions bin/rename-dataset-in-configs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#!/usr/bin/env python
# rename the dataset name in the config files of trained models

import glob
import logging
from ml_collections import config_dict
import os
import yaml

logging.basicConfig(
level=os.environ.get("LOG_LEVEL", "INFO").upper(),
format="%(levelname)s - %(filename)s - %(asctime)s - %(message)s",
)
logger = logging.getLogger()


def load_config(config_path: str) -> config_dict.ConfigDict:
logger.info(f"Loading config from {config_path}")
with open(config_path) as f:
config = config_dict.ConfigDict(yaml.unsafe_load(f))
return config


def save_config(config: config_dict.ConfigDict, config_path: str) -> None:
logger.info(f"Saving config to {config_path}")
with open(config_path, "w") as f:
f.write(config.to_yaml())


def update_dataset_name(dataset_name: str) -> str:
return (
dataset_name.replace("_eqvt", "_pr")
.replace("_random-season", "")
.replace("bham_gcmx", "bham64_ccpm")
.replace("bham_60km", "bham64_gcm")
)


def update_config(config: config_dict.ConfigDict) -> config_dict.ConfigDict:
with config.unlocked():
dataset_name = config.data.dataset_name
new_dataset_name = update_dataset_name(dataset_name)
logger.info(f"Changing {dataset_name} to {new_dataset_name}")
config.data.dataset_name = new_dataset_name

if config.data.input_transform_dataset is not None:
dataset_name = config.data.input_transform_dataset
new_dataset_name = update_dataset_name(dataset_name)
logger.info(f"Changing {dataset_name} to {new_dataset_name}")
config.data.input_transform_dataset = new_dataset_name

return config


workdirs = [
path
for g in ["score-sde/subvpsde/*/*", "u-net/*", "u-net/ukcp_local_pr_unet"]
for path in glob.glob(os.path.join(os.getenv("DERIVED_DATA"), "workdirs", g))
]

for workdir in workdirs:
config_path = os.path.join(workdir, "config.yml")
config = load_config(config_path)
config = update_config(config)
# save_config(config, config_path)

0 comments on commit ded681b

Please sign in to comment.