Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NPT validation and MD checkpoints #208

Merged
merged 26 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
40a4961
made cell explicit in simulation loop
M-R-Schaefer Nov 13, 2023
71e9029
h5trajhanderl now records actual sim time
M-R-Schaefer Nov 14, 2023
f9d5375
imported logging setup from trianing into MD
M-R-Schaefer Nov 14, 2023
797d5b4
pulled trajectory handelr out of `run_nvt`
M-R-Schaefer Nov 14, 2023
17570a5
removed old sampling interval check comment
M-R-Schaefer Nov 14, 2023
cf1f058
moved directory creation to top level
M-R-Schaefer Nov 14, 2023
e67c561
fixed nvt apply fn calling signature and sim dir creation
M-R-Schaefer Nov 14, 2023
9d1860a
removed log gile CLI and MD option
M-R-Schaefer Nov 14, 2023
6ef18aa
removed unused box arg from create energy fn
M-R-Schaefer Nov 14, 2023
37b7ebf
Merge branch 'dev' into npt_validation
M-R-Schaefer Nov 14, 2023
bd6cac3
added checkpoint interval to md config and md code
M-R-Schaefer Nov 14, 2023
5bd0b24
removed duplicate docs makefiles
M-R-Schaefer Nov 14, 2023
83607a9
implemented saving and loading of MD checkpoints
M-R-Schaefer Nov 15, 2023
c4a670e
pbar now correctly shows 100 percent on completion
M-R-Schaefer Nov 15, 2023
a0341ec
removed debug statements
M-R-Schaefer Nov 15, 2023
6e84d19
set default md log level to info
M-R-Schaefer Nov 15, 2023
94eb969
refactored checkpoint and momenta loading into separete function
M-R-Schaefer Nov 15, 2023
ccc39d1
moved System and SimFunctions to separate submodule
M-R-Schaefer Nov 15, 2023
b571222
utility logging redirect for logs during tqdm pbar
M-R-Schaefer Nov 15, 2023
81af8ee
H5TrajHandler now appends to trajectory if an existing one is found
M-R-Schaefer Nov 15, 2023
56d942d
DSTruncater now initializes chunked datasets, added TrajHandler type …
M-R-Schaefer Nov 15, 2023
1a61a7b
added trajectory truncation on checkpoint loading and time printing i…
M-R-Schaefer Nov 15, 2023
be9a134
linting
M-R-Schaefer Nov 15, 2023
b8bb72f
added explicit buffer size argument to TrajHandler
M-R-Schaefer Nov 15, 2023
2db1501
added checkpoint loading to MD integration test
M-R-Schaefer Nov 15, 2023
10b6061
linting
M-R-Schaefer Nov 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions Makefile

This file was deleted.

5 changes: 2 additions & 3 deletions apax/cli/apax_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,15 @@ def md(
..., help="Configuration YAML file that was used to train a model."
),
md_config_path: Path = typer.Argument(..., help="MD configuration YAML file."),
log_level: str = typer.Option("error", help="Sets the training logging level."),
log_file: str = typer.Option("md.log", help="Specifies the name of the log file"),
log_level: str = typer.Option("info", help="Sets the training logging level."),
):
"""
Starts performing a molecular dynamics simulation (currently only NHC thermostat)
with parameters provided by a configuration file.
"""
from apax.md import run_md

run_md(train_config_path, md_config_path, log_file, log_level)
run_md(train_config_path, md_config_path, log_level)


