Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

0.2.0 Release #213

Merged
merged 143 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
143 commits
Select commit Hold shift + click to select a range
9e0f827
removed unused DsStats fields
M-R-Schaefer Oct 7, 2023
5a00ae5
made n_species optional in builder
M-R-Schaefer Oct 7, 2023
c0b8539
moved callback handing to separate module
M-R-Schaefer Oct 7, 2023
a65ba81
removed redundant maximize l2cache option
M-R-Schaefer Oct 7, 2023
096024b
removed redundant `n_atoms` from NL precompute. Moved BS validation t…
M-R-Schaefer Oct 7, 2023
503e5d6
preprocessing type hints
M-R-Schaefer Oct 7, 2023
5e3c2e8
removed l2cache option from l2cache
M-R-Schaefer Oct 7, 2023
0fc63fb
renamed `TFPipeline` to `AtomsiticDataset`. Moved DS initialization t…
M-R-Schaefer Oct 7, 2023
f15f29d
fix type hints and tests
M-R-Schaefer Oct 7, 2023
da4f637
moved displacement calculation to model
M-R-Schaefer Oct 19, 2023
5051528
vmapped step functions
M-R-Schaefer Oct 20, 2023
d5b7681
added n_models option to config
M-R-Schaefer Oct 20, 2023
1123d4b
moved conversion of init_inputs to the `init_input` method
M-R-Schaefer Oct 20, 2023
28f9513
linting
M-R-Schaefer Oct 20, 2023
53768fa
moved detection of ensembles from ase calc to checkpoints
M-R-Schaefer Oct 20, 2023
6c6b2fc
added train step functions for ensembles, vmapped train state and par…
M-R-Schaefer Oct 20, 2023
606a6b1
made check for ensemble return number of models instead of bool
M-R-Schaefer Oct 24, 2023
8bacd0e
added ensemble check to eval
M-R-Schaefer Oct 24, 2023
557434a
sketch of ensemble energy fn
M-R-Schaefer Oct 25, 2023
ba862e9
added ensemble energy transformation to jaxmd
M-R-Schaefer Oct 27, 2023
71f0383
linting
M-R-Schaefer Oct 27, 2023
ce4d200
fixed single model param inti
M-R-Schaefer Oct 27, 2023
be20d64
fixed val step argument order
M-R-Schaefer Oct 27, 2023
17dd4a5
fixed loss, pred argument orders in trainer
M-R-Schaefer Oct 27, 2023
d9d3351
updated eval for ensembles
M-R-Schaefer Oct 27, 2023
2d5f98e
added logging for number of models
M-R-Schaefer Oct 27, 2023
0ef8962
ensemble support in jaxmd
M-R-Schaefer Oct 27, 2023
ec4935e
updated jaxmd test
M-R-Schaefer Oct 28, 2023
8d7dd08
Merge branch 'dev' into vmap_ensemble
M-R-Schaefer Oct 28, 2023
a29a010
linting
M-R-Schaefer Oct 28, 2023
f066fbc
fixed model call arguments
M-R-Schaefer Oct 28, 2023
f9eb604
fixed box and disp function args in builder
M-R-Schaefer Oct 28, 2023
331a4b6
fixed dr type converions in descriptor
M-R-Schaefer Oct 28, 2023
9caaac8
adjusted tests for model refactor
M-R-Schaefer Oct 28, 2023
d8c4ee9
fixed BAL computing DS stats
M-R-Schaefer Oct 29, 2023
934192d
Merge branch 'dev' into val_bs_fix
M-R-Schaefer Oct 29, 2023
acc84a4
fixed import in feature maps
M-R-Schaefer Oct 29, 2023
515c0db
removed duplicate function definition
M-R-Schaefer Oct 29, 2023
b02d15d
linting
M-R-Schaefer Oct 29, 2023
9fc9a01
corrected type hint
M-R-Schaefer Oct 31, 2023
3b95a43
added explanation for choice of 119 nspecies
M-R-Schaefer Oct 31, 2023
11571d0
allow sublcasses in canonicalize neighbors
M-R-Schaefer Oct 31, 2023
655db44
removed initialize directories function
M-R-Schaefer Oct 31, 2023
f632d22
Merge pull request #188 from apax-hub/val_bs_fix
M-R-Schaefer Oct 31, 2023
c49e3a2
Merge branch 'dev' into vmap_ensemble
M-R-Schaefer Nov 2, 2023
60acf44
removed unused numpy import
M-R-Schaefer Nov 2, 2023
2f464f0
added key word to initialize dataset
M-R-Schaefer Nov 9, 2023
cc3f1f0
made check for ensemble easier to understand
M-R-Schaefer Nov 9, 2023
b1a6cdd
better function name for inner in train state creation
M-R-Schaefer Nov 9, 2023
6d65dfb
better docs for vmap dimensions
M-R-Schaefer Nov 9, 2023
6e5f326
linting
M-R-Schaefer Nov 9, 2023
f856534
fixed run not creating parent folders
M-R-Schaefer Nov 9, 2023
536651b
Merge pull request #187 from apax-hub/vmap_ensemble
M-R-Schaefer Nov 9, 2023
49f864b
remove duplicate parameter initialization
M-R-Schaefer Nov 9, 2023
b0e3dcb
updated imports, switched to more robust ensemble detection, fixed in…
M-R-Schaefer Nov 9, 2023
9bcdd61
test linting
M-R-Schaefer Nov 9, 2023
54341de
added fixtures and utilities for creating smaple input, intializting …
M-R-Schaefer Nov 9, 2023
08668cf
added integration test for BAL
M-R-Schaefer Nov 9, 2023
3ab6675
linting
M-R-Schaefer Nov 9, 2023
baebe47
removed session scope from remove_tmp_path to avoid clashes between i…
M-R-Schaefer Nov 9, 2023
a25dfdc
added parameter canonicalizationb functions
M-R-Schaefer Nov 9, 2023
fd79280
canonocalize params in nvt and bal
M-R-Schaefer Nov 9, 2023
dd0a80e
adjust tests to cover parameter loading from different model types
M-R-Schaefer Nov 9, 2023
dea322f
Merge pull request #195 from apax-hub/bal_tests
M-R-Schaefer Nov 10, 2023
6ef6ab5
set absl log level to warning
M-R-Schaefer Nov 10, 2023
6772dfb
removed loging from config parsing
M-R-Schaefer Nov 10, 2023
fb256b0
set log level to info, removed log file from cli and run
M-R-Schaefer Nov 10, 2023
0800164
turned model version path into property
M-R-Schaefer Nov 10, 2023
3689443
adapted model verison path preopetry change across project
M-R-Schaefer Nov 10, 2023
da6946e
added safety conversion of model_version_path to Path
M-R-Schaefer Nov 10, 2023
e3b6b97
removed model version path argument from load data files
M-R-Schaefer Nov 10, 2023
19a8794
removed unused os import
M-R-Schaefer Nov 10, 2023
db80629
updated tests to use best model path as a property
M-R-Schaefer Nov 10, 2023
a2c92c9
use best model dir directly in tests
M-R-Schaefer Nov 10, 2023
1247162
added higher level param transfer function to clean up run.py
M-R-Schaefer Nov 11, 2023
e2efa12
added safety model verion path conversion to Path
M-R-Schaefer Nov 11, 2023
972a62d
Fixed bug where transfered parameters are not actually assigned to an…
M-R-Schaefer Nov 11, 2023
ca77454
added example dataset, moved load config and run training to conf test
M-R-Schaefer Nov 11, 2023
47d6962
added transfer learning integration test
M-R-Schaefer Nov 11, 2023
98c5b5f
linting
M-R-Schaefer Nov 11, 2023
fa733ce
added session scoped dataset tmp path
M-R-Schaefer Nov 11, 2023
990fe82
removed superfluous unused NeighborSpoof class
M-R-Schaefer Nov 11, 2023
ff0bed4
unpinned cuda jax version for Flax compatibility
M-R-Schaefer Nov 11, 2023
7814763
restructured readme, added DOI
M-R-Schaefer Nov 11, 2023
e13d122
updated doc dependencies
M-R-Schaefer Nov 12, 2023
b6e902f
fixed some of the doc file not being included
M-R-Schaefer Nov 12, 2023
5e185c6
use furo theme
M-R-Schaefer Nov 13, 2023
ccb903c
add furo theme to conf.py
M-R-Schaefer Nov 13, 2023
28b7857
updated the package description in index.rst
M-R-Schaefer Nov 13, 2023
7745f8e
removed submitted citation
M-R-Schaefer Nov 13, 2023
6649e3f
minor fixes to docstring compatibility in data
M-R-Schaefer Nov 13, 2023
9e2714c
added zenodo badge
M-R-Schaefer Nov 13, 2023
82ad0da
added readthedocs yaml
M-R-Schaefer Nov 13, 2023
17c7431
updated apax docs cli command with rtd link
M-R-Schaefer Nov 13, 2023
a4cd3e7
added docs badge to README and link to pyproject
M-R-Schaefer Nov 13, 2023
40a4961
made cell explicit in simulation loop
M-R-Schaefer Nov 13, 2023
8f7e3c6
cli now fails when trying to override config files
M-R-Schaefer Nov 13, 2023
f059279
added tests for basic cli functionality
M-R-Schaefer Nov 13, 2023
12d5485
updated MD config template
M-R-Schaefer Nov 14, 2023
061f943
removed if name equals main section from cli
M-R-Schaefer Nov 14, 2023
400677f
added an integration test for the CLI
M-R-Schaefer Nov 14, 2023
71e9029
h5trajhanderl now records actual sim time
M-R-Schaefer Nov 14, 2023
f9d5375
imported logging setup from trianing into MD
M-R-Schaefer Nov 14, 2023
797d5b4
pulled trajectory handelr out of `run_nvt`
M-R-Schaefer Nov 14, 2023
17570a5
removed old sampling interval check comment
M-R-Schaefer Nov 14, 2023
cf1f058
moved directory creation to top level
M-R-Schaefer Nov 14, 2023
e67c561
fixed nvt apply fn calling signature and sim dir creation
M-R-Schaefer Nov 14, 2023
b90fdb2
Merge pull request #202 from apax-hub/doi
M-R-Schaefer Nov 14, 2023
9d1860a
removed log gile CLI and MD option
M-R-Schaefer Nov 14, 2023
7f61d0d
Merge pull request #197 from apax-hub/absl_logging_fix
M-R-Schaefer Nov 14, 2023
6ca78bd
Merge branch 'dev' into docs
M-R-Schaefer Nov 14, 2023
aadd2b4
Merge pull request #201 from apax-hub/unpin_jax
M-R-Schaefer Nov 14, 2023
57e7ad1
Merge pull request #206 from apax-hub/cli_test
M-R-Schaefer Nov 14, 2023
6ef18aa
removed unused box arg from create energy fn
M-R-Schaefer Nov 14, 2023
378a234
Merge pull request #204 from apax-hub/docs
M-R-Schaefer Nov 14, 2023
37b7ebf
Merge branch 'dev' into npt_validation
M-R-Schaefer Nov 14, 2023
bd6cac3
added checkpoint interval to md config and md code
M-R-Schaefer Nov 14, 2023
684c97b
Merge branch 'dev' into tl_test
M-R-Schaefer Nov 14, 2023
5bd0b24
removed duplicate docs makefiles
M-R-Schaefer Nov 14, 2023
ea67abd
fix shrink_wrapped_cell
Tetracarbonylnickel Nov 14, 2023
5e6c597
Merge branch 'dev' into fix-neighbour-calc
Tetracarbonylnickel Nov 14, 2023
d5bd497
Update preprocessing.py
Tetracarbonylnickel Nov 14, 2023
5bcfb50
Merge pull request #207 from apax-hub/fix-neighbour-calc
Tetracarbonylnickel Nov 15, 2023
83607a9
implemented saving and loading of MD checkpoints
M-R-Schaefer Nov 15, 2023
c4a670e
pbar now correctly shows 100 percent on completion
M-R-Schaefer Nov 15, 2023
a0341ec
removed debug statements
M-R-Schaefer Nov 15, 2023
6e84d19
set default md log level to info
M-R-Schaefer Nov 15, 2023
94eb969
refactored checkpoint and momenta loading into separete function
M-R-Schaefer Nov 15, 2023
ccc39d1
moved System and SimFunctions to separate submodule
M-R-Schaefer Nov 15, 2023
b571222
utility logging redirect for logs during tqdm pbar
M-R-Schaefer Nov 15, 2023
81af8ee
H5TrajHandler now appends to trajectory if an existing one is found
M-R-Schaefer Nov 15, 2023
56d942d
DSTruncater now initializes chunked datasets, added TrajHandler type …
M-R-Schaefer Nov 15, 2023
1a61a7b
added trajectory truncation on checkpoint loading and time printing i…
M-R-Schaefer Nov 15, 2023
be9a134
linting
M-R-Schaefer Nov 15, 2023
b8bb72f
added explicit buffer size argument to TrajHandler
M-R-Schaefer Nov 15, 2023
2db1501
added checkpoint loading to MD integration test
M-R-Schaefer Nov 15, 2023
10b6061
linting
M-R-Schaefer Nov 15, 2023
39a9f8c
Merge pull request #200 from apax-hub/tl_test
M-R-Schaefer Nov 17, 2023
8a3c56b
Typo in the name README.md
Tetracarbonylnickel Nov 24, 2023
ff7e626
Merge pull request #208 from apax-hub/npt_validation
M-R-Schaefer Nov 29, 2023
f5c5373
implemented model loading from arbitray locations
M-R-Schaefer Dec 13, 2023
ebe2613
Merge pull request #212 from apax-hub/loading
M-R-Schaefer Dec 14, 2023
27c5305
Merge branch 'main' into dev
M-R-Schaefer Dec 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details

