Skip to content

Commit

Permalink
Merge pull request #208 from apax-hub/npt_validation
Browse files Browse the repository at this point in the history
NPT validation and MD checkpoints
  • Loading branch information
M-R-Schaefer authored Nov 29, 2023
2 parents 8a3c56b + 10b6061 commit ff7e626
Show file tree
Hide file tree
Showing 10 changed files with 243 additions and 186 deletions.
20 changes: 0 additions & 20 deletions Makefile

This file was deleted.

5 changes: 2 additions & 3 deletions apax/cli/apax_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,15 @@ def md(
..., help="Configuration YAML file that was used to train a model."
),
md_config_path: Path = typer.Argument(..., help="MD configuration YAML file."),
log_level: str = typer.Option("error", help="Sets the training logging level."),
log_file: str = typer.Option("md.log", help="Specifies the name of the log file"),
log_level: str = typer.Option("info", help="Sets the training logging level."),
):
"""
Starts performing a molecular dynamics simulation (currently only NHC thermostat)
with parameters provided by a configuration file.
"""
from apax.md import run_md

run_md(train_config_path, md_config_path, log_file, log_level)
run_md(train_config_path, md_config_path, log_level)


@app.command()
Expand Down
9 changes: 8 additions & 1 deletion apax/config/md_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
n_inner: Number of compiled simulation steps (i.e. number of iterations of the
`jax.lax.fori_loop` loop). Also determines atoms buffer size.
sampling_rate:
Trajectory dumping interval.
Interval between saving frames.
buffer_size:
Number of collected frames to be dumped at once.
dr_threshold: Skin of the neighborlist.
extra_capacity: JaxMD allocates a maximal number of neighbors.
This argument lets you add additional capacity to avoid recompilation.
Expand All @@ -56,6 +58,9 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
traj_name: Name of the trajectory file.
restart: Whether the simulation should restart from the latest configuration
in `traj_name`.
checkpoint_interval: Number of time steps between saving
full simulation state checkpoints. These will be loaded
with the `restart` option.
disable_pbar: Disables the MD progressbar.
"""

Expand All @@ -69,6 +74,7 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
duration: PositiveFloat
n_inner: PositiveInt = 100
sampling_rate: PositiveInt = 10
buffer_size: PositiveInt = 100
dr_threshold: PositiveFloat = 0.5
extra_capacity: PositiveInt = 0

Expand All @@ -77,6 +83,7 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
sim_dir: str = "."
traj_name: str = "md.h5"
restart: bool = True
checkpoint_interval: int = 50_000
disable_pbar: bool = False

def dump_config(self):
Expand Down
71 changes: 62 additions & 9 deletions apax/md/io.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
import logging
from pathlib import Path

import h5py
import numpy as np
import znh5md
from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator
from jax_md.space import transform

from apax.md.sim_utils import System

log = logging.getLogger(__name__)


class TrajHandler:
def __init__(self) -> None:
self.system: System
self.sampling_rate: int
self.buffer_size: int
self.traj_path: Path
self.time_step: float

def step(self, state_and_energy, transform):
pass

Expand Down Expand Up @@ -41,37 +56,75 @@ def atoms_from_state(self, state, energy, nbr_kwargs):


class H5TrajHandler(TrajHandler):
def __init__(self, system, sampling_rate, traj_path) -> None:
def __init__(
self,
system: System,
sampling_rate: int,
buffer_size: int,
traj_path: Path,
time_step: float = 0.5,
) -> None:
self.atomic_numbers = system.atomic_numbers
self.box = system.box
self.fractional = np.any(self.box < 1e-6)
self.sampling_rate = sampling_rate
self.traj_path = traj_path
self.db = znh5md.io.DataWriter(self.traj_path)
self.db.initialize_database_groups()
if not self.traj_path.is_file():
log.info(f"Initializing new trajectory file at {self.traj_path}")
self.db.initialize_database_groups()
self.time_step = time_step

self.sampling_counter = 1
self.step_counter = 0
self.buffer = []
self.buffer_size = buffer_size

def reset_buffer(self):
self.buffer = []

def step(self, state, transform):
state, energy, nbr_kwargs = state

if self.sampling_counter < self.sampling_rate:
self.sampling_counter += 1
else:
if self.step_counter % self.sampling_rate == 0:
new_atoms = self.atoms_from_state(state, energy, nbr_kwargs)
self.buffer.append(new_atoms)
self.sampling_counter = 1
self.step_counter += 1

if len(self.buffer) >= self.buffer_size:
self.write()

def write(self, x=None, transform=None):
if len(self.buffer) > 0:
reader = znh5md.io.AtomsReader(
self.buffer,
step=1,
time=self.sampling_rate,
step=self.time_step,
time=self.time_step * self.step_counter,
frames_per_chunk=self.buffer_size,
)
self.db.add(reader)
self.reset_buffer()


class DSTruncator:
def __init__(self, length):
self.length = length
self.node_names = []

def __call__(self, name, node):
if isinstance(node, h5py.Dataset):
if len(node.shape) > 1 or name.endswith("energy/value"):
self.node_names.append(name)

def truncate(self, ds):
for name in self.node_names:
shape = tuple([None] + list(ds[name].shape[1:]))
truncated_data = ds[name][: self.length]
del ds[name]
ds.create_dataset(name, maxshape=shape, data=truncated_data, chunks=True)


def truncate_trajectory_to_checkpoint(traj_path, length):
truncator = DSTruncator(length=length)
with h5py.File(traj_path, "r+") as ds:
ds.visititems(truncator)
truncator.truncate(ds)
16 changes: 9 additions & 7 deletions apax/md/md_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import logging
from pathlib import Path

from flax.training import checkpoints
from jaxtyping import PyTree

log = logging.getLogger(__name__)


def load_md_state(sim_dir):
# TODO: not functional yet
def load_md_state(state: PyTree, ckpt_dir: Path) -> tuple[PyTree, int]:
try:
log.info("loading previous md state")
raw_restored = checkpoints.restore_checkpoint(sim_dir, target=None, step=None)
log.info(f"loading MD state from {ckpt_dir}")
target = {"state": state, "step": 0}
restored_ckpt = checkpoints.restore_checkpoint(ckpt_dir, target=target, step=None)
except FileNotFoundError:
print(f"No checkpoint found at {sim_dir}")
state = raw_restored["state"]
step = raw_restored["step"]
print(f"No checkpoint found at {ckpt_dir}")
state = restored_ckpt["state"]
step = restored_ckpt["step"]
return state, step
Loading

0 comments on commit ff7e626

Please sign in to comment.