From 6ef6ab5f7512d3356c3de40f59954dba7e8e793b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 10 Nov 2023 21:51:59 +0100 Subject: [PATCH 01/10] set absl log level to warning --- apax/train/run.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/apax/train/run.py b/apax/train/run.py index 92922e43..27ff6aa0 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -1,6 +1,7 @@ import logging import os from pathlib import Path +import sys from typing import List import jax @@ -32,7 +33,18 @@ def setup_logging(log_file, log_level): while len(logging.root.handlers) > 0: logging.root.removeHandler(logging.root.handlers[-1]) - logging.basicConfig(filename=log_file, level=log_levels[log_level]) + # Remove uninformative checkpointing absl logs + logging.getLogger('absl').setLevel(logging.WARNING) + + logging.basicConfig( + level=log_levels[log_level], + format="%(levelname)s | %(asctime)s | %(message)s", + datefmt='%H:%M:%S', + handlers=[ + logging.FileHandler(log_file), + logging.StreamHandler(sys.stderr) + ] + ) def initialize_loss_fn(loss_config_list: List[LossConfig]) -> LossCollection: @@ -44,25 +56,21 @@ def initialize_loss_fn(loss_config_list: List[LossConfig]) -> LossCollection: def run(user_config, log_file="train.log", log_level="error"): - setup_logging(log_file, log_level) - log.info("Loading user config") config = parse_config(user_config) seed_py_np_tf(config.seed) rng_key = jax.random.PRNGKey(config.seed) - experiment = Path(config.data.experiment) - directory = Path(config.data.directory) - model_version_path = directory / experiment log.info("Initializing directories") - model_version_path.mkdir(parents=True, exist_ok=True) - config.dump_config(model_version_path) + config.data.model_version_path.mkdir(parents=True, exist_ok=True) + setup_logging(config.data.model_version_path / "train.log", log_level) + config.dump_config(config.data.model_version_path) - callbacks = initialize_callbacks(config.callbacks, model_version_path) + callbacks = initialize_callbacks(config.callbacks, config.data.model_version_path) loss_fn = initialize_loss_fn(config.loss) Metrics = initialize_metrics(config.metrics) - train_raw_ds, val_raw_ds = load_data_files(config.data, model_version_path) + train_raw_ds, val_raw_ds = load_data_files(config.data, config.data.model_version_path) # remove path argument train_ds, ds_stats = initialize_dataset(config, train_raw_ds) val_ds = initialize_dataset(config, val_raw_ds, calc_stats=False) @@ -112,7 +120,7 @@ def run(user_config, log_file="train.log", log_level="error"): Metrics, callbacks, n_epochs, - ckpt_dir=os.path.join(config.data.directory, config.data.experiment), + ckpt_dir=config.data.model_version_path, ckpt_interval=config.checkpoints.ckpt_interval, val_ds=val_ds, sam_rho=config.optimizer.sam_rho, From 6772dfb24534ea62f5a1ea880790775e95c57b97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 10 Nov 2023 21:52:38 +0100 Subject: [PATCH 02/10] removed loging from config parsing --- apax/config/common.py | 1 - 1 file changed, 1 deletion(-) diff --git a/apax/config/common.py b/apax/config/common.py index ce700480..843830bb 100644 --- a/apax/config/common.py +++ b/apax/config/common.py @@ -18,7 +18,6 @@ def parse_config(config: Union[str, os.PathLike, dict], mode: str = "train") -> config: Path to the config file or a dictionary containing the config. """ - log.info("Loading user config") if isinstance(config, (str, os.PathLike)): with open(config, "r") as stream: config = yaml.safe_load(stream) From fb256b0053640a29761e72fd3cc7943c0c4b7446 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 10 Nov 2023 21:53:51 +0100 Subject: [PATCH 03/10] set log level to info, removed log file from cli and run --- apax/cli/apax_app.py | 5 ++--- apax/train/run.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/apax/cli/apax_app.py b/apax/cli/apax_app.py index d24f2420..c6338fc3 100644 --- a/apax/cli/apax_app.py +++ b/apax/cli/apax_app.py @@ -34,15 +34,14 @@ def train( train_config_path: Path = typer.Argument( ..., help="Training configuration YAML file." ), - log_level: str = typer.Option("error", help="Sets the training logging level."), - log_file: str = typer.Option("train.log", help="Specifies the name of the log file"), + log_level: str = typer.Option("info", help="Sets the training logging level."), ): """ Starts the training of a model with parameters provided by a configuration file. """ from apax.train.run import run - run(train_config_path, log_file, log_level) + run(train_config_path, log_level) @app.command() diff --git a/apax/train/run.py b/apax/train/run.py index 27ff6aa0..ff7ef77e 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -55,7 +55,7 @@ def initialize_loss_fn(loss_config_list: List[LossConfig]) -> LossCollection: return LossCollection(loss_funcs) -def run(user_config, log_file="train.log", log_level="error"): +def run(user_config, log_level="error"): config = parse_config(user_config) seed_py_np_tf(config.seed) From 0800164b9dd38fc5fd7a53685645311a16d291ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 10 Nov 2023 21:54:12 +0100 Subject: [PATCH 04/10] turned model version path into property --- apax/config/train_config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/apax/config/train_config.py b/apax/config/train_config.py index cd9b6eb8..f383af08 100644 --- a/apax/config/train_config.py +++ b/apax/config/train_config.py @@ -112,12 +112,14 @@ def validate_shift_scale_methods(self): return self + @property def model_version_path(self): version_path = Path(self.directory) / self.experiment return version_path + @property def best_model_path(self): - return self.model_version_path() / "best" + return self.model_version_path / "best" class ModelConfig(BaseModel, extra="forbid"): From 3689443694bfb0068dbb6b39712fd796f9e31b7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 10 Nov 2023 21:56:00 +0100 Subject: [PATCH 05/10] adapted model verison path preopetry change across project --- apax/md/nvt.py | 2 +- apax/train/trainer.py | 4 ++-- tests/conftest.py | 4 ++-- tests/integration_tests/bal/test_api.py | 2 +- tests/integration_tests/md/test_md.py | 10 +++++----- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/apax/md/nvt.py b/apax/md/nvt.py index d26e60a8..9cac9217 100644 --- a/apax/md/nvt.py +++ b/apax/md/nvt.py @@ -372,7 +372,7 @@ def md_setup(model_config: Config, md_config: MDConfig): disable_cell_list=True, ) - _, params = restore_parameters(model_config.data.model_version_path()) + _, params = restore_parameters(model_config.data.model_version_path) params = canonicalize_energy_model_parameters(params) energy_fn = create_energy_fn( model.apply, params, system.atomic_numbers, system.box, model_config.n_models diff --git a/apax/train/trainer.py b/apax/train/trainer.py index d7faa83e..867ee50e 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -31,8 +31,8 @@ def fit( log.info("Beginning Training") callbacks.on_train_begin() - latest_dir = ckpt_dir + "/latest" - best_dir = ckpt_dir + "/best" + latest_dir = ckpt_dir / "latest" + best_dir = ckpt_dir / "best" ckpt_manager = CheckpointManager() train_step, val_step = make_step_fns( diff --git a/tests/conftest.py b/tests/conftest.py index 72807814..7df83945 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -147,6 +147,6 @@ def load_and_dump_config(config_path, dump_path): model_config_dict["data"]["directory"] = dump_path.as_posix() model_config = Config.model_validate(model_config_dict) - os.makedirs(model_config.data.model_version_path(), exist_ok=True) - model_config.dump_config(model_config.data.model_version_path()) + os.makedirs(model_config.data.model_version_path, exist_ok=True) + model_config.dump_config(model_config.data.model_version_path) return model_config diff --git a/tests/integration_tests/bal/test_api.py b/tests/integration_tests/bal/test_api.py index 59a7ad2c..9878dac6 100644 --- a/tests/integration_tests/bal/test_api.py +++ b/tests/integration_tests/bal/test_api.py @@ -41,7 +41,7 @@ def test_kernel_selection(example_atoms, get_tmp_path, get_sample_input): bs = 5 selected_indices = kernel_selection( - model_config.data.model_version_path(), + model_config.data.model_version_path, train_atoms, pool_atoms, base_fm_options, diff --git a/tests/integration_tests/md/test_md.py b/tests/integration_tests/md/test_md.py index 08c1df10..3bc2577e 100644 --- a/tests/integration_tests/md/test_md.py +++ b/tests/integration_tests/md/test_md.py @@ -34,8 +34,8 @@ def test_run_md(get_tmp_path): md_config_dict["initial_structure"] = get_tmp_path.as_posix() + "/atoms.extxyz" model_config = Config.model_validate(model_config_dict) - os.makedirs(model_config.data.model_version_path()) - model_config.dump_config(model_config.data.model_version_path()) + os.makedirs(model_config.data.model_version_path) + model_config.dump_config(model_config.data.model_version_path) md_config = MDConfig.model_validate(md_config_dict) positions = jnp.array( @@ -106,8 +106,8 @@ def test_ase_calc(get_tmp_path): model_config_dict["data"]["directory"] = get_tmp_path.as_posix() model_config = Config.model_validate(model_config_dict) - os.makedirs(model_config.data.model_version_path(), exist_ok=True) - model_config.dump_config(model_config.data.model_version_path()) + os.makedirs(model_config.data.model_version_path, exist_ok=True) + model_config.dump_config(model_config.data.model_version_path) cell_size = 10.0 positions = np.array( @@ -157,7 +157,7 @@ def test_ase_calc(get_tmp_path): atoms = read(initial_structure_path.as_posix()) calc = ASECalculator( - [model_config.data.model_version_path(), model_config.data.model_version_path()] + [model_config.data.model_version_path, model_config.data.model_version_path] ) atoms.calc = calc From da6946e33f2a8807b7c03af99fb212809e0b533f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 10 Nov 2023 21:58:06 +0100 Subject: [PATCH 06/10] added safety conversion of model_version_path to Path --- apax/train/checkpoints.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/apax/train/checkpoints.py b/apax/train/checkpoints.py index f1864f2e..380f32a4 100644 --- a/apax/train/checkpoints.py +++ b/apax/train/checkpoints.py @@ -122,6 +122,7 @@ def stack_parameters(param_list: List[FrozenDict]) -> FrozenDict: def load_params(model_version_path: Path, best=True) -> FrozenDict: + model_version_path = Path(model_version_path) if best: model_version_path = model_version_path / "best" log.info(f"loading checkpoint from {model_version_path}") @@ -142,7 +143,7 @@ def restore_single_parameters(model_dir: Path) -> Tuple[Config, FrozenDict]: """Load the config and parameters of a single model """ model_config = parse_config(Path(model_dir) / "config.yaml") - ckpt_dir = model_config.data.model_version_path() + ckpt_dir = model_config.data.model_version_path return model_config, load_params(ckpt_dir) From e3b6b97811b666dd8b9b14a7162ed02db6abb5f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 10 Nov 2023 22:03:19 +0100 Subject: [PATCH 07/10] removed model version path argument from load data files --- apax/data/initialization.py | 4 ++-- apax/train/run.py | 13 ++++--------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/apax/data/initialization.py b/apax/data/initialization.py index 0904238d..f7d9e468 100644 --- a/apax/data/initialization.py +++ b/apax/data/initialization.py @@ -19,7 +19,7 @@ class RawDataset: additional_labels: Optional[dict] = None -def load_data_files(data_config, model_version_path): +def load_data_files(data_config): log.info("Running Input Pipeline") if data_config.data_path is not None: log.info(f"Read data file {data_config.data_path}") @@ -32,7 +32,7 @@ def load_data_files(data_config, model_version_path): train_label_dict, val_label_dict = split_label(label_dict, train_idxs, val_idxs) np.savez( - os.path.join(model_version_path, "train_val_idxs"), + data_config.model_version_path / "train_val_idxs", train_idxs=train_idxs, val_idxs=val_idxs, ) diff --git a/apax/train/run.py b/apax/train/run.py index ff7ef77e..eb231924 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -1,6 +1,4 @@ import logging -import os -from pathlib import Path import sys from typing import List @@ -34,16 +32,13 @@ def setup_logging(log_file, log_level): logging.root.removeHandler(logging.root.handlers[-1]) # Remove uninformative checkpointing absl logs - logging.getLogger('absl').setLevel(logging.WARNING) + logging.getLogger("absl").setLevel(logging.WARNING) logging.basicConfig( level=log_levels[log_level], format="%(levelname)s | %(asctime)s | %(message)s", - datefmt='%H:%M:%S', - handlers=[ - logging.FileHandler(log_file), - logging.StreamHandler(sys.stderr) - ] + datefmt="%H:%M:%S", + handlers=[logging.FileHandler(log_file), logging.StreamHandler(sys.stderr)], ) @@ -70,7 +65,7 @@ def run(user_config, log_level="error"): loss_fn = initialize_loss_fn(config.loss) Metrics = initialize_metrics(config.metrics) - train_raw_ds, val_raw_ds = load_data_files(config.data, config.data.model_version_path) # remove path argument + train_raw_ds, val_raw_ds = load_data_files(config.data) train_ds, ds_stats = initialize_dataset(config, train_raw_ds) val_ds = initialize_dataset(config, val_raw_ds, calc_stats=False) From 19a8794e7634b8d7c6b6b3a25e9cd066dd80a628 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 10 Nov 2023 22:03:42 +0100 Subject: [PATCH 08/10] removed unused os import --- apax/data/initialization.py | 1 - 1 file changed, 1 deletion(-) diff --git a/apax/data/initialization.py b/apax/data/initialization.py index f7d9e468..2f7ad964 100644 --- a/apax/data/initialization.py +++ b/apax/data/initialization.py @@ -1,6 +1,5 @@ import dataclasses import logging -import os from typing import Optional import numpy as np From db80629b59322d83a0e10f3cbd6e45e5b992f63a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 10 Nov 2023 22:06:46 +0100 Subject: [PATCH 09/10] updated tests to use best model path as a property --- tests/integration_tests/bal/test_api.py | 2 +- tests/integration_tests/md/test_md.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration_tests/bal/test_api.py b/tests/integration_tests/bal/test_api.py index 9878dac6..6595d0eb 100644 --- a/tests/integration_tests/bal/test_api.py +++ b/tests/integration_tests/bal/test_api.py @@ -23,7 +23,7 @@ def test_kernel_selection(example_atoms, get_tmp_path, get_sample_input): _, params = initialize_model(model_config, inputs) ckpt = {"model": {"params": params}, "epoch": 0} - best_dir = model_config.data.best_model_path() + best_dir = model_config.data.best_model_path checkpoints.save_checkpoint( ckpt_dir=best_dir, target=ckpt, diff --git a/tests/integration_tests/md/test_md.py b/tests/integration_tests/md/test_md.py index 3bc2577e..abca6214 100644 --- a/tests/integration_tests/md/test_md.py +++ b/tests/integration_tests/md/test_md.py @@ -147,7 +147,7 @@ def test_ase_calc(get_tmp_path): ) ckpt = {"model": {"params": params}, "epoch": 0} - best_dir = model_config.data.best_model_path() + best_dir = model_config.data.best_model_path checkpoints.save_checkpoint( ckpt_dir=best_dir, target=ckpt, From a2c92c94d9f5a0dcec83b5468e5c029ede980578 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 10 Nov 2023 22:10:02 +0100 Subject: [PATCH 10/10] use best model dir directly in tests --- tests/integration_tests/md/test_md.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/integration_tests/md/test_md.py b/tests/integration_tests/md/test_md.py index abca6214..6a0249d1 100644 --- a/tests/integration_tests/md/test_md.py +++ b/tests/integration_tests/md/test_md.py @@ -80,11 +80,8 @@ def test_run_md(get_tmp_path): ) ckpt = {"model": {"params": params}, "epoch": 0} - best_dir = os.path.join( - model_config.data.directory, model_config.data.experiment, "best" - ) checkpoints.save_checkpoint( - ckpt_dir=best_dir, + ckpt_dir=model_config.data.best_model_path, target=ckpt, step=0, overwrite=True, @@ -147,9 +144,8 @@ def test_ase_calc(get_tmp_path): ) ckpt = {"model": {"params": params}, "epoch": 0} - best_dir = model_config.data.best_model_path checkpoints.save_checkpoint( - ckpt_dir=best_dir, + ckpt_dir=model_config.data.best_model_path, target=ckpt, step=0, overwrite=True,