Skip to content

Commit

Permalink
Make Incar keys case insensitive, fix init Incar from dict val pr…
Browse files Browse the repository at this point in the history
…ocessing for str/float/int (#4122)

* move ENCUT from int to float list

* remove str case sensitivity in check_params

* fix return type annotation

* fix unit test failure

* use original key for warning msg

* add duplicate check in check_params

* add test first, issue not fixed yet

* init from dict also use setter method and fix val filter logic

* relocate test_from_file_and_from_dict

* Revert "init from dict also use setter method and fix val filter logic"

This reverts commit adcdba7.

* remove seemingly unused monkeypatch

* tweak type

* make module level var all cap

* casting to list doesn't seem necessary, remain iterator for lazy eval

* add docstring to clarify parse list

* remove duplicate check

* reduce indentation level

* remove docstring of warn that doesn't exist

* fix typo in incar tag ECUT -> ENCUT

* inherit from UserDict, and make more ops case insensitive

* also override del and in methdos

* issue warning for duplicate keys

* enhance warning check

* tweak docstring

* fix type of float/int casting

* enhance test for from_dict consistency check

* relocate duplicate check to setter so that both from str and dict would be checked

* fix index error for vasprun

* move duplicate warning to init otherwise get false pos when update

* enhance unit test from type cast from dict

* remove unnecessary get default

* remove unnecessary type cast in check_params

* tweak Incar docstring

---------

Co-authored-by: Shyue Ping Ong <[email protected]>
  • Loading branch information
DanielYang59 and shyuep authored Oct 21, 2024
1 parent 7b02bf3 commit 91f12de
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 99 deletions.
108 changes: 72 additions & 36 deletions src/pymatgen/io/vasp/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import re
import subprocess
import warnings
from collections import Counter, UserDict
from enum import Enum, unique
from glob import glob
from hashlib import sha256
Expand Down Expand Up @@ -709,43 +710,67 @@ class BadPoscarWarning(UserWarning):
"""Warning class for bad POSCAR entries."""


class Incar(dict, MSONable):
class Incar(UserDict, MSONable):
"""
Read and write INCAR files.
Essentially a dictionary with some helper functions.
A case-insensitive dictionary to read/write INCAR files with additional helper functions.
- Keys are stored in uppercase to allow case-insensitive access (set, get, del, update, setdefault).
- String values are capitalized by default, except for keys specified
in the `lower_str_keys` of the `proc_val` method.
"""

def __init__(self, params: dict[str, Any] | None = None) -> None:
"""
Create an Incar object.
Clean up params and create an Incar object.
Args:
params (dict): Input parameters as a dictionary.
params (dict): INCAR parameters as a dictionary.
Warnings:
BadIncarWarning: If there are duplicate in keys (case insensitive).
"""
super().__init__()
if params is not None:
# If INCAR contains vector-like MAGMOMS given as a list
# of floats, convert to a list of lists
if (params.get("MAGMOM") and isinstance(params["MAGMOM"][0], int | float)) and (
params.get("LSORBIT") or params.get("LNONCOLLINEAR")
):
val = []
for idx in range(len(params["MAGMOM"]) // 3):
val.append(params["MAGMOM"][idx * 3 : (idx + 1) * 3])
params["MAGMOM"] = val

self.update(params)
params = params or {}

# Check for case-insensitive duplicate keys
key_counter = Counter(key.strip().upper() for key in params)
if duplicates := [key for key, count in key_counter.items() if count > 1]:
warnings.warn(f"Duplicate keys found (case-insensitive): {duplicates}", BadIncarWarning, stacklevel=2)

# If INCAR contains vector-like MAGMOMS given as a list
# of floats, convert to a list of lists
if (params.get("MAGMOM") and isinstance(params["MAGMOM"][0], int | float)) and (
params.get("LSORBIT") or params.get("LNONCOLLINEAR")
):
val: list[list] = []
for idx in range(len(params["MAGMOM"]) // 3):
val.append(params["MAGMOM"][idx * 3 : (idx + 1) * 3])
params["MAGMOM"] = val

super().__init__(params)

def __setitem__(self, key: str, val: Any) -> None:
"""
Add parameter-val pair to Incar. Warn if parameter is not in list of
valid INCAR tags. Also clean the parameter and val by stripping
leading and trailing white spaces.
Add parameter-val pair to Incar.
- Clean the parameter and val by stripping leading
and trailing white spaces.
- Cast keys to upper case.
"""
super().__setitem__(
key.strip().upper(),
type(self).proc_val(key.strip(), val.strip()) if isinstance(val, str) else val,
)
key = key.strip().upper()
# Cast float/int to str such that proc_val would clean up their types
val = self.proc_val(key, str(val)) if isinstance(val, str | float | int) else val
super().__setitem__(key, val)

def __getitem__(self, key: str) -> Any:
"""
Get value using a case-insensitive key.
"""
return super().__getitem__(key.strip().upper())

def __delitem__(self, key: str) -> None:
super().__delitem__(key.strip().upper())

def __contains__(self, key: str) -> bool:
return super().__contains__(key.upper().strip())

def __str__(self) -> str:
return self.get_str(sort_keys=True, pretty=False)
Expand All @@ -762,6 +787,12 @@ def __add__(self, other: Self) -> Self:
params[key] = val
return type(self)(params)

def get(self, key: str, default: Any = None) -> Any:
"""
Get a value for a case-insensitive key, return default if not found.
"""
return super().get(key.strip().upper(), default)

def as_dict(self) -> dict:
"""MSONable dict."""
dct = dict(self)
Expand Down Expand Up @@ -854,24 +885,23 @@ def from_str(cls, string: str) -> Self:
Returns:
Incar object
"""
lines: list[str] = list(clean_lines(string.splitlines()))
params: dict[str, Any] = {}
for line in lines:
for line in clean_lines(string.splitlines()):
for sline in line.split(";"):
if match := re.match(r"(\w+)\s*=\s*(.*)", sline.strip()):
key: str = match[1].strip()
val: Any = match[2].strip()
val: str = match[2].strip()
params[key] = cls.proc_val(key, val)
return cls(params)

@staticmethod
def proc_val(key: str, val: Any) -> list | bool | float | int | str:
def proc_val(key: str, val: str) -> list | bool | float | int | str:
"""Helper method to convert INCAR parameters to proper types
like ints, floats, lists, etc.
Args:
key (str): INCAR parameter key
val (Any): Value of INCAR parameter.
key (str): INCAR parameter key.
val (str): Value of INCAR parameter.
"""
list_keys = (
"LDAUU",
Expand Down Expand Up @@ -906,6 +936,7 @@ def proc_val(key: str, val: Any) -> list | bool | float | int | str:
"AGGAC",
"PARAM1",
"PARAM2",
"ENCUT",
)
int_keys = (
"NSW",
Expand All @@ -921,7 +952,6 @@ def proc_val(key: str, val: Any) -> list | bool | float | int | str:
"NPAR",
"LDAUPRINT",
"LMAXMIX",
"ENCUT",
"NSIM",
"NKRED",
"NUPDOWN",
Expand All @@ -931,7 +961,7 @@ def proc_val(key: str, val: Any) -> list | bool | float | int | str:
)
lower_str_keys = ("ML_MODE",)

def smart_int_or_float(num_str: str) -> str | float:
def smart_int_or_float(num_str: str) -> float:
"""Determine whether a string represents an integer or a float."""
if "." in num_str or "e" in num_str.lower():
return float(num_str)
Expand Down Expand Up @@ -1032,7 +1062,7 @@ def check_params(self) -> None:
warnings.warn(f"Cannot find {tag} in the list of INCAR tags", BadIncarWarning, stacklevel=2)
continue

# Check value and its type
# Check value type
param_type: str = incar_params[tag].get("type")
allowed_values: list[Any] = incar_params[tag].get("values")

Expand All @@ -1041,8 +1071,13 @@ def check_params(self) -> None:

# Only check value when it's not None,
# meaning there is recording for corresponding value
if allowed_values is not None and val not in allowed_values:
warnings.warn(f"{tag}: Cannot find {val} in the list of values", BadIncarWarning, stacklevel=2)
if allowed_values is not None:
# Note: param_type could be a Union type, e.g. "str | bool"
if "str" in param_type:
allowed_values = [item.capitalize() if isinstance(item, str) else item for item in allowed_values]

if val not in allowed_values:
warnings.warn(f"{tag}: Cannot find {val} in the list of values", BadIncarWarning, stacklevel=2)


class BadIncarWarning(UserWarning):
Expand Down Expand Up @@ -1712,6 +1747,7 @@ def _parse_int(string: str) -> int:


def _parse_list(string: str) -> list[float]:
"""Parse a list of floats from a string."""
return [float(y) for y in re.split(r"\s+", string.strip()) if not y.isalpha()]


Expand Down
17 changes: 9 additions & 8 deletions src/pymatgen/io/vasp/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ def _parse_from_incar(filename: PathLike, key: str) -> Any:
dirname = os.path.dirname(filename)
for fn in os.listdir(dirname):
if re.search("INCAR", fn):
warnings.warn(f"INCAR found. Using {key} from INCAR.")
warnings.warn(f"INCAR found. Using {key} from INCAR.", stacklevel=2)
incar = Incar.from_file(os.path.join(dirname, fn))
return incar.get(key, None)
return incar.get(key)
return None


Expand Down Expand Up @@ -347,7 +347,7 @@ def __init__(
self.update_potcar_spec(parse_potcar_file)
self.update_charge_from_potcar(parse_potcar_file)

if self.incar.get("ALGO") not in {"CHI", "BSE"} and not self.converged and self.parameters.get("IBRION") != 0:
if self.incar.get("ALGO") not in {"Chi", "Bse"} and not self.converged and self.parameters.get("IBRION") != 0:
msg = f"{filename} is an unconverged VASP run.\n"
msg += f"Electronic convergence reached: {self.converged_electronic}.\n"
msg += f"Ionic convergence reached: {self.converged_ionic}."
Expand All @@ -366,10 +366,10 @@ def _parse(
self.projected_magnetisation: NDArray | None = None
self.dielectric_data: dict[str, tuple] = {}
self.other_dielectric: dict[str, tuple] = {}
self.incar: dict[str, Any] = {}
self.incar: Incar = {}
self.kpoints_opt_props: KpointOptProps | None = None

ionic_steps: list = []
ionic_steps: list[dict] = []

md_data: list[dict] = []
parsed_header: bool = False
Expand Down Expand Up @@ -1357,9 +1357,9 @@ def as_dict(self) -> dict:
dct["output"] = vout
return jsanitize(dct, strict=True)

def _parse_params(self, elem: XML_Element) -> dict:
"""Parse INCAR parameters."""
params: dict = {}
def _parse_params(self, elem: XML_Element) -> Incar[str, Any]:
"""Parse INCAR parameters and more."""
params: dict[str, Any] = {}
for c in elem:
# VASP 6.4.3 can add trailing whitespace
# for example, <i type="string" name="GGA ">PE</i>
Expand All @@ -1371,6 +1371,7 @@ def _parse_params(self, elem: XML_Element) -> dict:
# which overrides the values in the root params.
p = {k: v for k, v in p.items() if k not in params}
params |= p

else:
ptype = c.attrib.get("type", "")
val = c.text.strip() if c.text else ""
Expand Down
36 changes: 18 additions & 18 deletions src/pymatgen/util/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from monty.io import zopen

if TYPE_CHECKING:
from collections.abc import Generator
from collections.abc import Iterator

__author__ = "Shyue Ping Ong, Rickard Armiento, Anubhav Jain, G Matteo, Ioannis Petousis"
__copyright__ = "Copyright 2011, The Materials Project"
Expand All @@ -21,31 +21,31 @@


def clean_lines(
string_list,
remove_empty_lines=True,
rstrip_only=False,
) -> Generator[str, None, None]:
string_list: list[str],
remove_empty_lines: bool = True,
rstrip_only: bool = False,
) -> Iterator[str]:
"""Strips whitespace, carriage returns and empty lines from a list of strings.
Args:
string_list: List of strings
remove_empty_lines: Set to True to skip lines which are empty after
string_list (list[str]): List of strings.
remove_empty_lines (bool): Set to True to skip lines which are empty after
stripping.
rstrip_only: Set to True to strip trailing whitespaces only (i.e.,
rstrip_only (bool): Set to True to strip trailing whitespaces only (i.e.,
to retain leading whitespaces). Defaults to False.
Yields:
list: clean strings with no whitespaces. If rstrip_only == True,
clean strings with no trailing whitespaces.
str: clean strings with no whitespaces.
"""
for s in string_list:
clean_s = s
if "#" in s:
ind = s.index("#")
clean_s = s[:ind]
clean_s = clean_s.rstrip() if rstrip_only else clean_s.strip()
if (not remove_empty_lines) or clean_s != "":
yield clean_s
for string in string_list:
clean_string = string
if "#" in string:
clean_string = string[: string.index("#")]

clean_string = clean_string.rstrip() if rstrip_only else clean_string.strip()

if (not remove_empty_lines) or clean_string != "":
yield clean_string


def micro_pyawk(filename, search, results=None, debug=None, postdebug=None):
Expand Down
Loading

0 comments on commit 91f12de

Please sign in to comment.