Skip to content

Commit

Permalink
Merge pull request #123 from pyiron/output_modul
Browse files Browse the repository at this point in the history
Separate module for output classes
  • Loading branch information
jan-janssen authored Dec 12, 2023
2 parents 1d201a9 + c74980b commit 1999488
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 40 deletions.
17 changes: 6 additions & 11 deletions atomistics/calculators/lammps/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
LAMMPS_RUN,
LAMMPS_MINIMIZE_VOLUME,
)
from atomistics.calculators.lammps.helpers import quantities
from atomistics.calculators.lammps.output import LammpsMDOutput, LammpsStaticOutput

if TYPE_CHECKING:
from ase import Atoms
Expand Down Expand Up @@ -113,7 +113,7 @@ def calc_static_with_lammps(
structure,
potential_dataframe,
lmp=None,
quantities=("energy", "forces", "stress"),
quantities=LammpsStaticOutput.fields(),
**kwargs,
):
template_str = LAMMPS_THERMO_STYLE + "\n" + LAMMPS_THERMO + "\n" + LAMMPS_RUN
Expand All @@ -127,12 +127,7 @@ def calc_static_with_lammps(
lmp=lmp,
**kwargs,
)
interactive_getter_dict = {
"forces": lmp_instance.interactive_forces_getter,
"energy": lmp_instance.interactive_energy_pot_getter,
"stress": lmp_instance.interactive_pressures_getter,
}
result_dict = {q: interactive_getter_dict[q]() for q in quantities}
result_dict = LammpsStaticOutput.get(lmp_instance, *quantities)
lammps_shutdown(lmp_instance=lmp_instance, close_instance=lmp is None)
return result_dict

Expand All @@ -149,7 +144,7 @@ def calc_molecular_dynamics_nvt_with_lammps(
seed=4928459,
dist="gaussian",
lmp=None,
quantities=quantities,
quantities=LammpsMDOutput.fields(),
**kwargs,
):
init_str = (
Expand Down Expand Up @@ -206,7 +201,7 @@ def calc_molecular_dynamics_npt_with_lammps(
seed=4928459,
dist="gaussian",
lmp=None,
quantities=quantities,
quantities=LammpsMDOutput.fields(),
**kwargs,
):
init_str = (
Expand Down Expand Up @@ -264,7 +259,7 @@ def calc_molecular_dynamics_nph_with_lammps(
seed=4928459,
dist="gaussian",
lmp=None,
quantities=quantities,
quantities=LammpsMDOutput.fields(),
**kwargs,
):
init_str = (
Expand Down
33 changes: 4 additions & 29 deletions atomistics/calculators/lammps/helpers.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,11 @@
from __future__ import annotations

import dataclasses

from jinja2 import Template
import numpy as np
from pylammpsmpi import LammpsASELibrary

from atomistics.calculators.lammps.potential import validate_potential_dataframe


@dataclasses.dataclass
class LammpsQuantityGetter:
positions: callable = LammpsASELibrary.interactive_positions_getter
cell: callable = LammpsASELibrary.interactive_cells_getter
forces: callable = LammpsASELibrary.interactive_forces_getter
temperature: callable = LammpsASELibrary.interactive_temperatures_getter
energy_pot: callable = LammpsASELibrary.interactive_energy_pot_getter
energy_tot: callable = LammpsASELibrary.interactive_energy_tot_getter
pressure: callable = LammpsASELibrary.interactive_pressures_getter
velocities: callable = LammpsASELibrary.interactive_velocities_getter

@classmethod
def fields(cls):
return tuple(field.name for field in dataclasses.fields(cls))

def __call__(self, engine: LammpsASELibrary, quantity: str):
return getattr(self, quantity)(engine)


quantity_getter = LammpsQuantityGetter()
quantities = quantity_getter.fields()
from atomistics.calculators.lammps.output import LammpsMDOutput


def lammps_run(structure, potential_dataframe, input_template=None, lmp=None, **kwargs):
Expand Down Expand Up @@ -65,20 +41,19 @@ def lammps_calc_md_step(
lmp_instance,
run_str,
run,
quantities=quantities,
quantities=LammpsMDOutput.fields(),
):
run_str_rendered = Template(run_str).render(run=run)
lmp_instance.interactive_lib_command(run_str_rendered)
# return {q: getattr(LammpsQuantityGetter, q)(lmp_instance) for q in quantities}
return {q: quantity_getter(lmp_instance, q) for q in quantities}
return LammpsMDOutput.get(lmp_instance, *quantities)


def lammps_calc_md(
lmp_instance,
run_str,
run,
thermo,
quantities=quantities,
quantities=LammpsMDOutput.fields(),
):
results_lst = [
lammps_calc_md_step(
Expand Down
33 changes: 33 additions & 0 deletions atomistics/calculators/lammps/output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import dataclasses

from pylammpsmpi import LammpsASELibrary


@dataclasses.dataclass
class LammpsOutput:
@classmethod
def fields(cls):
return tuple(field.name for field in dataclasses.fields(cls))

@classmethod
def get(cls, engine: LammpsASELibrary, *quantities: str) -> dict:
return {q: getattr(cls, q)(engine) for q in quantities}


@dataclasses.dataclass
class LammpsMDOutput(LammpsOutput):
positions: callable = LammpsASELibrary.interactive_positions_getter
cell: callable = LammpsASELibrary.interactive_cells_getter
forces: callable = LammpsASELibrary.interactive_forces_getter
temperature: callable = LammpsASELibrary.interactive_temperatures_getter
energy_pot: callable = LammpsASELibrary.interactive_energy_pot_getter
energy_tot: callable = LammpsASELibrary.interactive_energy_tot_getter
pressure: callable = LammpsASELibrary.interactive_pressures_getter
velocities: callable = LammpsASELibrary.interactive_velocities_getter


@dataclasses.dataclass
class LammpsStaticOutput(LammpsOutput):
forces: callable = LammpsASELibrary.interactive_forces_getter
energy: callable = LammpsASELibrary.interactive_energy_pot_getter
stress: callable = LammpsASELibrary.interactive_pressures_getter

0 comments on commit 1999488

Please sign in to comment.