Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate module for output classes #123

Merged
merged 8 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions atomistics/calculators/lammps/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
LAMMPS_RUN,
LAMMPS_MINIMIZE_VOLUME,
)
from atomistics.calculators.lammps.helpers import quantities
from atomistics.calculators.lammps.output import LammpsMDOutput, LammpsStaticOutput

if TYPE_CHECKING:
from ase import Atoms
Expand Down Expand Up @@ -113,7 +113,7 @@ def calc_static_with_lammps(
structure,
potential_dataframe,
lmp=None,
quantities=("energy", "forces", "stress"),
quantities=LammpsStaticOutput.fields(),
**kwargs,
):
template_str = LAMMPS_THERMO_STYLE + "\n" + LAMMPS_THERMO + "\n" + LAMMPS_RUN
Expand All @@ -127,12 +127,7 @@ def calc_static_with_lammps(
lmp=lmp,
**kwargs,
)
interactive_getter_dict = {
"forces": lmp_instance.interactive_forces_getter,
"energy": lmp_instance.interactive_energy_pot_getter,
"stress": lmp_instance.interactive_pressures_getter,
}
result_dict = {q: interactive_getter_dict[q]() for q in quantities}
result_dict = LammpsStaticOutput.get(lmp_instance, *quantities)
lammps_shutdown(lmp_instance=lmp_instance, close_instance=lmp is None)
return result_dict

Expand All @@ -149,7 +144,7 @@ def calc_molecular_dynamics_nvt_with_lammps(
seed=4928459,
dist="gaussian",
lmp=None,
quantities=quantities,
quantities=LammpsMDOutput.fields(),
**kwargs,
):
init_str = (
Expand Down Expand Up @@ -206,7 +201,7 @@ def calc_molecular_dynamics_npt_with_lammps(
seed=4928459,
dist="gaussian",
lmp=None,
quantities=quantities,
quantities=LammpsMDOutput.fields(),
**kwargs,
):
init_str = (
Expand Down Expand Up @@ -264,7 +259,7 @@ def calc_molecular_dynamics_nph_with_lammps(
seed=4928459,
dist="gaussian",
lmp=None,
quantities=quantities,
quantities=LammpsMDOutput.fields(),
**kwargs,
):
init_str = (
Expand Down
33 changes: 4 additions & 29 deletions atomistics/calculators/lammps/helpers.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,11 @@
from __future__ import annotations

import dataclasses

from jinja2 import Template
import numpy as np
from pylammpsmpi import LammpsASELibrary

from atomistics.calculators.lammps.potential import validate_potential_dataframe


@dataclasses.dataclass
class LammpsQuantityGetter:
positions: callable = LammpsASELibrary.interactive_positions_getter
cell: callable = LammpsASELibrary.interactive_cells_getter
forces: callable = LammpsASELibrary.interactive_forces_getter
temperature: callable = LammpsASELibrary.interactive_temperatures_getter
energy_pot: callable = LammpsASELibrary.interactive_energy_pot_getter
energy_tot: callable = LammpsASELibrary.interactive_energy_tot_getter
pressure: callable = LammpsASELibrary.interactive_pressures_getter
velocities: callable = LammpsASELibrary.interactive_velocities_getter

@classmethod
def fields(cls):
return tuple(field.name for field in dataclasses.fields(cls))

def __call__(self, engine: LammpsASELibrary, quantity: str):
return getattr(self, quantity)(engine)


quantity_getter = LammpsQuantityGetter()
quantities = quantity_getter.fields()
from atomistics.calculators.lammps.output import LammpsMDOutput


def lammps_run(structure, potential_dataframe, input_template=None, lmp=None, **kwargs):
Expand Down Expand Up @@ -65,20 +41,19 @@ def lammps_calc_md_step(
lmp_instance,
run_str,
run,
quantities=quantities,
quantities=LammpsMDOutput.fields(),
):
run_str_rendered = Template(run_str).render(run=run)
lmp_instance.interactive_lib_command(run_str_rendered)
# return {q: getattr(LammpsQuantityGetter, q)(lmp_instance) for q in quantities}
return {q: quantity_getter(lmp_instance, q) for q in quantities}
return LammpsMDOutput.get(lmp_instance, *quantities)


def lammps_calc_md(
lmp_instance,
run_str,
run,
thermo,
quantities=quantities,
quantities=LammpsMDOutput.fields(),
):
results_lst = [
lammps_calc_md_step(
Expand Down
33 changes: 33 additions & 0 deletions atomistics/calculators/lammps/output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import dataclasses

from pylammpsmpi import LammpsASELibrary


@dataclasses.dataclass
class LammpsOutput:
@classmethod
def fields(cls):
return tuple(field.name for field in dataclasses.fields(cls))

@classmethod
def get(cls, engine: LammpsASELibrary, *quantities: str) -> dict:
return {q: getattr(cls, q)(engine) for q in quantities}


@dataclasses.dataclass
class LammpsMDOutput(LammpsOutput):
positions: callable = LammpsASELibrary.interactive_positions_getter
cell: callable = LammpsASELibrary.interactive_cells_getter
forces: callable = LammpsASELibrary.interactive_forces_getter
temperature: callable = LammpsASELibrary.interactive_temperatures_getter
energy_pot: callable = LammpsASELibrary.interactive_energy_pot_getter
energy_tot: callable = LammpsASELibrary.interactive_energy_tot_getter
pressure: callable = LammpsASELibrary.interactive_pressures_getter
velocities: callable = LammpsASELibrary.interactive_velocities_getter


@dataclasses.dataclass
class LammpsStaticOutput(LammpsOutput):
forces: callable = LammpsASELibrary.interactive_forces_getter
energy: callable = LammpsASELibrary.interactive_energy_pot_getter
stress: callable = LammpsASELibrary.interactive_pressures_getter