Skip to content

Commit

Permalink
Merge pull request #223 from apax-hub/dev
Browse files Browse the repository at this point in the history
Version 0.3.0 changes
  • Loading branch information
M-R-Schaefer authored Jan 17, 2024
2 parents a27da8e + 2b51835 commit 7d9e872
Show file tree
Hide file tree
Showing 12 changed files with 1,119 additions and 1,006 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/linting.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
uses: psf/black@stable
with:
src: "./apax"
version: "22.10.0"
version: "22.12.0"

isort:
runs-on: ubuntu-latest
Expand All @@ -25,7 +25,7 @@ jobs:

- name: Install isort
run: |
pip install isort==5.10.1
pip install isort==5.12.0
- name: run isort
run: |
Expand Down
5 changes: 5 additions & 0 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,11 @@ class Config(BaseModel, frozen=True, extra="forbid"):
----------
n_epochs: Number of training epochs.
patience: Number of epochs without improvement before trainings gets terminated.
seed: Random seed.
n_models: Number of models to be trained at once.
n_jitted_steps: Number of train batches to be processed in a compiled loop.
Can yield singificant speedups for small structures or small batch sizes.
data: :class: `Data` <config.DataConfig> configuration.
model: :class: `Model` <config.ModelConfig> configuration.
metrics: List of :class: `metric` <config.MetricsConfig> configurations.
Expand All @@ -294,6 +298,7 @@ class Config(BaseModel, frozen=True, extra="forbid"):
patience: Optional[PositiveInt] = None
seed: int = 1
n_models: int = 1
n_jitted_steps: int = 1

