Skip to content

Commit

Permalink
added sketch of multi step jit
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Jan 2, 2024
1 parent 1c3e2d5 commit 9f21feb
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 20 deletions.
16 changes: 12 additions & 4 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def __init__(
"""
self.n_epoch = n_epoch
self.batch_size = None
self.n_jit_steps = 1
self.buffer_size = buffer_size

max_atoms, max_nbrs = find_largest_system(inputs)
Expand All @@ -186,6 +187,9 @@ def __init__(

def set_batch_size(self, batch_size: int):
self.batch_size = self.validate_batch_size(batch_size)

def batch_multiple_steps(self, n_steps: int):
self.n_jit_steps = n_steps

def _check_batch_size(self):
if self.batch_size is None:
Expand All @@ -208,7 +212,7 @@ def steps_per_epoch(self) -> int:
number of steps, and all batches have the same length. To do so, some training
data are dropped in each epoch.
"""
return self.n_data // self.batch_size
return self.n_data // self.batch_size // self.n_jit_steps

def init_input(self) -> Dict[str, np.ndarray]:
"""Returns first batch of inputs and labels to init the model."""
Expand Down Expand Up @@ -240,15 +244,19 @@ def shuffle_and_batch(self) -> Iterator[jax.Array]:
Iterator that returns inputs and labels of one batch in each step.
"""
self._check_batch_size()
shuffled_ds = (
ds = (
self.ds.shuffle(buffer_size=self.buffer_size)
.repeat(self.n_epoch)
.batch(batch_size=self.batch_size)
.map(PadToSpecificSize(self.max_atoms, self.max_nbrs))
)

shuffled_ds = prefetch_to_single_device(shuffled_ds.as_numpy_iterator(), 2)
return shuffled_ds
if self.n_jit_steps > 1:
print("JIT BATCH")
ds = ds.batch(batch_size=self.n_jit_steps)

ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
return ds

def batch(self) -> Iterator[jax.Array]:
self._check_batch_size()
Expand Down
49 changes: 33 additions & 16 deletions apax/train/trainer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import functools
import logging
import time
from functools import partial
from typing import Callable
from typing import Callable, Optional

import jax
import jax.numpy as jnp
from clu import metrics
import numpy as np
from tqdm import trange
from apax.data.input_pipeline import AtomisticDataset

from apax.train.checkpoints import CheckpointManager, load_state

Expand All @@ -15,16 +18,16 @@

def fit(
state,
train_ds,
train_ds: AtomisticDataset,
loss_fn,
Metrics,
callbacks,
n_epochs,
Metrics: metrics.Collection,
callbacks: list,
n_epochs: int,
ckpt_dir,
ckpt_interval: int = 1,
val_ds=None,
val_ds: Optional[AtomisticDataset] =None,
sam_rho=0.0,
patience=None,
patience: Optional[int]=None,
disable_pbar: bool = False,
is_ensemble=False,
):
Expand All @@ -35,18 +38,26 @@ def fit(
best_dir = ckpt_dir / "best"
ckpt_manager = CheckpointManager()

n_jitted_steps = 1

train_step, val_step = make_step_fns(
loss_fn, Metrics, model=state.apply_fn, sam_rho=sam_rho, is_ensemble=is_ensemble
)
if n_jitted_steps > 1:
train_step = jax.jit(functools.partial(jax.lax.scan, train_step))

state, start_epoch = load_state(state, latest_dir)
if start_epoch >= n_epochs:
raise ValueError(
f"n_epochs <= current epoch from checkpoint ({n_epochs} <= {start_epoch})"
)

train_ds.batch_multiple_steps(n_jitted_steps)
train_steps_per_epoch = train_ds.steps_per_epoch()
batch_train_ds = train_ds.shuffle_and_batch()
# inputs, labels = next(batch_train_ds)
# print(jax.tree_map(lambda x: x.shape, inputs))
# quit()

if val_ds is not None:
val_steps_per_epoch = val_ds.steps_per_epoch()
Expand All @@ -68,12 +79,13 @@ def fit(
for batch_idx in range(train_steps_per_epoch):
callbacks.on_train_batch_begin(batch=batch_idx)

inputs, labels = next(batch_train_ds)
batch_loss, train_batch_metrics, state = train_step(
state, inputs, labels, train_batch_metrics
batch = next(batch_train_ds)
(state, train_batch_metrics), batch_loss, = train_step(
(state, train_batch_metrics), batch,
)
print(batch_loss)

epoch_loss["train_loss"] += batch_loss
epoch_loss["train_loss"] += jnp.mean(batch_loss)
callbacks.on_train_batch_end(batch=batch_idx)

epoch_loss["train_loss"] /= train_steps_per_epoch
Expand All @@ -88,10 +100,10 @@ def fit(
epoch_loss.update({"val_loss": 0.0})
val_batch_metrics = Metrics.empty()
for batch_idx in range(val_steps_per_epoch):
inputs, labels = next(batch_val_ds)
batch = next(batch_val_ds)

batch_loss, val_batch_metrics = val_step(
state.params, inputs, labels, val_batch_metrics
state.params, batch, val_batch_metrics
)
epoch_loss["val_loss"] += batch_loss

Expand Down Expand Up @@ -213,17 +225,22 @@ def update_step(state, inputs, labels):
eval_fn = loss_calculator

@jax.jit
def train_step(state, inputs, labels, batch_metrics):
def train_step(carry, batch):
state, batch_metrics = carry
inputs, labels = batch
loss, predictions, state = update_fn(state, inputs, labels)

new_batch_metrics = Metrics.single_from_model_output(
label=labels, prediction=predictions
)
batch_metrics = batch_metrics.merge(new_batch_metrics)
return loss, batch_metrics, state

new_carry = (state, batch_metrics)
return new_carry, loss

@jax.jit
def val_step(params, inputs, labels, batch_metrics):
def val_step(params, batch, batch_metrics):
inputs, labels = batch
loss, predictions = eval_fn(params, inputs, labels)

new_batch_metrics = Metrics.single_from_model_output(
Expand Down

0 comments on commit 9f21feb

Please sign in to comment.