Skip to content

Commit

Permalink
docs linting
Browse files Browse the repository at this point in the history
  • Loading branch information
Tetracarbonylnickel committed Apr 10, 2024
1 parent 4e6ff51 commit f8c9cba
Show file tree
Hide file tree
Showing 15 changed files with 76 additions and 490 deletions.
16 changes: 14 additions & 2 deletions apax/cli/templates/md_config_minimal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,23 @@ ensemble:
name: nvt
dt: 0.5 # fs time step
temperature: <T> # K
thermostat_chain:
chain_length: 3
chain_steps: 2
sy_steps: 3
tau: 100

duration: <DURATION> # fs
n_inner: 1 # compiled innner steps
sampling_rate: 1 # dump interval
n_inner: 100 # compiled innner steps
sampling_rate: 10 # dump interval
buffer_size: 100
dr_threshold: 0.5 # Neighborlist skin
extra_capacity: 0

sim_dir: md
initial_structure: <INITIAL_STRUCTURE>
load_momenta: false
traj_name: md.h5
restart: true
checkpoint_interval: 50_000
disable_pbar: false
29 changes: 23 additions & 6 deletions apax/cli/templates/train_config_full.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
n_epochs: <NUMBER OF EPOCHS>
seed: 1
patience: null
n_models: 1
n_jitted_steps: 1
data_parallel: True

data:
directory: models/
Expand All @@ -11,6 +15,8 @@ data:
#train_data_path: <PATH>
#val_data_path: <PATH>
#test_data_path: <PATH>
additional_properties_info: {}
ds_type: cached

n_train: 1000
n_valid: 100
Expand All @@ -20,6 +26,10 @@ data:

shift_method: "per_element_regression_shift"
shift_options: {"energy_regularisation": 1.0}

scale_method: "per_element_force_rms_scale"
scale_options: {}

shuffle_buffer_size: 1000

pos_unit: Ang
Expand All @@ -28,25 +38,30 @@ data:
model:
n_basis: 7
n_radial: 5
n_contr: -1
nn: [512, 512]

r_max: 6.0
r_min: 0.5

calc_stress: true
use_zbl: false

b_init: normal
descriptor_dtype: fp32
descriptor_dtype: fp64
readout_dtype: fp32
scale_shift_dtype: fp32
emb_init: uniform

loss:
- loss_type: structures
name: energy
- name: energy
loss_type: mse
weight: 1.0
- loss_type: structures
name: forces
atoms_exponent: 1
- name: forces
loss_type: mse
weight: 4.0
atoms_exponent: 1

metrics:
- name: energy
Expand All @@ -66,7 +81,8 @@ optimizer:
shift_lr: 0.05
zbl_lr: 0.001
transition_begin: 0

sam_rho: 0.0

callbacks:
- name: csv

Expand All @@ -78,3 +94,4 @@ checkpoints:

progress_bar:
disable_epoch_pbar: false
disable_batch_pbar: true
20 changes: 10 additions & 10 deletions apax/config/md_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class NHCOptions(BaseModel, extra="forbid"):
tau : PositiveFloat, default = 100
Relaxation time parameter.
"""

chain_length: PositiveInt = 3
chain_steps: PositiveInt = 2
sy_steps: PositiveInt = 3
Expand All @@ -37,6 +38,7 @@ class Integrator(BaseModel, extra="forbid"):
dt : PositiveFloat, default = 0.5
Time step size in femtoseconds (fs).
"""

dt: PositiveFloat = 0.5 # fs


Expand All @@ -49,6 +51,7 @@ class NVEOptions(Integrator, extra="forbid"):
name : Literal["nve"]
Name of the ensemble.
"""

name: Literal["nve"]


Expand All @@ -65,6 +68,7 @@ class NVTOptions(Integrator, extra="forbid"):
thermostat_chain : NHCOptions, default = NHCOptions()
Thermostat chain options.
"""