data: DataConfig
model: ModelConfig = ModelConfig()
Expand Down
15 changes: 11 additions & 4 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def __init__(
"""
self.n_epoch = n_epoch
self.batch_size = None
self.n_jit_steps = 1
self.buffer_size = buffer_size

max_atoms, max_nbrs = find_largest_system(inputs)
Expand All @@ -187,6 +188,9 @@ def __init__(
def set_batch_size(self, batch_size: int):
self.batch_size = self.validate_batch_size(batch_size)

def batch_multiple_steps(self, n_steps: int):
self.n_jit_steps = n_steps

def _check_batch_size(self):
if self.batch_size is None:
raise ValueError("Dataset Batch Size has not been set yet")
Expand All @@ -208,7 +212,7 @@ def steps_per_epoch(self) -> int:
number of steps, and all batches have the same length. To do so, some training
data are dropped in each epoch.
"""
return self.n_data // self.batch_size
return self.n_data // self.batch_size // self.n_jit_steps

def init_input(self) -> Dict[str, np.ndarray]:
"""Returns first batch of inputs and labels to init the model."""
Expand Down Expand Up @@ -240,15 +244,18 @@ def shuffle_and_batch(self) -> Iterator[jax.Array]:
Iterator that returns inputs and labels of one batch in each step.
"""
self._check_batch_size()
shuffled_ds = (
ds = (
self.ds.shuffle(buffer_size=self.buffer_size)
.repeat(self.n_epoch)
.batch(batch_size=self.batch_size)
.map(PadToSpecificSize(self.max_atoms, self.max_nbrs))
)

shuffled_ds = prefetch_to_single_device(shuffled_ds.as_numpy_iterator(), 2)
return shuffled_ds
if self.n_jit_steps > 1:
ds = ds.batch(batch_size=self.n_jit_steps)

ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
return ds

def batch(self) -> Iterator[jax.Array]:
self._check_batch_size()
Expand Down
5 changes: 5 additions & 0 deletions apax/md/ase_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
self.model_config, self.params = restore_parameters(model_dir)
self.n_models = check_for_ensemble(self.params)
self.padding_factor = padding_factor
self.padded_length = 0

if self.model_config.model.calc_stress:
self.implemented_properties.append("stress")
Expand Down Expand Up @@ -148,6 +149,10 @@ def initialize(self, atoms):
self.step = get_step_fn(model, atoms, self.neigbor_from_jax)
self.neighbor_fn = neighbor_fn

if self.neigbor_from_jax:
positions = jnp.asarray(atoms.positions, dtype=jnp.float64)
self.neighbors = self.neighbor_fn.allocate(positions)

def set_neighbours_and_offsets(self, atoms, box):
idxs_i, idxs_j, offsets = neighbour_list("ijS", atoms, self.r_max)

Expand Down
2 changes: 1 addition & 1 deletion apax/md/nvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def body_fn(i, state):
)
ckpt = {"state": state, "step": step}
checkpoints.save_checkpoint(
ckpt_dir=ckpt_dir,
ckpt_dir=ckpt_dir.resolve(),
target=ckpt,
step=step,
overwrite=True,
Expand Down
10 changes: 7 additions & 3 deletions apax/train/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ class CheckpointManager:
def __init__(self) -> None:
self.async_manager = checkpoints.AsyncManager()

def save_checkpoint(self, ckpt, epoch: int, path: str) -> None:
def save_checkpoint(self, ckpt, epoch: int, path: Path) -> None:
checkpoints.save_checkpoint(
ckpt_dir=path,
ckpt_dir=path.resolve(),
target=ckpt,
step=epoch,
overwrite=True,
Expand Down Expand Up @@ -147,7 +147,11 @@ def restore_single_parameters(model_dir: Path) -> Tuple[Config, FrozenDict]:
"""
model_dir = Path(model_dir)
model_config = parse_config(model_dir / "config.yaml")
model_config.data.directory = model_dir.parent.resolve().as_posix()

if model_config.data.experiment == "":
model_config.data.directory = model_dir.resolve().as_posix()
else:
model_config.data.directory = model_dir.parent.resolve().as_posix()

ckpt_dir = model_config.data.model_version_path
return model_config, load_params(ckpt_dir)
Expand Down
9 changes: 9 additions & 0 deletions apax/train/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,21 @@ def force_angle_exponential_weight(
return (1.0 - dotp) * jnp.exp(-F_0_norm) / divisor


def stress_tril(label, prediction, divisor=1.0):
idxs = jnp.tril_indices(3)
label_tril = label[:, idxs[0], idxs[1]]
prediction_tril = prediction[:, idxs[0], idxs[1]]
return (label_tril - prediction_tril) ** 2 / divisor


loss_functions = {
"molecules": weighted_squared_error,
"structures": weighted_squared_error,
"vibrations": weighted_squared_error,
"cosine_sim": force_angle_loss,
"cosine_sim_div_magnitude": force_angle_div_force_label,
"cosine_sim_exp_magnitude": force_angle_exponential_weight,
"tril": stress_tril,
}


Expand Down Expand Up @@ -101,6 +109,7 @@ def determine_divisor(self, n_atoms: jnp.array) -> jnp.array:
n_atoms, "batch -> batch 1 1"
),
"stress_structures": einops.repeat(n_atoms**2, "batch -> batch 1 1"),
"stress_tril": einops.repeat(n_atoms**2, "batch -> batch 1 1"),
"stress_vibrations": einops.repeat(n_atoms, "batch -> batch 1 1"),
}
divisor = divisor_dict.get(divisor_id, jnp.array(1.0))
Expand Down
10 changes: 8 additions & 2 deletions apax/train/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
log = logging.getLogger(__name__)


class RootAverage(metrics.Average):
class Averagefp64(metrics.Average):
@classmethod
def empty(cls) -> metrics.Metric:
return cls(total=jnp.array(0, jnp.float64), count=jnp.array(0, jnp.int64))


