From 91f12de549fae6f53ef15bec0350465c6e3e8c0b Mon Sep 17 00:00:00 2001 From: "Haoyu (Daniel) YANG" Date: Tue, 22 Oct 2024 03:05:18 +0800 Subject: [PATCH] Make `Incar` keys case insensitive, fix init `Incar` from dict val processing for str/float/int (#4122) * move ENCUT from int to float list * remove str case sensitivity in check_params * fix return type annotation * fix unit test failure * use original key for warning msg * add duplicate check in check_params * add test first, issue not fixed yet * init from dict also use setter method and fix val filter logic * relocate test_from_file_and_from_dict * Revert "init from dict also use setter method and fix val filter logic" This reverts commit adcdba750ba5662c8ca133785441fa3617171bf8. * remove seemingly unused monkeypatch * tweak type * make module level var all cap * casting to list doesn't seem necessary, remain iterator for lazy eval * add docstring to clarify parse list * remove duplicate check * reduce indentation level * remove docstring of warn that doesn't exist * fix typo in incar tag ECUT -> ENCUT * inherit from UserDict, and make more ops case insensitive * also override del and in methdos * issue warning for duplicate keys * enhance warning check * tweak docstring * fix type of float/int casting * enhance test for from_dict consistency check * relocate duplicate check to setter so that both from str and dict would be checked * fix index error for vasprun * move duplicate warning to init otherwise get false pos when update * enhance unit test from type cast from dict * remove unnecessary get default * remove unnecessary type cast in check_params * tweak Incar docstring --------- Co-authored-by: Shyue Ping Ong --- src/pymatgen/io/vasp/inputs.py | 108 ++++++++++++++-------- src/pymatgen/io/vasp/outputs.py | 17 ++-- src/pymatgen/util/io_utils.py | 36 ++++---- tests/io/vasp/test_inputs.py | 158 ++++++++++++++++++++++++-------- 4 files changed, 220 insertions(+), 99 deletions(-) diff --git a/src/pymatgen/io/vasp/inputs.py b/src/pymatgen/io/vasp/inputs.py index 80ff44b8a8f..40c6f628063 100644 --- a/src/pymatgen/io/vasp/inputs.py +++ b/src/pymatgen/io/vasp/inputs.py @@ -14,6 +14,7 @@ import re import subprocess import warnings +from collections import Counter, UserDict from enum import Enum, unique from glob import glob from hashlib import sha256 @@ -709,43 +710,67 @@ class BadPoscarWarning(UserWarning): """Warning class for bad POSCAR entries.""" -class Incar(dict, MSONable): +class Incar(UserDict, MSONable): """ - Read and write INCAR files. - Essentially a dictionary with some helper functions. + A case-insensitive dictionary to read/write INCAR files with additional helper functions. + + - Keys are stored in uppercase to allow case-insensitive access (set, get, del, update, setdefault). + - String values are capitalized by default, except for keys specified + in the `lower_str_keys` of the `proc_val` method. """ def __init__(self, params: dict[str, Any] | None = None) -> None: """ - Create an Incar object. + Clean up params and create an Incar object. Args: - params (dict): Input parameters as a dictionary. + params (dict): INCAR parameters as a dictionary. + + Warnings: + BadIncarWarning: If there are duplicate in keys (case insensitive). """ - super().__init__() - if params is not None: - # If INCAR contains vector-like MAGMOMS given as a list - # of floats, convert to a list of lists - if (params.get("MAGMOM") and isinstance(params["MAGMOM"][0], int | float)) and ( - params.get("LSORBIT") or params.get("LNONCOLLINEAR") - ): - val = [] - for idx in range(len(params["MAGMOM"]) // 3): - val.append(params["MAGMOM"][idx * 3 : (idx + 1) * 3]) - params["MAGMOM"] = val - - self.update(params) + params = params or {} + + # Check for case-insensitive duplicate keys + key_counter = Counter(key.strip().upper() for key in params) + if duplicates := [key for key, count in key_counter.items() if count > 1]: + warnings.warn(f"Duplicate keys found (case-insensitive): {duplicates}", BadIncarWarning, stacklevel=2) + + # If INCAR contains vector-like MAGMOMS given as a list + # of floats, convert to a list of lists + if (params.get("MAGMOM") and isinstance(params["MAGMOM"][0], int | float)) and ( + params.get("LSORBIT") or params.get("LNONCOLLINEAR") + ): + val: list[list] = [] + for idx in range(len(params["MAGMOM"]) // 3): + val.append(params["MAGMOM"][idx * 3 : (idx + 1) * 3]) + params["MAGMOM"] = val + + super().__init__(params) def __setitem__(self, key: str, val: Any) -> None: """ - Add parameter-val pair to Incar. Warn if parameter is not in list of - valid INCAR tags. Also clean the parameter and val by stripping - leading and trailing white spaces. + Add parameter-val pair to Incar. + - Clean the parameter and val by stripping leading + and trailing white spaces. + - Cast keys to upper case. """ - super().__setitem__( - key.strip().upper(), - type(self).proc_val(key.strip(), val.strip()) if isinstance(val, str) else val, - ) + key = key.strip().upper() + # Cast float/int to str such that proc_val would clean up their types + val = self.proc_val(key, str(val)) if isinstance(val, str | float | int) else val + super().__setitem__(key, val) + + def __getitem__(self, key: str) -> Any: + """ + Get value using a case-insensitive key. + """ + return super().__getitem__(key.strip().upper()) + + def __delitem__(self, key: str) -> None: + super().__delitem__(key.strip().upper()) + + def __contains__(self, key: str) -> bool: + return super().__contains__(key.upper().strip()) def __str__(self) -> str: return self.get_str(sort_keys=True, pretty=False) @@ -762,6 +787,12 @@ def __add__(self, other: Self) -> Self: params[key] = val return type(self)(params) + def get(self, key: str, default: Any = None) -> Any: + """ + Get a value for a case-insensitive key, return default if not found. + """ + return super().get(key.strip().upper(), default) + def as_dict(self) -> dict: """MSONable dict.""" dct = dict(self) @@ -854,24 +885,23 @@ def from_str(cls, string: str) -> Self: Returns: Incar object """ - lines: list[str] = list(clean_lines(string.splitlines())) params: dict[str, Any] = {} - for line in lines: + for line in clean_lines(string.splitlines()): for sline in line.split(";"): if match := re.match(r"(\w+)\s*=\s*(.*)", sline.strip()): key: str = match[1].strip() - val: Any = match[2].strip() + val: str = match[2].strip() params[key] = cls.proc_val(key, val) return cls(params) @staticmethod - def proc_val(key: str, val: Any) -> list | bool | float | int | str: + def proc_val(key: str, val: str) -> list | bool | float | int | str: """Helper method to convert INCAR parameters to proper types like ints, floats, lists, etc. Args: - key (str): INCAR parameter key - val (Any): Value of INCAR parameter. + key (str): INCAR parameter key. + val (str): Value of INCAR parameter. """ list_keys = ( "LDAUU", @@ -906,6 +936,7 @@ def proc_val(key: str, val: Any) -> list | bool | float | int | str: "AGGAC", "PARAM1", "PARAM2", + "ENCUT", ) int_keys = ( "NSW", @@ -921,7 +952,6 @@ def proc_val(key: str, val: Any) -> list | bool | float | int | str: "NPAR", "LDAUPRINT", "LMAXMIX", - "ENCUT", "NSIM", "NKRED", "NUPDOWN", @@ -931,7 +961,7 @@ def proc_val(key: str, val: Any) -> list | bool | float | int | str: ) lower_str_keys = ("ML_MODE",) - def smart_int_or_float(num_str: str) -> str | float: + def smart_int_or_float(num_str: str) -> float: """Determine whether a string represents an integer or a float.""" if "." in num_str or "e" in num_str.lower(): return float(num_str) @@ -1032,7 +1062,7 @@ def check_params(self) -> None: warnings.warn(f"Cannot find {tag} in the list of INCAR tags", BadIncarWarning, stacklevel=2) continue - # Check value and its type + # Check value type param_type: str = incar_params[tag].get("type") allowed_values: list[Any] = incar_params[tag].get("values") @@ -1041,8 +1071,13 @@ def check_params(self) -> None: # Only check value when it's not None, # meaning there is recording for corresponding value - if allowed_values is not None and val not in allowed_values: - warnings.warn(f"{tag}: Cannot find {val} in the list of values", BadIncarWarning, stacklevel=2) + if allowed_values is not None: + # Note: param_type could be a Union type, e.g. "str | bool" + if "str" in param_type: + allowed_values = [item.capitalize() if isinstance(item, str) else item for item in allowed_values] + + if val not in allowed_values: + warnings.warn(f"{tag}: Cannot find {val} in the list of values", BadIncarWarning, stacklevel=2) class BadIncarWarning(UserWarning): @@ -1712,6 +1747,7 @@ def _parse_int(string: str) -> int: def _parse_list(string: str) -> list[float]: + """Parse a list of floats from a string.""" return [float(y) for y in re.split(r"\s+", string.strip()) if not y.isalpha()] diff --git a/src/pymatgen/io/vasp/outputs.py b/src/pymatgen/io/vasp/outputs.py index 3b208a762c5..46c80210c5e 100644 --- a/src/pymatgen/io/vasp/outputs.py +++ b/src/pymatgen/io/vasp/outputs.py @@ -142,9 +142,9 @@ def _parse_from_incar(filename: PathLike, key: str) -> Any: dirname = os.path.dirname(filename) for fn in os.listdir(dirname): if re.search("INCAR", fn): - warnings.warn(f"INCAR found. Using {key} from INCAR.") + warnings.warn(f"INCAR found. Using {key} from INCAR.", stacklevel=2) incar = Incar.from_file(os.path.join(dirname, fn)) - return incar.get(key, None) + return incar.get(key) return None @@ -347,7 +347,7 @@ def __init__( self.update_potcar_spec(parse_potcar_file) self.update_charge_from_potcar(parse_potcar_file) - if self.incar.get("ALGO") not in {"CHI", "BSE"} and not self.converged and self.parameters.get("IBRION") != 0: + if self.incar.get("ALGO") not in {"Chi", "Bse"} and not self.converged and self.parameters.get("IBRION") != 0: msg = f"{filename} is an unconverged VASP run.\n" msg += f"Electronic convergence reached: {self.converged_electronic}.\n" msg += f"Ionic convergence reached: {self.converged_ionic}." @@ -366,10 +366,10 @@ def _parse( self.projected_magnetisation: NDArray | None = None self.dielectric_data: dict[str, tuple] = {} self.other_dielectric: dict[str, tuple] = {} - self.incar: dict[str, Any] = {} + self.incar: Incar = {} self.kpoints_opt_props: KpointOptProps | None = None - ionic_steps: list = [] + ionic_steps: list[dict] = [] md_data: list[dict] = [] parsed_header: bool = False @@ -1357,9 +1357,9 @@ def as_dict(self) -> dict: dct["output"] = vout return jsanitize(dct, strict=True) - def _parse_params(self, elem: XML_Element) -> dict: - """Parse INCAR parameters.""" - params: dict = {} + def _parse_params(self, elem: XML_Element) -> Incar[str, Any]: + """Parse INCAR parameters and more.""" + params: dict[str, Any] = {} for c in elem: # VASP 6.4.3 can add trailing whitespace # for example, PE @@ -1371,6 +1371,7 @@ def _parse_params(self, elem: XML_Element) -> dict: # which overrides the values in the root params. p = {k: v for k, v in p.items() if k not in params} params |= p + else: ptype = c.attrib.get("type", "") val = c.text.strip() if c.text else "" diff --git a/src/pymatgen/util/io_utils.py b/src/pymatgen/util/io_utils.py index b40a6edb50d..f8c7d268f43 100644 --- a/src/pymatgen/util/io_utils.py +++ b/src/pymatgen/util/io_utils.py @@ -9,7 +9,7 @@ from monty.io import zopen if TYPE_CHECKING: - from collections.abc import Generator + from collections.abc import Iterator __author__ = "Shyue Ping Ong, Rickard Armiento, Anubhav Jain, G Matteo, Ioannis Petousis" __copyright__ = "Copyright 2011, The Materials Project" @@ -21,31 +21,31 @@ def clean_lines( - string_list, - remove_empty_lines=True, - rstrip_only=False, -) -> Generator[str, None, None]: + string_list: list[str], + remove_empty_lines: bool = True, + rstrip_only: bool = False, +) -> Iterator[str]: """Strips whitespace, carriage returns and empty lines from a list of strings. Args: - string_list: List of strings - remove_empty_lines: Set to True to skip lines which are empty after + string_list (list[str]): List of strings. + remove_empty_lines (bool): Set to True to skip lines which are empty after stripping. - rstrip_only: Set to True to strip trailing whitespaces only (i.e., + rstrip_only (bool): Set to True to strip trailing whitespaces only (i.e., to retain leading whitespaces). Defaults to False. Yields: - list: clean strings with no whitespaces. If rstrip_only == True, - clean strings with no trailing whitespaces. + str: clean strings with no whitespaces. """ - for s in string_list: - clean_s = s - if "#" in s: - ind = s.index("#") - clean_s = s[:ind] - clean_s = clean_s.rstrip() if rstrip_only else clean_s.strip() - if (not remove_empty_lines) or clean_s != "": - yield clean_s + for string in string_list: + clean_string = string + if "#" in string: + clean_string = string[: string.index("#")] + + clean_string = clean_string.rstrip() if rstrip_only else clean_string.strip() + + if (not remove_empty_lines) or clean_string != "": + yield clean_string def micro_pyawk(filename, search, results=None, debug=None, postdebug=None): diff --git a/tests/io/vasp/test_inputs.py b/tests/io/vasp/test_inputs.py index b510068ef2d..130fc90ea51 100644 --- a/tests/io/vasp/test_inputs.py +++ b/tests/io/vasp/test_inputs.py @@ -14,6 +14,7 @@ import scipy.constants as const from monty.io import zopen from monty.serialization import loadfn +from monty.tempfile import ScratchDir from numpy.testing import assert_allclose from pytest import approx @@ -39,25 +40,25 @@ from pymatgen.util.testing import FAKE_POTCAR_DIR, TEST_FILES_DIR, VASP_IN_DIR, VASP_OUT_DIR, PymatgenTest # make sure _gen_potcar_summary_stats runs and works with all tests in this file -_summ_stats = _gen_potcar_summary_stats(append=False, vasp_psp_dir=str(FAKE_POTCAR_DIR), summary_stats_filename=None) +_SUMM_STATS = _gen_potcar_summary_stats(append=False, vasp_psp_dir=str(FAKE_POTCAR_DIR), summary_stats_filename=None) @pytest.fixture(autouse=True) def _mock_complete_potcar_summary_stats(monkeypatch: pytest.MonkeyPatch) -> None: # Override POTCAR library to use fake scrambled POTCARs monkeypatch.setitem(SETTINGS, "PMG_VASP_PSP_DIR", str(FAKE_POTCAR_DIR)) - monkeypatch.setattr(PotcarSingle, "_potcar_summary_stats", _summ_stats) + monkeypatch.setattr(PotcarSingle, "_potcar_summary_stats", _SUMM_STATS) # The fake POTCAR library is pretty big even with just a few sub-libraries # just copying over entries to work with PotcarSingle.is_valid for func in PotcarSingle.functional_dir: - if func in _summ_stats: + if func in _SUMM_STATS: continue if "pbe" in func.lower() or "pw91" in func.lower(): # Generate POTCAR hashes on the fly - _summ_stats[func] = _summ_stats["PBE_54_W_HASH"].copy() + _SUMM_STATS[func] = _SUMM_STATS["PBE_54_W_HASH"].copy() elif "lda" in func.lower() or "perdew_zunger81" in func.lower(): - _summ_stats[func] = _summ_stats["LDA_64"].copy() + _SUMM_STATS[func] = _SUMM_STATS["LDA_64"].copy() class TestPoscar(PymatgenTest): @@ -546,6 +547,64 @@ def test_init(self): assert float(incar["EDIFF"]) == 1e-4, "Wrong EDIFF" assert isinstance(incar["LORBIT"], int) + def test_check_for_duplicate(self): + incar_str: str = """encut = 400 + ENCUT = 500 + """ + with pytest.warns(BadIncarWarning, match=re.escape("Duplicate keys found (case-insensitive): ['ENCUT']")): + Incar.from_str(incar_str) + + incar_dict = {"ALGO": "Fast", "algo": "fast"} + with pytest.warns(BadIncarWarning, match=re.escape("Duplicate keys found (case-insensitive): ['ALGO']")): + Incar.from_dict(incar_dict) + + def test_key_case_insensitive(self): + """Verify that keys are case-insensitive by internally converting + all keys to upper case. This includes operations such as: + - set/get: Keys can be set and retrieved with any case. + - update: Keys in updates are case-insensitive. + - setdefault: Defaults are set and retrieved case-insensitively. + """ + test_tag: str = "ENCUT" + + incar_str: str = f"""ALGO = Fast + {test_tag} = 480 + EDIFF = 1e-07 + """ + + # Test setter and getter + incar: Incar = Incar.from_str(incar_str) + incar[test_tag.lower()] = 490 + assert incar[test_tag.lower()] == 490 + assert incar[test_tag.upper()] == 490 + assert incar.get(test_tag.lower()) == 490 + assert incar.get(test_tag.upper()) == 490 + + incar[test_tag.upper()] = 500 + assert incar[test_tag.lower()] == 500 + + # Test delete + del incar["algo"] + assert "ALGO" not in incar + + # Test membership check + assert test_tag.upper() in incar + assert test_tag.lower() in incar + + # Test update + incar.update({test_tag.lower(): 510}) + assert incar[test_tag] == 510 + + incar.update({test_tag.upper(): 520}) + assert incar[test_tag] == 520 + + # Test setdefault + incar.setdefault("ismear", 0) + assert incar["ISMEAR"] == 0 + + incar.setdefault("NPAR", 4) + assert incar["npar"] == 4 + def test_copy(self): incar2 = self.incar.copy() assert isinstance(incar2, Incar), f"Expected Incar, got {type(incar2).__name__}" @@ -638,6 +697,39 @@ def test_as_dict_and_from_dict(self): incar3 = Incar.from_dict(dct) assert incar3["MAGMOM"] == [Magmom([1, 2, 3])] + def test_from_file_and_from_dict(self): + """ + Init from file (from str) should yield the same results as from dict. + + Previously init Incar from dict would bypass the proc_val method for + float/int, and might yield values in wrong type. + """ + # Init from dict + incar_dict = {"ENCUT": 500, "GGA": "PS", "NELM": 60.0} + incar_from_dict = Incar(incar_dict) + + # Init from file (from string) + incar_str = """\ + ENCUT = 500 + GGA = PS + NELM = 60.0 + """ + + with ScratchDir("."): + with open("INCAR", "w", encoding="utf-8") as f: + f.write(incar_str) + + incar_from_file = Incar.from_file("INCAR") + + # Make sure int/float is cast to correct type when init from dict + assert incar_from_dict["GGA"] == "Ps" + assert isinstance(incar_from_dict["ENCUT"], float) + assert isinstance(incar_from_dict["NELM"], int) + + assert incar_from_dict == incar_from_file + for key in incar_from_dict: + assert type(incar_from_dict[key]) is type(incar_from_file[key]) + def test_write(self): tmp_file = f"{self.tmp_path}/INCAR.testing" self.incar.write_file(tmp_file) @@ -648,7 +740,7 @@ def test_get_str(self): incar_str = self.incar.get_str(pretty=True, sort_keys=True) expected = """ALGO = Damped EDIFF = 0.0001 -ENCUT = 500 +ENCUT = 500.0 ENCUTFOCK = 0.0 HFSCREEN = 0.207 IBRION = 2 @@ -736,7 +828,7 @@ def test_quad_efg(self): def test_types(self): incar_str = """ALGO = Fast -ECUT = 510 +ENCUT = 510 EDIFF = 1e-07 EINT = -0.85 0.85 IBRION = -1 @@ -759,16 +851,6 @@ def test_types(self): assert incar["HFSCREEN"] == 0.2 assert incar["ALGO"] == "All" - def test_upper_keys(self): - incar_str = """ALGO = Fast - ECUT = 510 - EDIFF = 1e-07 - """ - incar = Incar.from_str(incar_str) - incar["ecut"] = 480 - assert "ecut" not in incar - assert incar["ECUT"] == 480 - def test_proc_types(self): assert Incar.proc_val("HELLO", "-0.85 0.85") == "-0.85 0.85" assert Incar.proc_val("ML_MODE", "train") == "train" @@ -777,31 +859,33 @@ def test_proc_types(self): def test_check_params(self): # Triggers warnings when running into invalid parameters + incar = Incar( + { + "ADDGRID": True, + "ALGO": "Normal", + "AMIN": 0.01, + "ICHARG": 1, + "MAGMOM": [1, 2, 4, 5], + "ENCUT": 500, # make sure float key is casted + "GGA": "PS", # test string case insensitivity + "LREAL": True, # special case: Union type + "NBAND": 250, # typo in tag + "METAGGA": "SCAM", # typo in value + "EDIFF": 5 + 1j, # value should be a float + "ISIF": 9, # value out of range + "LASPH": 5, # value should be bool + "PHON_TLIST": "is_a_str", # value should be a list + } + ) with pytest.warns(BadIncarWarning) as record: - incar = Incar( - { - "ADDGRID": True, - "ALGO": "Normal", - "AMIN": 0.01, - "ICHARG": 1, - "MAGMOM": [1, 2, 4, 5], - "LREAL": True, # special case: Union type - "NBAND": 250, # typo in tag - "METAGGA": "SCAM", # typo in value - "EDIFF": 5 + 1j, # value should be a float - "ISIF": 9, # value out of range - "LASPH": 5, # value should be bool - "PHON_TLIST": "is_a_str", # value should be a list - } - ) incar.check_params() assert record[0].message.args[0] == "Cannot find NBAND in the list of INCAR tags" - assert record[1].message.args[0] == "METAGGA: Cannot find SCAM in the list of values" + assert record[1].message.args[0] == "METAGGA: Cannot find Scam in the list of values" assert record[2].message.args[0] == "EDIFF: (5+1j) is not a float" assert record[3].message.args[0] == "ISIF: Cannot find 9 in the list of values" assert record[4].message.args[0] == "LASPH: 5 is not a bool" - assert record[5].message.args[0] == "PHON_TLIST: is_a_str is not a list" + assert record[5].message.args[0] == "PHON_TLIST: Is_a_str is not a list" class TestKpointsSupportedModes: @@ -1475,8 +1559,8 @@ def test_potcar_summary_stats() -> None: assert actual == expected, f"{key=}, {expected=}, {actual=}" -def test_gen_potcar_summary_stats(monkeypatch: pytest.MonkeyPatch) -> None: - assert set(_summ_stats) == set(PotcarSingle.functional_dir) +def test_gen_potcar_summary_stats() -> None: + assert set(_SUMM_STATS) == set(PotcarSingle.functional_dir) expected_funcs = [x for x in os.listdir(str(FAKE_POTCAR_DIR)) if x in PotcarSingle.functional_dir]