name: Literal["nvt"]
temperature: PositiveFloat = 298.15 # K
thermostat_chain: NHCOptions = NHCOptions()
Expand All @@ -83,6 +87,7 @@ class NPTOptions(NVTOptions, extra="forbid"):
barostat_chain : NHCOptions, default = NHCOptions(tau=1000)
Barostat chain options.
"""

name: Literal["npt"]
pressure: PositiveFloat = 1.01325 # bar
barostat_chain: NHCOptions = NHCOptions(tau=1000)
Expand All @@ -103,31 +108,25 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
duration : float, required
| Total simulation time in fs.
n_inner : int, default = 100
| Number of compiled simulation steps (i.e. number of iterations of the
`jax.lax.fori_loop` loop). Also determines atoms buffer size.
| Number of compiled simulation steps (i.e. number of iterations of the `jax.lax.fori_loop` loop). Also determines atoms buffer size.
sampling_rate : int, default = 10
| Interval between saving frames.
buffer_size : int, default = 100
| Number of collected frames to be dumped at once.
dr_threshold : float, default = 0.5
| Skin of the neighborlist.
extra_capacity : int, default = 0
| JaxMD allocates a maximal number of neighbors.
This argument lets you add additional capacity to avoid recompilation.
The default is usually fine.
| JaxMD allocates a maximal number of neighbors. This argument lets you add additional capacity to avoid recompilation. The default is usually fine.
initial_structure : str, required
| Path to the starting structure of the simulation.
sim_dir : str, default = "."
| Directory where simulation file will be stored.
traj_name : str, default = "md.h5"
| Name of the trajectory file.
restart : bool, default = True
| Whether the simulation should restart from the latest configuration
in `traj_name`.
| Whether the simulation should restart from the latest configuration in `traj_name`.
checkpoint_interval : int, default = 50_000
| Number of time steps between saving
full simulation state checkpoints. These will be loaded
with the `restart` option.
| Number of time steps between saving full simulation state checkpoints. These will be loaded with the `restart` option.
disable_pbar : bool, False
| Disables the MD progressbar.
"""
Expand Down Expand Up @@ -157,6 +156,7 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
def dump_config(self):
"""
Writes the current config file to the MD directory.
"""
with open(os.path.join(self.sim_dir, "md_config.yaml"), "w") as conf:
yaml.dump(self.model_dump(), conf, default_flow_style=False)
15 changes: 7 additions & 8 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class DataConfig(BaseModel, extra="forbid"):
directory : str, required
| Path to directory where training results and checkpoints are written.
experiment : str, required
| Model name distinguishing from others in directory.
data_path : str, required if train_ and val_data_path is not specified
| Model name distinguishing from others in directory.
data_path : str, required if train_data_path and val_data_path is not specified
| Path to single dataset file.
train_data_path : str, required if data_path is not specified
| Path to training dataset.
Expand All @@ -51,8 +51,9 @@ class DataConfig(BaseModel, extra="forbid"):
| Size of the `tf.data` shuffle buffer.
additional_properties_info : dict, optional
| dict of property name, shape (ragged or fixed) pairs
energy_regularisation :
energy_regularisation :
| Magnitude of the regularization in the per-element energy regression.
"""

directory: str
Expand Down Expand Up @@ -200,7 +201,7 @@ class OptimizerConfig(BaseModel, frozen=True, extra="forbid"):
"""
Configuration of the optimizer.
Learning rates of 0 will freeze the respective parameters.
Parameters
----------
opt_name : str, default = "adam"
Expand Down Expand Up @@ -380,8 +381,7 @@ class Config(BaseModel, frozen=True, extra="forbid"):
n_models : int, default = 1
| Number of models to be trained at once.
n_jitted_steps : int, default = 1
| Number of train batches to be processed in a compiled loop.
Can yield singificant speedups for small structures or small batch sizes.
| Number of train batches to be processed in a compiled loop. Can yield singificant speedups for small structures or small batch sizes.
data : :class:`.DataConfig`
| Data configuration.
model : :class:`.ModelConfig`
Expand All @@ -399,8 +399,7 @@ class Config(BaseModel, frozen=True, extra="forbid"):
checkpoints : :class:`.CheckpointConfig`
| Checkpoint configuration.
data_parallel : bool, default = True
| Automatically uses all available GPUs for data parallel training.
Set to false to force single device training.
| Automatically uses all available GPUs for data parallel training. Set to false to force single device training.
"""

n_epochs: PositiveInt
Expand Down
2 changes: 1 addition & 1 deletion apax/md/function_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class UncertaintyDrivenDynamics(FunctionTransformation):
up to some maximum bias energy.
https://doi.org/10.1038/s43588-023-00406-5
Parameters
----------
height : float
Expand Down
9 changes: 4 additions & 5 deletions apax/train/run.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import logging
import sys
from typing import List
from typing import Union
import os
import sys
from typing import List, Union

import jax

Expand Down Expand Up @@ -139,9 +138,9 @@ def run(user_config: Union[str, os.PathLike, dict], log_level="error"):
Parameters
----------
user_config : str | os.PathLike | dict
training config full exmaple can be finde :ref:`here <config>`:
training config full exmaple can be finde :ref:`here <train_config>`:
"""
config = parse_config(user_config)