class RootAverage(Averagefp64):
"""
Modifies the `compute` method of `metrics.Average` to obtain the root of the average.
Meant to be used with `mse_fn`.
Expand Down Expand Up @@ -59,7 +65,7 @@ def make_single_metric(key: str, reduction: str) -> metrics.Average:
if reduction == "rmse":
metric = RootAverage
else:
metric = metrics.Average
metric = Averagefp64

reduction_fn = reduction_fns[reduction]
reduction_fn = partial(reduction_fn, key=key)
Expand Down
1 change: 1 addition & 0 deletions apax/train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,5 @@ def run(user_config, log_level="error"):
patience=config.patience,
disable_pbar=config.progress_bar.disable_epoch_pbar,
is_ensemble=config.n_models > 1,
n_jitted_steps=config.n_jitted_steps,
)
48 changes: 32 additions & 16 deletions apax/train/trainer.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,36 @@
import functools
import logging
import time
from functools import partial
from typing import Callable
from typing import Callable, Optional

import jax
import jax.numpy as jnp
import numpy as np
from clu import metrics
from tqdm import trange

from apax.data.input_pipeline import AtomisticDataset
from apax.train.checkpoints import CheckpointManager, load_state

log = logging.getLogger(__name__)


def fit(
state,
train_ds,
train_ds: AtomisticDataset,
loss_fn,
Metrics,
callbacks,
n_epochs,
Metrics: metrics.Collection,
callbacks: list,
n_epochs: int,
ckpt_dir,
ckpt_interval: int = 1,
val_ds=None,
val_ds: Optional[AtomisticDataset] = None,
sam_rho=0.0,
patience=None,
patience: Optional[int] = None,
disable_pbar: bool = False,
is_ensemble=False,
n_jitted_steps=1,
):
log.info("Beginning Training")
callbacks.on_train_begin()
Expand All @@ -38,13 +42,16 @@ def fit(
train_step, val_step = make_step_fns(
loss_fn, Metrics, model=state.apply_fn, sam_rho=sam_rho, is_ensemble=is_ensemble
)
if n_jitted_steps > 1:
train_step = jax.jit(functools.partial(jax.lax.scan, train_step))

state, start_epoch = load_state(state, latest_dir)
if start_epoch >= n_epochs:
raise ValueError(
f"n_epochs <= current epoch from checkpoint ({n_epochs} <= {start_epoch})"
)

train_ds.batch_multiple_steps(n_jitted_steps)
train_steps_per_epoch = train_ds.steps_per_epoch()
batch_train_ds = train_ds.shuffle_and_batch()

Expand All @@ -68,12 +75,16 @@ def fit(
for batch_idx in range(train_steps_per_epoch):
callbacks.on_train_batch_begin(batch=batch_idx)

inputs, labels = next(batch_train_ds)
batch_loss, train_batch_metrics, state = train_step(
state, inputs, labels, train_batch_metrics
batch = next(batch_train_ds)
(
(state, train_batch_metrics),
batch_loss,
) = train_step(
(state, train_batch_metrics),
batch,
)

epoch_loss["train_loss"] += batch_loss
epoch_loss["train_loss"] += jnp.mean(batch_loss)
callbacks.on_train_batch_end(batch=batch_idx)

epoch_loss["train_loss"] /= train_steps_per_epoch
Expand All @@ -88,10 +99,10 @@ def fit(
epoch_loss.update({"val_loss": 0.0})
val_batch_metrics = Metrics.empty()
for batch_idx in range(val_steps_per_epoch):
inputs, labels = next(batch_val_ds)
batch = next(batch_val_ds)

batch_loss, val_batch_metrics = val_step(
state.params, inputs, labels, val_batch_metrics
state.params, batch, val_batch_metrics
)
epoch_loss["val_loss"] += batch_loss

Expand Down Expand Up @@ -213,17 +224,22 @@ def update_step(state, inputs, labels):
eval_fn = loss_calculator

@jax.jit
def train_step(state, inputs, labels, batch_metrics):
def train_step(carry, batch):
state, batch_metrics = carry
inputs, labels = batch
loss, predictions, state = update_fn(state, inputs, labels)

new_batch_metrics = Metrics.single_from_model_output(
label=labels, prediction=predictions
)
batch_metrics = batch_metrics.merge(new_batch_metrics)
return loss, batch_metrics, state

new_carry = (state, batch_metrics)
return new_carry, loss

@jax.jit
def val_step(params, inputs, labels, batch_metrics):
def val_step(params, batch, batch_metrics):
inputs, labels = batch
loss, predictions = eval_fn(params, inputs, labels)

new_batch_metrics = Metrics.single_from_model_output(
Expand Down
Loading

0 comments on commit 7d9e872

Please sign in to comment.