diff --git a/.github/workflows/linting.yaml b/.github/workflows/linting.yaml index 27ee44dd..708b7426 100644 --- a/.github/workflows/linting.yaml +++ b/.github/workflows/linting.yaml @@ -13,7 +13,7 @@ jobs: uses: psf/black@stable with: src: "./apax" - version: "22.10.0" + version: "22.12.0" isort: runs-on: ubuntu-latest @@ -25,7 +25,7 @@ jobs: - name: Install isort run: | - pip install isort==5.10.1 + pip install isort==5.12.0 - name: run isort run: | diff --git a/apax/config/train_config.py b/apax/config/train_config.py index f383af08..d32a58be 100644 --- a/apax/config/train_config.py +++ b/apax/config/train_config.py @@ -279,7 +279,11 @@ class Config(BaseModel, frozen=True, extra="forbid"): ---------- n_epochs: Number of training epochs. + patience: Number of epochs without improvement before trainings gets terminated. seed: Random seed. + n_models: Number of models to be trained at once. + n_jitted_steps: Number of train batches to be processed in a compiled loop. + Can yield singificant speedups for small structures or small batch sizes. data: :class: `Data` configuration. model: :class: `Model` configuration. metrics: List of :class: `metric` configurations. @@ -294,6 +298,7 @@ class Config(BaseModel, frozen=True, extra="forbid"): patience: Optional[PositiveInt] = None seed: int = 1 n_models: int = 1 + n_jitted_steps: int = 1 data: DataConfig model: ModelConfig = ModelConfig() diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 34a9d31f..3e53fa33 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -174,6 +174,7 @@ def __init__( """ self.n_epoch = n_epoch self.batch_size = None + self.n_jit_steps = 1 self.buffer_size = buffer_size max_atoms, max_nbrs = find_largest_system(inputs) @@ -187,6 +188,9 @@ def __init__( def set_batch_size(self, batch_size: int): self.batch_size = self.validate_batch_size(batch_size) + def batch_multiple_steps(self, n_steps: int): + self.n_jit_steps = n_steps + def _check_batch_size(self): if self.batch_size is None: raise ValueError("Dataset Batch Size has not been set yet") @@ -208,7 +212,7 @@ def steps_per_epoch(self) -> int: number of steps, and all batches have the same length. To do so, some training data are dropped in each epoch. """ - return self.n_data // self.batch_size + return self.n_data // self.batch_size // self.n_jit_steps def init_input(self) -> Dict[str, np.ndarray]: """Returns first batch of inputs and labels to init the model.""" @@ -240,15 +244,18 @@ def shuffle_and_batch(self) -> Iterator[jax.Array]: Iterator that returns inputs and labels of one batch in each step. """ self._check_batch_size() - shuffled_ds = ( + ds = ( self.ds.shuffle(buffer_size=self.buffer_size) .repeat(self.n_epoch) .batch(batch_size=self.batch_size) .map(PadToSpecificSize(self.max_atoms, self.max_nbrs)) ) - shuffled_ds = prefetch_to_single_device(shuffled_ds.as_numpy_iterator(), 2) - return shuffled_ds + if self.n_jit_steps > 1: + ds = ds.batch(batch_size=self.n_jit_steps) + + ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) + return ds def batch(self) -> Iterator[jax.Array]: self._check_batch_size() diff --git a/apax/train/metrics.py b/apax/train/metrics.py index 151cfe68..e5067b0f 100644 --- a/apax/train/metrics.py +++ b/apax/train/metrics.py @@ -9,7 +9,13 @@ log = logging.getLogger(__name__) -class RootAverage(metrics.Average): +class Averagefp64(metrics.Average): + @classmethod + def empty(cls) -> metrics.Metric: + return cls(total=jnp.array(0, jnp.float64), count=jnp.array(0, jnp.int64)) + + +class RootAverage(Averagefp64): """ Modifies the `compute` method of `metrics.Average` to obtain the root of the average. Meant to be used with `mse_fn`. @@ -59,7 +65,7 @@ def make_single_metric(key: str, reduction: str) -> metrics.Average: if reduction == "rmse": metric = RootAverage else: - metric = metrics.Average + metric = Averagefp64 reduction_fn = reduction_fns[reduction] reduction_fn = partial(reduction_fn, key=key) diff --git a/apax/train/run.py b/apax/train/run.py index 36e4e15f..a4015f9a 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -117,4 +117,5 @@ def run(user_config, log_level="error"): patience=config.patience, disable_pbar=config.progress_bar.disable_epoch_pbar, is_ensemble=config.n_models > 1, + n_jitted_steps=config.n_jitted_steps, ) diff --git a/apax/train/trainer.py b/apax/train/trainer.py index 867ee50e..684203eb 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -1,13 +1,16 @@ +import functools import logging import time from functools import partial -from typing import Callable +from typing import Callable, Optional import jax import jax.numpy as jnp import numpy as np +from clu import metrics from tqdm import trange +from apax.data.input_pipeline import AtomisticDataset from apax.train.checkpoints import CheckpointManager, load_state log = logging.getLogger(__name__) @@ -15,18 +18,19 @@ def fit( state, - train_ds, + train_ds: AtomisticDataset, loss_fn, - Metrics, - callbacks, - n_epochs, + Metrics: metrics.Collection, + callbacks: list, + n_epochs: int, ckpt_dir, ckpt_interval: int = 1, - val_ds=None, + val_ds: Optional[AtomisticDataset] = None, sam_rho=0.0, - patience=None, + patience: Optional[int] = None, disable_pbar: bool = False, is_ensemble=False, + n_jitted_steps=1, ): log.info("Beginning Training") callbacks.on_train_begin() @@ -38,6 +42,8 @@ def fit( train_step, val_step = make_step_fns( loss_fn, Metrics, model=state.apply_fn, sam_rho=sam_rho, is_ensemble=is_ensemble ) + if n_jitted_steps > 1: + train_step = jax.jit(functools.partial(jax.lax.scan, train_step)) state, start_epoch = load_state(state, latest_dir) if start_epoch >= n_epochs: @@ -45,6 +51,7 @@ def fit( f"n_epochs <= current epoch from checkpoint ({n_epochs} <= {start_epoch})" ) + train_ds.batch_multiple_steps(n_jitted_steps) train_steps_per_epoch = train_ds.steps_per_epoch() batch_train_ds = train_ds.shuffle_and_batch() @@ -68,12 +75,16 @@ def fit( for batch_idx in range(train_steps_per_epoch): callbacks.on_train_batch_begin(batch=batch_idx) - inputs, labels = next(batch_train_ds) - batch_loss, train_batch_metrics, state = train_step( - state, inputs, labels, train_batch_metrics + batch = next(batch_train_ds) + ( + (state, train_batch_metrics), + batch_loss, + ) = train_step( + (state, train_batch_metrics), + batch, ) - epoch_loss["train_loss"] += batch_loss + epoch_loss["train_loss"] += jnp.mean(batch_loss) callbacks.on_train_batch_end(batch=batch_idx) epoch_loss["train_loss"] /= train_steps_per_epoch @@ -88,10 +99,10 @@ def fit( epoch_loss.update({"val_loss": 0.0}) val_batch_metrics = Metrics.empty() for batch_idx in range(val_steps_per_epoch): - inputs, labels = next(batch_val_ds) + batch = next(batch_val_ds) batch_loss, val_batch_metrics = val_step( - state.params, inputs, labels, val_batch_metrics + state.params, batch, val_batch_metrics ) epoch_loss["val_loss"] += batch_loss @@ -213,17 +224,22 @@ def update_step(state, inputs, labels): eval_fn = loss_calculator @jax.jit - def train_step(state, inputs, labels, batch_metrics): + def train_step(carry, batch): + state, batch_metrics = carry + inputs, labels = batch loss, predictions, state = update_fn(state, inputs, labels) new_batch_metrics = Metrics.single_from_model_output( label=labels, prediction=predictions ) batch_metrics = batch_metrics.merge(new_batch_metrics) - return loss, batch_metrics, state + + new_carry = (state, batch_metrics) + return new_carry, loss @jax.jit - def val_step(params, inputs, labels, batch_metrics): + def val_step(params, batch, batch_metrics): + inputs, labels = batch loss, predictions = eval_fn(params, inputs, labels) new_batch_metrics = Metrics.single_from_model_output(