Skip to content

Commit

Permalink
Add post-processing routines to MD
Browse files Browse the repository at this point in the history
  • Loading branch information
oerc0122 committed Jun 3, 2024
1 parent 2dcb1c8 commit e4fa4b3
Show file tree
Hide file tree
Showing 7 changed files with 363 additions and 9 deletions.
10 changes: 10 additions & 0 deletions docs/source/apidoc/janus_core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,16 @@ janus\_core.helpers.descriptors module
:undoc-members:
:show-inheritance:

janus\_core.helpers.post_process module
---------------------------------------

.. automodule:: janus_core.helpers.post_process
:members:
:special-members:
:private-members:
:undoc-members:
:show-inheritance:

janus\_core.helpers.train module
--------------------------------

Expand Down
83 changes: 79 additions & 4 deletions janus_core/calculations/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from math import isclose
from pathlib import Path
import random
from typing import Any, Optional
from typing import Any, Optional, Union
from warnings import warn

from ase import Atoms, units
from ase.io import write
from ase.io import read, write
from ase.md.langevin import Langevin
from ase.md.npt import NPT as ASE_NPT
from ase.md.velocitydistribution import (
Expand All @@ -18,11 +18,21 @@
ZeroRotation,
)
from ase.md.verlet import VelocityVerlet

try:
from ase.geometry.analysis import Analysis

ASE_GEOMETRY = True
except ImportError:

ASE_GEOMETRY = False

import numpy as np

from janus_core.calculations.geom_opt import optimize
from janus_core.helpers.janus_types import Ensembles, PathLike
from janus_core.helpers.janus_types import Ensembles, PathLike, PostProcessKwargs
from janus_core.helpers.log import config_logger
from janus_core.helpers.post_process import compute_rdf, compute_vaf
from janus_core.helpers.utils import FileNameMixin

DENS_FACT = (units.m / 1.0e2) ** 3 / units.mol
Expand Down Expand Up @@ -97,6 +107,8 @@ class MolecularDynamics(FileNameMixin): # pylint: disable=too-many-instance-att
heating.
temp_time : Optional[float]
Time between heating steps, in fs. Default is None, which disables heating.
post_process_kwargs : Optional[PostProcessKwargs]
Keyword arguments to control post-processing operations.
log_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to log config. Default is None.
seed : Optional[int]
Expand Down Expand Up @@ -157,6 +169,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta
temp_end: Optional[float] = None,
temp_step: Optional[float] = None,
temp_time: Optional[float] = None,
post_process_kwargs: Optional[PostProcessKwargs] = None,
log_kwargs: Optional[dict[str, Any]] = None,
seed: Optional[int] = None,
) -> None:
Expand Down Expand Up @@ -231,6 +244,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta
disables heating.
temp_time : Optional[float]
Time between heating steps, in fs. Default is None, which disables heating.
post_process_kwargs : Optional[PostProcessKwargs]
Keyword arguments to control post-processing operations.
log_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to log config. Default is None.
seed : Optional[int]
Expand Down Expand Up @@ -262,6 +277,9 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta
self.temp_end = temp_end
self.temp_step = temp_step
self.temp_time = temp_time * units.fs if temp_time else None
self.post_process_kwargs = (
post_process_kwargs if post_process_kwargs is not None else {}
)
self.log_kwargs = log_kwargs
self.ensemble = ensemble
self.seed = seed
Expand Down Expand Up @@ -315,7 +333,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta

self.minimize_kwargs = minimize_kwargs if minimize_kwargs else {}
self.restart_files = []
self.dyn = None
self.dyn: Union[Langevin, VelocityVerlet, ASE_NPT]
self.n_atoms = len(self.struct)

self.stats_file = self._build_filename(
Expand Down Expand Up @@ -542,6 +560,60 @@ def _write_final_state(self) -> None:
columns=["symbols", "positions", "momenta", "masses"],
)

def _post_process(self) -> None:
"""Compute properties after MD run."""
# Nothing to do
if not any(
self.post_process_kwargs.get(kwarg, None)
for kwarg in ("rdf_compute", "vaf_compute")
):
return

data = read(self.traj_file, index=":")

if ASE_GEOMETRY:
ana = Analysis(data)
else:
ana = None

param_pref = self._parameter_prefix if self.file_prefix is None else ""

if self.post_process_kwargs.get("rdf_compute", False):
base_name = self.post_process_kwargs.get("rdf_output_file", None)
rdf_args = {
name: self.post_process_kwargs.get(key, default)
for name, (key, default) in (
("rmax", ("rdf_rmax", 2.5)),
("nbins", ("rdf_nbins", 50)),
("elements", ("rdf_elements", None)),
)
}
slice_ = (
self.post_process_kwargs.get("rdf_start", 0),
self.post_process_kwargs.get("rdf_stop", 1),
self.post_process_kwargs.get("rdf_step", 1),
)

