Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch Pbar #256

Merged
merged 6 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion apax/cli/templates/train_config_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,3 @@ checkpoints:

progress_bar:
disable_epoch_pbar: false
disable_nl_pbar: false
4 changes: 2 additions & 2 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,11 @@ class TrainProgressbarConfig(BaseModel, extra="forbid"):
Parameters
----------
disable_epoch_pbar: Set to True to disable the epoch progress bar.
disable_nl_pbar: Set to True to disable the NL precomputation progress bar.
disable_batch_pbar: Set to True to disable the batch progress bar.
"""

disable_epoch_pbar: bool = False
disable_nl_pbar: bool = False
disable_batch_pbar: bool = True


class CheckpointConfig(BaseModel, extra="forbid"):
Expand Down
4 changes: 3 additions & 1 deletion apax/train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ def run(user_config, log_level="error"):
seed_py_np_tf(config.seed)
rng_key = jax.random.PRNGKey(config.seed)

log.info("Initializing directories")
config.data.model_version_path.mkdir(parents=True, exist_ok=True)
setup_logging(config.data.model_version_path / "train.log", log_level)
config.dump_config(config.data.model_version_path)
log.info(f"Running on {jax.devices()}")

callbacks = initialize_callbacks(config.callbacks, config.data.model_version_path)
loss_fn = initialize_loss_fn(config.loss)
Expand Down Expand Up @@ -148,5 +148,7 @@ def run(user_config, log_level="error"):
sam_rho=config.optimizer.sam_rho,
patience=config.patience,
disable_pbar=config.progress_bar.disable_epoch_pbar,
disable_batch_pbar=config.progress_bar.disable_batch_pbar,
is_ensemble=config.n_models > 1,
)
log.info("Finished training")
23 changes: 23 additions & 0 deletions apax/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def fit(
sam_rho=0.0,
patience: Optional[int] = None,
disable_pbar: bool = False,
disable_batch_pbar: bool = True,
is_ensemble=False,
):
log.info("Beginning Training")
Expand Down Expand Up @@ -70,6 +71,16 @@ def fit(
epoch_loss.update({"train_loss": 0.0})
train_batch_metrics = Metrics.empty()

batch_pbar = trange(
0,
train_steps_per_epoch,
desc="Batches",
ncols=100,
mininterval=1.0,
disable=disable_batch_pbar,
leave=False,
)

for batch_idx in range(train_steps_per_epoch):
callbacks.on_train_batch_begin(batch=batch_idx)

Expand All @@ -84,6 +95,7 @@ def fit(

epoch_loss["train_loss"] += jnp.mean(batch_loss)
callbacks.on_train_batch_end(batch=batch_idx)
batch_pbar.update()

epoch_loss["train_loss"] /= train_steps_per_epoch
epoch_loss["train_loss"] = float(epoch_loss["train_loss"])
Expand All @@ -96,13 +108,24 @@ def fit(
if val_ds is not None:
epoch_loss.update({"val_loss": 0.0})
val_batch_metrics = Metrics.empty()

batch_pbar = trange(
0,
val_steps_per_epoch,
desc="Batches",
ncols=100,
mininterval=1.0,
disable=disable_batch_pbar,
leave=False,
)
for batch_idx in range(val_steps_per_epoch):
batch = next(batch_val_ds)

batch_loss, val_batch_metrics = val_step(
state.params, batch, val_batch_metrics
)
epoch_loss["val_loss"] += batch_loss
batch_pbar.update()

epoch_loss["val_loss"] /= val_steps_per_epoch
epoch_loss["val_loss"] = float(epoch_loss["val_loss"])
Expand Down
1 change: 0 additions & 1 deletion tests/regression_tests/apax_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,3 @@ checkpoints:

progress_bar:
disable_epoch_pbar: true
disable_nl_pbar: true
Loading