Expand Down
4 changes: 2 additions & 2 deletions apax/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ def fit(
Parameters
----------
state :
state :
The initial state of the model.
train_ds : InMemoryDataset
The training dataset.
loss_fn :
The loss function to be minimized.
Metrics metrics.Collection :
Collection of metrics to evaluate during training.
callbacks : list
callbacks : list
List of callback functions to be executed during training.
n_epochs : int
Number of epochs for training.
Expand Down
2 changes: 1 addition & 1 deletion apax/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,4 @@ def split_atoms(atoms_list, train_idxs, val_idxs=None):
else:
val_atoms_list = []

return train_atoms_list, val_atoms_list
return train_atoms_list, val_atoms_list
2 changes: 1 addition & 1 deletion apax/utils/jax_md_reduced/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def neighbor_fn(position_and_error, max_occupancy=None):
if not is_sparse(format):
capacity_limit = N - 1 if mask_self else N
elif format is NeighborListFormat.Sparse:
capacity_limit = N * (N - 1) if mask_self else N**2
capacity_limit = N * (N - 1) if mask_self else N ** 2
else:
capacity_limit = N * (N - 1) // 2
if max_occupancy > capacity_limit:
Expand Down
11 changes: 6 additions & 5 deletions docs/source/configs/full_configs.rst
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@

.. _train_config:
======

===============
Training Config
======
===============

Full config can be downloaded :download:`here <../../../apax/cli/templates/train_config_full.yaml>`.

.. include:: ../../../apax/cli/templates/train_config_full.yaml
:literal:

.. _md_config:
======

=========================
Molecular Dynamics Config
======
=========================

Full config can be downloaded :download:`here <../../../apax/cli/templates/md_config_minimal.yaml>`.

Expand Down
2 changes: 1 addition & 1 deletion docs/source/configs/index.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Input Parameter
=======
===============

.. toctree::
:maxdepth: 2
Expand Down
3 changes: 1 addition & 2 deletions docs/source/getting_started/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,4 @@ Getting Started
.. toctree::
:maxdepth: 2

install
full_config
install
7 changes: 2 additions & 5 deletions examples/01_Model_Training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The following command create a minimal configuration file in the working directory. Full configuration file with descriptiond of the prameter can be found [here](./05_Full_Config.ipynb)."
"The following command create a minimal configuration file in the working directory. Full configuration file with descriptiond of the prameter can be found [here](https://github.com/apax-hub/apax/blob/main/apax/cli/templates/train_config_full.yaml)."
]
},
{
Expand Down Expand Up @@ -280,7 +280,7 @@
"\n",
"\n",
"During training, apax displays a progress bar to keep track of the validation loss.\n",
"This progress bar is optional however and can be turned off in the config. LINK\n",
"This progress bar is optional however and can be turned off in the config.\n",
"The default configuration writes training metrics to a CSV file, but TensorBoard is also supported.\n",
"One can specify which to use by adding the following section to the input file:\n",
"\n",
Expand Down Expand Up @@ -454,9 +454,6 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"TODO pretty print results to the terminal\n",
"\n",
"Congratulations, you have successfully trained and evaluated your first apax model!"
]
},
Expand Down
5 changes: 3 additions & 2 deletions examples/02_Molecular_Dynamics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@
"The CLI provides easy access to standard NVT and NPT simulations.\n",
"More complex simulation loops are relatively easy to build yourself in JaxMD (see their colab notebooks for examples). \n",
"Trained apax models can of course be used as `energy_fn` in such custom simulations.\n",
"If you have a suggestion for adding some MD feature or thermostat to the core of `apax`, feel free to open up an issue on [Github]{https://github.com/apax-hub/apax}.\n"
"If you have a suggestion for adding some MD feature or thermostat to the core of `apax`, feel free to open up an issue on [Github](https://github.com/apax-hub/apax).\n"
]
},
{
Expand Down Expand Up @@ -243,7 +243,8 @@
" \n",
"duration: 20_000 # fs\n",
"initial_structure: project/benzene_mod.xyz\n",
"```\n"
"```\n",
"Full configuration file with descriptiond of the prameter can be found [here](https://github.com/apax-hub/apax/blob/main/apax/cli/templates/md_config_minimal.yaml)."
]
},
{
Expand Down
Loading

0 comments on commit f8c9cba

Please sign in to comment.