out_paths = [
self._build_filename(
"rdf.dat", param_pref, str(ind), prefix_override=base_name
)
for ind in range(*slice_)
]

rdf_args["index"] = slice_
compute_rdf(data, ana, filename=out_paths, **rdf_args)

if self.post_process_kwargs.get("vaf_compute", False):

file_name = self.post_process_kwargs.get("vaf_output_file", None)
use_vel = self.post_process_kwargs.get("vaf_velocities", False)
fft = self.post_process_kwargs.get("vaf_fft", False)

out_path = self._build_filename("vaf.dat", param_pref, filename=file_name)

compute_vaf(data, out_path, use_velocities=use_vel, fft=fft)

def _write_restart(self) -> None:
"""Write restart file and (optionally) rotate files saved."""
step = self.offset + self.dyn.nsteps
Expand Down Expand Up @@ -594,6 +666,9 @@ def run(self) -> None:
self.struct.info["real_time"] = datetime.datetime.now()
self._run_dynamics()

if self.post_process_kwargs:
self._post_process()

def _run_dynamics(self) -> None:
"""Run dynamics and/or temperature ramp."""
# Store temperature for final MD
Expand Down
11 changes: 9 additions & 2 deletions janus_core/cli/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Device,
LogPath,
MinimizeKwargs,
PostProcessKwargs,
ReadKwargs,
StructPath,
Summary,
Expand Down Expand Up @@ -168,6 +169,7 @@ def md(
temp_time: Annotated[
float, Option(help="Time between heating steps, in fs.")
] = None,
post_process_kwargs: PostProcessKwargs = None,
log: LogPath = "md.log",
seed: Annotated[
Optional[int],
Expand Down Expand Up @@ -267,6 +269,8 @@ def md(
temp_time : Optional[float]
Time between heating steps, in fs. Default is None, which disables
heating.
post_process_kwargs : Optional[PostProcessKwargs]
Kwargs to pass to post-processing.
log : Optional[Path]
Path to write logs to. Default is "md.log".
seed : Optional[int]
Expand All @@ -280,8 +284,10 @@ def md(
# Check options from configuration file are all valid
check_config(ctx)

[read_kwargs, calc_kwargs, minimize_kwargs] = parse_typer_dicts(
[read_kwargs, calc_kwargs, minimize_kwargs]
[read_kwargs, calc_kwargs, minimize_kwargs, post_process_kwargs] = (
parse_typer_dicts(
[read_kwargs, calc_kwargs, minimize_kwargs, post_process_kwargs]
)
)

if not ensemble in get_args(Ensembles):
Expand Down Expand Up @@ -334,6 +340,7 @@ def md(
"temp_end": temp_end,
"temp_step": temp_step,
"temp_time": temp_time,
"post_process_kwargs": post_process_kwargs,
"log_kwargs": log_kwargs,
"seed": seed,
}
Expand Down
15 changes: 15 additions & 0 deletions janus_core/cli/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,21 @@ def __str__(self):
),
]

PostProcessKwargs = Annotated[
TyperDict,
Option(
parser=parse_dict_class,
help=(
"""
Keyword arguments to pass to post-processer. Must be passed as a dictionary
wrapped in quotes, e.g. "{'key' : value}".
"""
),
metavar="DICT",
),
]


LogPath = Annotated[Path, Option(help="Path to save logs to.")]

Summary = Annotated[
Expand Down
21 changes: 20 additions & 1 deletion janus_core/helpers/janus_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,32 @@ class ASEWriteArgs(TypedDict, total=False):


class ASEOptArgs(TypedDict, total=False):
"""Main arugments for ase optimisers."""
"""Main arguments for ase optimisers."""

restart: Optional[bool]
logfile: Optional[PathLike]
trajectory: Optional[str]


class PostProcessKwargs(TypedDict, total=False):
"""Main arguments for MD post-processing."""

# RDF
rdf_compute: bool
rdf_rmax: float
rdf_nbins: int
rdf_elements: MaybeSequence[Union[str, int]]
rdf_start: int
rdf_stop: Optional[int]
rdf_step: int
rdf_output_file: Optional[str]
# VAF
vaf_compute: bool
vaf_velocities: bool
vaf_fft: bool
vaf_output_file: Optional[PathLike]


# eos_names from ase.eos
EoSNames = Literal[
"sj",
Expand Down
Loading

0 comments on commit e4fa4b3

Please sign in to comment.