diff --git a/emmet-core/emmet/core/mpid.py b/emmet-core/emmet/core/mpid.py index dadc5a72ff..e037f80642 100644 --- a/emmet-core/emmet/core/mpid.py +++ b/emmet-core/emmet/core/mpid.py @@ -79,5 +79,7 @@ def validate(cls, v): return v elif isinstance(v, str) and mpid_regex.fullmatch(v): return MPID(v) + elif isinstance(v, int): + return MPID(v) raise ValueError("Invalid MPID Format") diff --git a/emmet-core/emmet/core/oxidation_states.py b/emmet-core/emmet/core/oxidation_states.py index 06564e712e..54f4f28574 100644 --- a/emmet-core/emmet/core/oxidation_states.py +++ b/emmet-core/emmet/core/oxidation_states.py @@ -3,23 +3,32 @@ from typing import Dict, List import numpy as np -from pydantic import BaseModel +from pydantic import Field from pymatgen.analysis.bond_valence import BVAnalyzer from pymatgen.core import Structure from pymatgen.core.periodic_table import Specie from typing_extensions import Literal +from emmet.core.material_property import PropertyDoc +from emmet.core.mpid import MPID -class OxidationStateDoc(BaseModel): - possible_species: List[str] - possible_valences: List[float] - average_oxidation_states: Dict[str, float] - method: Literal["BVAnalyzer", "oxi_state_guesses"] - structure: Structure +class OxidationStateDoc(PropertyDoc): + """Oxidation states computed from the structure""" + + possible_species: List[str] = Field( + description="Possible charged species in this material" + ) + possible_valences: List[float] = Field( + description="List of valences for each site in this material" + ) + average_oxidation_states: Dict[str, float] = Field( + description="Average oxidation states for each unique species" + ) + method: str = Field(description="Method used to compute oxidation states") @classmethod - def from_structure(cls, structure: Structure): + def from_structure(cls, structure: Structure, material_id: MPID, **kwargs): # type: ignore[override] structure.remove_oxidation_states() try: bva = BVAnalyzer() @@ -49,12 +58,9 @@ def from_structure(cls, structure: Structure): "possible_species": list(possible_species), "possible_valences": valences, "average_oxidation_states": oxi_state_dict, - "method": "BVAnalyzer", - "structure": structure, + "method": "Bond Valence Analysis", } - return cls(**d) - except Exception as e: logging.error("BVAnalyzer failed with: {}".format(e)) @@ -76,12 +82,13 @@ def from_structure(cls, structure: Structure): "possible_species": list(possible_species), "possible_valences": valences, "average_oxidation_states": first_oxi_state_guess, - "method": "oxi_state_guesses", - "structure": structure, + "method": "Oxidation State Guess", } - return cls(**d) - except Exception as e: logging.error("Oxidation state guess failed with: {}".format(e)) raise e + + return super().from_structure( + structure=structure, material_id=material_id, **d, **kwargs + ) diff --git a/tests/emmet-core/test_mpid.py b/tests/emmet-core/test_mpid.py index b43ccd1402..acd6c2f5ee 100644 --- a/tests/emmet-core/test_mpid.py +++ b/tests/emmet-core/test_mpid.py @@ -16,6 +16,8 @@ def test_mpid(): == 3 ) + MPID(3) + def test_to_str(): assert str(MPID("mp-149")) == "mp-149" diff --git a/tests/emmet-core/test_oxidation_states.py b/tests/emmet-core/test_oxidation_states.py index 1db2f2a13d..f3399daf94 100644 --- a/tests/emmet-core/test_oxidation_states.py +++ b/tests/emmet-core/test_oxidation_states.py @@ -31,5 +31,5 @@ def test_oxidation_state(structure: Structure): """Very simple test to make sure this actually works""" print(f"Should work : {structure.composition}") - doc = OxidationStateDoc.from_structure(structure) + doc = OxidationStateDoc.from_structure(structure, material_id=33) assert doc is not None