Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
esoteric-ephemera committed Nov 6, 2024
1 parent da743f7 commit 4da5d94
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 47 deletions.
92 changes: 54 additions & 38 deletions emmet-core/emmet/core/neb.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,16 @@ class NebTaskDoc(BaseModel, extra="allow"):
None,
description="The initial and final configurations (reactants and products) of the barrier.",
)
endpoint_energies : Optional[Sequence[float]] = Field(
None,
description="Energies of the endpoint structures."
endpoint_energies: Optional[Sequence[float]] = Field(
None, description="Energies of the endpoint structures."
)
endpoint_calculations : Optional[list[Calculation]] = Field(
None,
description = "Calculation information for the endpoint structures"
endpoint_calculations: Optional[list[Calculation]] = Field(
None, description="Calculation information for the endpoint structures"
)
endpoint_objects : Optional[list[dict]] = Field(
endpoint_objects: Optional[list[dict]] = Field(
None, description="VASP objects for each endpoint calculation."
)
endpoint_directories : Optional[list[str]] = Field(
endpoint_directories: Optional[list[str]] = Field(
None, description="List of the directories for the endpoint calculations."
)

Expand Down Expand Up @@ -140,25 +138,33 @@ def set_barriers(self) -> Self:
def num_images(self) -> int:
"""Return the number of VASP calculations / number of images performed."""
return len(self.image_directories)

@property
def energies(self) -> list[float]:
"""Return the endpoint (optional) and image energies."""
if self.endpoint_energies is not None:
return [self.endpoint_energies[0], *self.image_energies, self.endpoint_energies[1]]
return [
self.endpoint_energies[0],
*self.image_energies,
self.endpoint_energies[1],
]
return self.image_energies

@property
def structures(self) -> list[Structure]:
"""Return the endpoint and image structures."""
return [self.endpoint_structures[0], *self.image_structures, self.endpoint_structures[1]]

return [
self.endpoint_structures[0],
*self.image_structures,
self.endpoint_structures[1],
]

@classmethod
def from_directory(
cls,
dir_name: Union[Path, str],
volumetric_files: Tuple[str, ...] = _VOLUMETRIC_FILES,
store_calculations : bool = True,
store_calculations: bool = True,
**neb_task_doc_kwargs,
) -> Self:
"""
Expand All @@ -172,7 +178,7 @@ def from_directory(

neb_directories = sorted(dir_name.glob("[0-9][0-9]"))

if (ep_calcs := neb_task_doc_kwargs.pop("endpoint_calculations", None) ) is None:
if (ep_calcs := neb_task_doc_kwargs.pop("endpoint_calculations", None)) is None:
endpoint_directories = [neb_directories[0], neb_directories[-1]]
endpoint_structures = [
Structure.from_file(zpath(f"{endpoint_dir}/POSCAR"))
Expand All @@ -181,12 +187,8 @@ def from_directory(
endpoint_energies = None
else:
endpoint_directories = neb_task_doc_kwargs.pop("endpoint_directories")
endpoint_structures = [
ep_calc.output.structure for ep_calc in ep_calcs
]
endpoint_energies = [
ep_calc.output.energy for ep_calc in ep_calcs
]
endpoint_structures = [ep_calc.output.structure for ep_calc in ep_calcs]
endpoint_energies = [ep_calc.output.energy for ep_calc in ep_calcs]

image_directories = neb_directories[1:-1]

Expand Down Expand Up @@ -216,8 +218,7 @@ def from_directory(
task_state = (
TaskState.SUCCESS
if all(
calc.has_vasp_completed == TaskState.SUCCESS
for calc in calcs_to_check
calc.has_vasp_completed == TaskState.SUCCESS for calc in calcs_to_check
)
else TaskState.FAILED
)
Expand Down Expand Up @@ -247,11 +248,11 @@ def from_directory(

return cls(
endpoint_structures=endpoint_structures,
endpoint_energies = endpoint_energies,
endpoint_directories = [str(ep_dir) for ep_dir in endpoint_directories],
endpoint_calculations = ep_calcs if store_calculations else None,
endpoint_energies=endpoint_energies,
endpoint_directories=[str(ep_dir) for ep_dir in endpoint_directories],
endpoint_calculations=ep_calcs if store_calculations else None,
image_calculations=image_calculations if store_calculations else None,
image_structures = image_structures,
image_structures=image_structures,
dir_name=str(dir_name),
image_directories=[str(img_dir) for img_dir in image_directories],
orig_inputs=inputs["orig_inputs"],
Expand All @@ -271,7 +272,7 @@ def from_directories(
endpoint_directories: list[str | Path],
neb_directory: str | Path,
volumetric_files: Tuple[str, ...] = _VOLUMETRIC_FILES,
**neb_task_doc_kwargs
**neb_task_doc_kwargs,
) -> Self:
"""
Return an NebTaskDoc from endpoint and NEB calculation directories.
Expand All @@ -282,12 +283,26 @@ def from_directories(
endpoint_calculations = [None for _ in range(2)]
endpoint_objects = [None for _ in range(2)]
for idx, endpoint_dir in enumerate(endpoint_directories):
vasp_files = _find_vasp_files(endpoint_dir, volumetric_files=volumetric_files)
ep_key = "standard" if vasp_files.get("standard") else "relax" + str(max(
int(k.split("relax")[-1]) for k in vasp_files if k.startswith("relax")
))
vasp_files = _find_vasp_files(
endpoint_dir, volumetric_files=volumetric_files
)
ep_key = (
"standard"
if vasp_files.get("standard")
else "relax"
+ str(
max(
int(k.split("relax")[-1])
for k in vasp_files
if k.startswith("relax")
)
)
)

endpoint_calculations[idx], endpoint_objects[idx] = Calculation.from_vasp_files(
(
endpoint_calculations[idx],
endpoint_objects[idx],
) = Calculation.from_vasp_files(
dir_name=endpoint_dir,
task_name=f"NEB endpoint {idx + 1}",
vasprun_file=vasp_files[ep_key]["vasprun_file"],
Expand All @@ -299,16 +314,17 @@ def from_directories(
"parse_potcar_file": False,
},
)

return cls.from_directory(
neb_directory,
volumetric_files=volumetric_files,
endpoint_calculations = endpoint_calculations,
endpoint_objects = endpoint_objects,
endpoint_directories = endpoint_directories,
**neb_task_doc_kwargs
endpoint_calculations=endpoint_calculations,
endpoint_objects=endpoint_objects,
endpoint_directories=endpoint_directories,
**neb_task_doc_kwargs,
)



def neb_barrier_spline_fit(
energies: Sequence[float],
spline_kwargs: dict | None = None,
Expand Down
29 changes: 20 additions & 9 deletions emmet-core/tests/test_neb.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def test_neb_doc(test_dir, from_dir: bool):
assert neb_doc.num_images == 3
assert len(neb_doc.image_structures) == neb_doc.num_images
assert len(neb_doc.energies) == neb_doc.num_images
assert len(neb_doc.structures) == neb_doc.num_images + 2 # always includes endpoints
assert (
len(neb_doc.structures) == neb_doc.num_images + 2
) # always includes endpoints
assert isinstance(neb_doc.orig_inputs, OrigInputs)

# test that NEB image calculations are all VASP Calculation objects
Expand Down Expand Up @@ -77,8 +79,8 @@ def test_neb_doc(test_dir, from_dir: bool):
)
assert len(neb_doc.image_energies) == neb_doc.num_images

def test_from_directories(test_dir):

def test_from_directories(test_dir):
with TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
shutil.unpack_archive(test_dir / "neb_sample_calc.zip", tmpdir, "zip")
Expand All @@ -87,23 +89,32 @@ def test_from_directories(test_dir):
tmpdir / "neb",
)

assert all(isinstance(ep_calc,Calculation) for ep_calc in neb_doc.endpoint_calculations)

assert all(
"relax_endpoint_" in ep_dir for ep_dir in neb_doc.endpoint_directories
isinstance(ep_calc, Calculation) for ep_calc in neb_doc.endpoint_calculations
)

assert all("relax_endpoint_" in ep_dir for ep_dir in neb_doc.endpoint_directories)

assert len(neb_doc.energies) == neb_doc.num_images + 2
assert len(neb_doc.structures) == neb_doc.num_images + 2
assert isinstance(neb_doc.barrier_analysis,dict)
assert isinstance(neb_doc.barrier_analysis, dict)

assert all(
neb_doc.barrier_analysis.get(k) is not None
for k in ("energies","frame_index","cubic_spline_pars","ts_frame_index","ts_energy","ts_in_frames","forward_barrier","reverse_barrier")
for k in (
"energies",
"frame_index",
"cubic_spline_pars",
"ts_frame_index",
"ts_energy",
"ts_in_frames",
"forward_barrier",
"reverse_barrier",
)
)

assert all(
getattr(neb_doc,f"{direction}_barrier") == neb_doc.barrier_analysis[f"{direction}_barrier"]
getattr(neb_doc, f"{direction}_barrier")
== neb_doc.barrier_analysis[f"{direction}_barrier"]
for direction in ("forward", "reverse")
)

0 comments on commit 4da5d94

Please sign in to comment.