version: 2

# Set the version of Python and other tools you might need
build:
os: ubuntu-22.04
tools:
python: "3.11"
jobs:
post_install:
- pip install poetry
- poetry config virtualenvs.create false
- poetry install --with=docs

# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/source/conf.py
26 changes: 18 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# `apax`: Atomistic learned Potentials in JAX!
[![Read the Docs](https://readthedocs.org/projects/apax/badge/)](https://apax.readthedocs.io/en/latest/)
[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.10040711.svg)](https://doi.org/10.5281/zenodo.10040711)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/python/black)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)

`apax` is a high-performance, extendable package for training of and inference with atomistic neural networks.

`apax`[1] is a high-performance, extendable package for training of and inference with atomistic neural networks.
It implements the Gaussian Moment Neural Network model [2, 3].
It is based on [JAX](https://jax.readthedocs.io/en/latest/) and uses [JaxMD](https://github.com/jax-md/jax-md) as a molecular dynamics engine.

Expand Down Expand Up @@ -39,12 +42,12 @@ pip install --upgrade pip

CUDA 12 installation. Wheels only available on linux.
```bash
pip install --upgrade "jax[cuda12_pip]==0.4.14" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

CUDA 11 installation. Wheels only available on linux.
```bash
pip install --upgrade "jax[cuda11_pip]==0.4.14" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

See the [Jax installation instructions](https://github.com/google/jax#installation) for more details.
Expand Down Expand Up @@ -91,14 +94,21 @@ The second way is to use the ASE calculator provided in `apax.md`.

Under the supervion of Johannes Kästner

## References
* [1] DOI PLACEHOLDER
* [2] V. Zaverkin and J. Kästner, [“Gaussian Moments as Physically Inspired Molecular Descriptors for Accurate and Scalable Machine Learning Potentials,”](https://doi.org/10.1021/acs.jctc.0c00347) J. Chem. Theory Comput. **16**, 5410–5421 (2020).
* [3] V. Zaverkin, D. Holzmüller, I. Steinwart, and J. Kästner, [“Fast and Sample-Efficient Interatomic Neural Network Potentials for Molecules and Materials Based on Gaussian Moments,”](https://pubs.acs.org/doi/10.1021/acs.jctc.1c00527) J. Chem. Theory Comput. **17**, 6658–6670 (2021).


## Contributing

We are happy to receive your issues and pull requests!

Do not hesitate to contact any of the authors above if you have any further questions.


## Acknowledgements

The creation of Apax was supported by the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) in the framework of the priority program SPP 2363, “Utilization and Development of Machine Learning for Molecular Applications - Molecular Machine Learning” Project No. 497249646 and the Ministry of Science, Research and the Arts Baden-Württemberg in the Artificial Intelligence Software Academy (AISA).
Further funding though the DFG under Germany's Excellence Strategy - EXC 2075 - 390740016 and the Stuttgart Center for Simulation Science (SimTech) was provided.


## References
* [1] 10.5281/zenodo.10040711
* [2] V. Zaverkin and J. Kästner, [“Gaussian Moments as Physically Inspired Molecular Descriptors for Accurate and Scalable Machine Learning Potentials,”](https://doi.org/10.1021/acs.jctc.0c00347) J. Chem. Theory Comput. **16**, 5410–5421 (2020).
* [3] V. Zaverkin, D. Holzmüller, I. Steinwart, and J. Kästner, [“Fast and Sample-Efficient Interatomic Neural Network Potentials for Molecules and Materials Based on Gaussian Moments,”](https://pubs.acs.org/doi/10.1021/acs.jctc.1c00527) J. Chem. Theory Comput. **17**, 6658–6670 (2021).
30 changes: 19 additions & 11 deletions apax/bal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@
from tqdm import trange

from apax.bal import feature_maps, kernel, selection, transforms
from apax.data.input_pipeline import TFPipeline
from apax.data.initialization import RawDataset
from apax.data.input_pipeline import AtomisticDataset
from apax.model.builder import ModelBuilder
from apax.model.gmnn import EnergyModel
from apax.train.checkpoints import restore_parameters
from apax.train.run import RawDataset, initialize_dataset
from apax.train.checkpoints import (
canonicalize_energy_model_parameters,
check_for_ensemble,
restore_parameters,
)
from apax.train.run import initialize_dataset


def create_feature_fn(
Expand Down Expand Up @@ -43,11 +48,11 @@ def create_feature_fn(
return feature_fn


def compute_features(feature_fn, dataset: TFPipeline, processing_batch_size: int):
def compute_features(feature_fn, dataset: AtomisticDataset):
"""Compute the features of a dataset."""
features = []
n_data = dataset.n_data
ds = dataset.batch(processing_batch_size)
ds = dataset.batch()

pbar = trange(n_data, desc="Computing features", ncols=100, leave=True)
for i, (inputs, _) in enumerate(ds):
Expand All @@ -70,29 +75,32 @@ def kernel_selection(
selection_batch_size: int = 10,
processing_batch_size: int = 64,
):
n_models = 1 if isinstance(model_dir, (Path, str)) else len(model_dir)
is_ensemble = n_models > 1

selection_fn = {
"max_dist": selection.max_dist_selection,
}[selection_method]

base_feature_map = feature_maps.FeatureMapOptions(base_fm_options)

config, params = restore_parameters(model_dir)
params = canonicalize_energy_model_parameters(params)
n_models = check_for_ensemble(params)
is_ensemble = n_models > 1

n_train = len(train_atoms)
dataset = initialize_dataset(config, RawDataset(atoms_list=train_atoms + pool_atoms))
dataset = initialize_dataset(
config, RawDataset(atoms_list=train_atoms + pool_atoms), calc_stats=False
)
dataset.set_batch_size(processing_batch_size)

init_box = dataset.init_input()["box"][0]
_, init_box = dataset.init_input()

builder = ModelBuilder(config.model.get_dict(), n_species=119)
model = builder.build_energy_model(apply_mask=True, init_box=init_box)

feature_fn = create_feature_fn(
model, params, base_feature_map, feature_transforms, is_ensemble
)
g = compute_features(feature_fn, dataset, processing_batch_size)
g = compute_features(feature_fn, dataset)
km = kernel.KernelMatrix(g, n_train)
new_indices = selection_fn(km, selection_batch_size)

Expand Down
21 changes: 9 additions & 12 deletions apax/cli/apax_app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib.metadata
import importlib.resources as pkg_resources
import sys
from pathlib import Path

import typer
Expand Down Expand Up @@ -34,15 +35,14 @@ def train(
train_config_path: Path = typer.Argument(
..., help="Training configuration YAML file."
),
log_level: str = typer.Option("error", help="Sets the training logging level."),
log_file: str = typer.Option("train.log", help="Specifies the name of the log file"),
log_level: str = typer.Option("info", help="Sets the training logging level."),
):
"""
Starts the training of a model with parameters provided by a configuration file.
"""
from apax.train.run import run

run(train_config_path, log_file, log_level)
run(train_config_path, log_level)


@app.command()
Expand All @@ -51,16 +51,15 @@ def md(
..., help="Configuration YAML file that was used to train a model."
),
md_config_path: Path = typer.Argument(..., help="MD configuration YAML file."),
log_level: str = typer.Option("error", help="Sets the training logging level."),
log_file: str = typer.Option("md.log", help="Specifies the name of the log file"),
log_level: str = typer.Option("info", help="Sets the training logging level."),
):
"""
Starts performing a molecular dynamics simulation (currently only NHC thermostat)
with parameters provided by a configuration file.
"""
from apax.md import run_md

run_md(train_config_path, md_config_path, log_file, log_level)
run_md(train_config_path, md_config_path, log_level)


@app.command()
Expand Down Expand Up @@ -90,8 +89,8 @@ def docs():
"""
Opens the documentation website in your browser.
"""
console.print("Opening apax's docs at https://github.com/apax-hub/apax")
typer.launch("https://github.com/apax-hub/apax")
console.print("Opening apax's docs at https://apax.readthedocs.io/en/latest/")
typer.launch("https://apax.readthedocs.io/en/latest/")


@validate_app.command("train")
Expand Down Expand Up @@ -211,6 +210,7 @@ def template_train_config(

if Path(config_path).is_file():
console.print("There is already a config file in the working directory.")
sys.exit(1)
else:
with open(config_path, "w") as config:
config.write(template_content)
Expand All @@ -229,6 +229,7 @@ def template_md_config():

if Path(config_path).is_file():
console.print("There is already a config file in the working directory.")
sys.exit(1)
else:
with open(config_path, "w") as config:
config.write(template_content)
Expand All @@ -249,7 +250,3 @@ def main(
):
# Taken from https://github.com/zincware/dask4dvc/blob/main/dask4dvc/cli/main.py
_ = version


if __name__ == "__main__":
app()
13 changes: 8 additions & 5 deletions apax/cli/templates/md_config_minimal.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
temperature: <T> # K
duration: <DURATION> #fs
n_inner: 100 # compiled innner steps
sampling_rate: 10 # dump interval
dt: 0.5 # fs time step
ensemble:
name: nvt
dt: 0.5 # fs time step
temperature: <T> # K

duration: <DURATION> # fs
n_inner: 1 # compiled innner steps
sampling_rate: 1 # dump interval
dr_threshold: 0.5 # Neighborlist skin

sim_dir: md
Expand Down
2 changes: 0 additions & 2 deletions apax/cli/templates/train_config_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,3 @@ checkpoints:
progress_bar:
disable_epoch_pbar: false
disable_nl_pbar: false

maximize_l2_cache: true
1 change: 0 additions & 1 deletion apax/config/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def parse_config(config: Union[str, os.PathLike, dict], mode: str = "train") ->
config: Path to the config file or a dictionary
containing the config.
"""
log.info("Loading user config")
if isinstance(config, (str, os.PathLike)):
with open(config, "r") as stream:
config = yaml.safe_load(stream)
Expand Down
9 changes: 8 additions & 1 deletion apax/config/md_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
n_inner: Number of compiled simulation steps (i.e. number of iterations of the
`jax.lax.fori_loop` loop). Also determines atoms buffer size.
sampling_rate:
Trajectory dumping interval.
Interval between saving frames.
buffer_size:
Number of collected frames to be dumped at once.
dr_threshold: Skin of the neighborlist.
extra_capacity: JaxMD allocates a maximal number of neighbors.
This argument lets you add additional capacity to avoid recompilation.
Expand All @@ -56,6 +58,9 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
traj_name: Name of the trajectory file.
restart: Whether the simulation should restart from the latest configuration
in `traj_name`.
checkpoint_interval: Number of time steps between saving
full simulation state checkpoints. These will be loaded
with the `restart` option.
disable_pbar: Disables the MD progressbar.
"""

Expand All @@ -69,6 +74,7 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
duration: PositiveFloat
n_inner: PositiveInt = 100
sampling_rate: PositiveInt = 10
buffer_size: PositiveInt = 100
dr_threshold: PositiveFloat = 0.5
extra_capacity: PositiveInt = 0

Expand All @@ -77,6 +83,7 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
sim_dir: str = "."
traj_name: str = "md.h5"
restart: bool = True
checkpoint_interval: int = 50_000
disable_pbar: bool = False

def dump_config(self):
Expand Down
7 changes: 4 additions & 3 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,14 @@ def validate_shift_scale_methods(self):

return self

@property
def model_version_path(self):
version_path = Path(self.directory) / self.experiment
return version_path

@property
def best_model_path(self):
return self.model_version_path() / "best"
return self.model_version_path / "best"


class ModelConfig(BaseModel, extra="forbid"):
Expand Down Expand Up @@ -286,12 +288,12 @@ class Config(BaseModel, frozen=True, extra="forbid"):
callbacks: List of :class: `callback` <config.CallbackConfig> configurations.
progress_bar: Progressbar configuration.
checkpoints: Checkpoint configuration.
maximize_l2_cache: Whether or not to maximize GPU L2 cache.
"""

n_epochs: PositiveInt
patience: Optional[PositiveInt] = None
seed: int = 1
n_models: int = 1

data: DataConfig
model: ModelConfig = ModelConfig()
Expand All @@ -301,7 +303,6 @@ class Config(BaseModel, frozen=True, extra="forbid"):
callbacks: List[CallbackConfig] = [CallbackConfig(name="csv")]
progress_bar: TrainProgressbarConfig = TrainProgressbarConfig()
checkpoints: CheckpointConfig = CheckpointConfig()
maximize_l2_cache: bool = False

def dump_config(self, save_path):
"""
Expand Down
Loading