@app.command()
Expand Down
9 changes: 8 additions & 1 deletion apax/config/md_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
n_inner: Number of compiled simulation steps (i.e. number of iterations of the
`jax.lax.fori_loop` loop). Also determines atoms buffer size.
sampling_rate:
Trajectory dumping interval.
Interval between saving frames.
buffer_size:
Number of collected frames to be dumped at once.
dr_threshold: Skin of the neighborlist.
extra_capacity: JaxMD allocates a maximal number of neighbors.
This argument lets you add additional capacity to avoid recompilation.
Expand All @@ -56,6 +58,9 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
traj_name: Name of the trajectory file.
restart: Whether the simulation should restart from the latest configuration
in `traj_name`.
checkpoint_interval: Number of time steps between saving
full simulation state checkpoints. These will be loaded
with the `restart` option.
disable_pbar: Disables the MD progressbar.
"""

Expand All @@ -69,6 +74,7 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
duration: PositiveFloat
n_inner: PositiveInt = 100
sampling_rate: PositiveInt = 10
buffer_size: PositiveInt = 100
dr_threshold: PositiveFloat = 0.5
extra_capacity: PositiveInt = 0

Expand All @@ -77,6 +83,7 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
sim_dir: str = "."
traj_name: str = "md.h5"
restart: bool = True
checkpoint_interval: int = 50_000
disable_pbar: bool = False

def dump_config(self):
Expand Down
71 changes: 62 additions & 9 deletions apax/md/io.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
import logging
from pathlib import Path

import h5py
import numpy as np
import znh5md
from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator
from jax_md.space import transform

from apax.md.sim_utils import System

log = logging.getLogger(__name__)


class TrajHandler:
def __init__(self) -> None:
self.system: System
self.sampling_rate: int
self.buffer_size: int
self.traj_path: Path
self.time_step: float

def step(self, state_and_energy, transform):
pass

Expand Down Expand Up @@ -41,37 +56,75 @@ def atoms_from_state(self, state, energy, nbr_kwargs):


class H5TrajHandler(TrajHandler):
def __init__(self, system, sampling_rate, traj_path) -> None:
def __init__(
self,
system: System,
sampling_rate: int,
buffer_size: int,
traj_path: Path,
time_step: float = 0.5,
) -> None:
self.atomic_numbers = system.atomic_numbers
self.box = system.box
self.fractional = np.any(self.box < 1e-6)
self.sampling_rate = sampling_rate
self.traj_path = traj_path
self.db = znh5md.io.DataWriter(self.traj_path)
self.db.initialize_database_groups()
if not self.traj_path.is_file():
log.info(f"Initializing new trajectory file at {self.traj_path}")
self.db.initialize_database_groups()
self.time_step = time_step

self.sampling_counter = 1
self.step_counter = 0
self.buffer = []
self.buffer_size = buffer_size

def reset_buffer(self):
self.buffer = []

def step(self, state, transform):
state, energy, nbr_kwargs = state

if self.sampling_counter < self.sampling_rate:
self.sampling_counter += 1
else:
if self.step_counter % self.sampling_rate == 0:
new_atoms = self.atoms_from_state(state, energy, nbr_kwargs)
self.buffer.append(new_atoms)
self.sampling_counter = 1
self.step_counter += 1

if len(self.buffer) >= self.buffer_size:
self.write()

def write(self, x=None, transform=None):
if len(self.buffer) > 0:
reader = znh5md.io.AtomsReader(
self.buffer,
step=1,
time=self.sampling_rate,
step=self.time_step,
time=self.time_step * self.step_counter,
frames_per_chunk=self.buffer_size,
)
self.db.add(reader)
self.reset_buffer()


class DSTruncator:
def __init__(self, length):
self.length = length
self.node_names = []

def __call__(self, name, node):
if isinstance(node, h5py.Dataset):
if len(node.shape) > 1 or name.endswith("energy/value"):
self.node_names.append(name)

def truncate(self, ds):
for name in self.node_names:
shape = tuple([None] + list(ds[name].shape[1:]))
truncated_data = ds[name][: self.length]
del ds[name]
ds.create_dataset(name, maxshape=shape, data=truncated_data, chunks=True)


def truncate_trajectory_to_checkpoint(traj_path, length):
truncator = DSTruncator(length=length)
with h5py.File(traj_path, "r+") as ds:
ds.visititems(truncator)
truncator.truncate(ds)
16 changes: 9 additions & 7 deletions apax/md/md_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import logging
from pathlib import Path

from flax.training import checkpoints
from jaxtyping import PyTree

log = logging.getLogger(__name__)


def load_md_state(sim_dir):
# TODO: not functional yet
def load_md_state(state: PyTree, ckpt_dir: Path) -> tuple[PyTree, int]:
try:
log.info("loading previous md state")
raw_restored = checkpoints.restore_checkpoint(sim_dir, target=None, step=None)
log.info(f"loading MD state from {ckpt_dir}")
target = {"state": state, "step": 0}
restored_ckpt = checkpoints.restore_checkpoint(ckpt_dir, target=target, step=None)
except FileNotFoundError:
print(f"No checkpoint found at {sim_dir}")
state = raw_restored["state"]
step = raw_restored["step"]
print(f"No checkpoint found at {ckpt_dir}")
state = restored_ckpt["state"]
step = restored_ckpt["step"]
return state, step
Loading