diff --git a/apax/train/run.py b/apax/train/run.py index 7ea9c949..bc1bdf68 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -90,10 +90,10 @@ def run(user_config, log_level="error"): seed_py_np_tf(config.seed) rng_key = jax.random.PRNGKey(config.seed) - log.info("Initializing directories") config.data.model_version_path.mkdir(parents=True, exist_ok=True) setup_logging(config.data.model_version_path / "train.log", log_level) config.dump_config(config.data.model_version_path) + log.info(f"Running on {jax.devices()}") callbacks = initialize_callbacks(config.callbacks, config.data.model_version_path) loss_fn = initialize_loss_fn(config.loss)