Skip to content

Commit

Permalink
Add types for `core.periodic_table/bonds/composition/ion/lattice/libx…
Browse files Browse the repository at this point in the history
…cfunc`, new type `MillerIndex` and fix Lattice hash (#3814)

* tweak type and docstring

* move dunder methods to the top

* add more types and tweaks

* relocate more dunder methods to top

* more types and format tweaks

* fix type error

* add types for composition

* help fix #3792 (comment)

* reverse compare order for readability

* Revert "reverse compare order for readability"

This reverts commit 05ea23a.

* Revert "help fix #3792 (comment)"

This reverts commit cae7aed.

* add types for `core.bonds`

* finish `core.ion`

* add some types

* revert changes on core.interface

* add types for `libxfunc`

* remove unnecessary `libxc_version`

* recover header

* finish `Lattice`

* fix unit test

* Revert "fix unit test"

This reverts commit 4d17bcd.

* merge two pbc checks

* improve property pbc setter

* fix type for pbc

* fix circular import

* revert script (unsure if needed)

* add new type `MillerIndex`

* change types to `MillerIndex`

* add a proper hash method

* fix typo in composition check of OH

* revert non-interface changes

* Revert "revert non-interface changes"

This reverts commit 7930238.

* refactor

* tweak doc strings

---------

Co-authored-by: Janosh Riebesell <[email protected]>
  • Loading branch information
DanielYang59 and janosh authored May 12, 2024
1 parent 2e1c301 commit 578d29c
Show file tree
Hide file tree
Showing 42 changed files with 624 additions and 600 deletions.
4 changes: 2 additions & 2 deletions dev_scripts/regen_libxcfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,11 @@ def main():
del lines[start + 1 : stop]

# [2] write new py module
with open(xc_funcpy_path, mode="w") as file:
with open(xc_funcpy_path, mode="w", encoding="utf-8") as file:
file.writelines(lines)

print("Files have been regenerated")
print("Remember to update libxc_version in libxcfuncs.py!")
print("Remember to update __version__ in libxcfuncs.py!")

return 0

Expand Down
3 changes: 1 addition & 2 deletions pymatgen/analysis/elasticity/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,8 +914,7 @@ def find_eq_stress(strains, stresses, tol: float = 1e-10):


def get_strain_state_dict(strains, stresses, eq_stress=None, tol: float = 1e-10, add_eq=True, sort=True):
"""
Creates a dictionary of voigt notation stress-strain sets
"""Create a dictionary of voigt notation stress-strain sets
keyed by "strain state", i. e. a tuple corresponding to
the non-zero entries in ratios to the lowest nonzero value,
e.g. [0, 0.1, 0, 0.2, 0, 0] -> (0,1,0,2,0,0)
Expand Down
10 changes: 3 additions & 7 deletions pymatgen/analysis/ewald.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,17 +541,15 @@ def __init__(self, matrix, m_list, num_to_return=1, algo=ALGO_FAST):
self._minimized_sum = self._output_lists[0][0]

def minimize_matrix(self):
"""
This method finds and returns the permutations that produce the lowest
"""Get the permutations that produce the lowest
Ewald sum calls recursive function to iterate through permutations.
"""
if self._algo in (EwaldMinimizer.ALGO_FAST, EwaldMinimizer.ALGO_BEST_FIRST):
return self._recurse(self._matrix, self._m_list, set(range(len(self._matrix))))
return None

def add_m_list(self, matrix_sum, m_list):
"""
This adds an m_list to the output_lists and updates the current
"""Add an m_list to the output_lists and updates the current
minimum if the list is full.
"""
if self._output_lists is None:
Expand Down Expand Up @@ -629,9 +627,7 @@ def get_next_index(cls, matrix, manipulation, indices_left):
return indices[sums.argmax(axis=0)] if f < 1 else indices[sums.argmin(axis=0)]

def _recurse(self, matrix, m_list, indices, output_m_list=None):
"""
This method recursively finds the minimal permutations using a binary
tree search strategy.
"""Find the minimal permutations using a binary tree search strategy.
Args:
matrix: The current matrix (with some permutations already
Expand Down
3 changes: 1 addition & 2 deletions pymatgen/analysis/ferroelectricity/polarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@


def zval_dict_from_potcar(potcar) -> dict[str, float]:
"""
Creates zval_dictionary for calculating the ionic polarization from
"""Create zval_dictionary for calculating the ionic polarization from
Potcar object.
potcar: Potcar object
Expand Down
4 changes: 1 addition & 3 deletions pymatgen/analysis/magnetism/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,9 +604,7 @@ def __init__(
truncate_by_symmetry: bool = True,
transformation_kwargs: dict | None = None,
):
"""
This class will try generated different collinear
magnetic orderings for a given input structure.
"""Generate different collinear magnetic orderings for a given input structure.
If the input structure has magnetic moments defined, it
is possible to use these as a hint as to which elements are
Expand Down
6 changes: 2 additions & 4 deletions pymatgen/analysis/magnetism/heisenberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,10 +669,8 @@ class HeisenbergScreener:
"""Clean and screen magnetic orderings."""

def __init__(self, structures, energies, screen=False):
"""
This class pre-processes magnetic orderings and energies for
HeisenbergMapper. It prioritizes low-energy orderings with large and
localized magnetic moments.
"""Pre-processes magnetic orderings and energies for HeisenbergMapper.
It prioritizes low-energy orderings with large and localized magnetic moments.
Args:
structures (list): Structure objects with magnetic moments.
Expand Down
4 changes: 1 addition & 3 deletions pymatgen/analysis/structure_prediction/substitutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ class Substitutor(MSONable):
charge_balanced_tol: float = 1e-9

def __init__(self, threshold=1e-3, symprec: float = 0.1, **kwargs):
"""
This substitutor uses the substitution probability class to
find good substitutions for a given chemistry or structure.
"""Use the substitution probability class to find good substitutions for a given chemistry or structure.
Args:
threshold:
Expand Down
7 changes: 3 additions & 4 deletions pymatgen/analysis/wulff.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,8 +671,7 @@ def anisotropy(self) -> float:

@property
def shape_factor(self) -> float:
"""
This is useful for determining the critical nucleus size.
"""Determine the critical nucleus size.
A large shape factor indicates great anisotropy.
See Ballufi, R. W., Allen, S. M. & Carter, W. C. Kinetics
of Materials. (John Wiley & Sons, 2005), p.461.
Expand All @@ -685,10 +684,10 @@ def shape_factor(self) -> float:
@property
def effective_radius(self) -> float:
"""
Radius of the WulffShape when the WulffShape is approximated as a sphere.
Radius of the WulffShape (in Angstroms) when the WulffShape is approximated as a sphere.
Returns:
float: radius.
float: radius R_eff
"""
return ((3 / 4) * (self.volume / np.pi)) ** (1 / 3)

Expand Down
68 changes: 47 additions & 21 deletions pymatgen/core/bonds.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""This class implements definitions for various kinds of bonds. Typically used in
"""This module implements definitions for various kinds of bonds. Typically used in
Molecule analysis.
"""

Expand All @@ -17,10 +17,13 @@
from pymatgen.util.typing import SpeciesLike


def _load_bond_length_data():
"""Loads bond length data from json file."""
with open(os.path.join(os.path.dirname(__file__), "bond_lengths.json")) as file:
data = defaultdict(dict)
def _load_bond_length_data() -> dict[tuple[str, ...], dict[float, float]]:
"""Load bond length data from bond_lengths.json file."""
with open(
os.path.join(os.path.dirname(__file__), "bond_lengths.json"),
encoding="utf-8",
) as file:
data: dict[tuple, dict] = defaultdict(dict)
for row in json.load(file):
els = sorted(row["elements"])
data[tuple(els)][row["bond_order"]] = row["length"]
Expand All @@ -43,12 +46,19 @@ def __init__(self, site1: Site, site2: Site) -> None:
self.site1 = site1
self.site2 = site2

def __repr__(self) -> str:
return f"Covalent bond between {self.site1} and {self.site2}"

@property
def length(self) -> float:
"""Length of the bond."""
return self.site1.distance(self.site2)

def get_bond_order(self, tol: float = 0.2, default_bl: float | None = None) -> float:
def get_bond_order(
self,
tol: float = 0.2,
default_bl: float | None = None,
) -> float:
"""The bond order according the distance between the two sites.
Args:
Expand All @@ -71,8 +81,14 @@ def get_bond_order(self, tol: float = 0.2, default_bl: float | None = None) -> f
return get_bond_order(sp1, sp2, dist, tol, default_bl)

@staticmethod
def is_bonded(site1, site2, tol: float = 0.2, bond_order: float | None = None, default_bl: float | None = None):
"""Test if two sites are bonded, up to a certain limit.
def is_bonded(
site1: Site,
site2: Site,
tol: float = 0.2,
bond_order: float | None = None,
default_bl: float | None = None,
) -> bool:
"""Check if two sites are bonded, up to a certain limit.
Args:
site1 (Site): First site
Expand All @@ -87,7 +103,7 @@ def is_bonded(site1, site2, tol: float = 0.2, bond_order: float | None = None, d
bond length. If None, a ValueError will be thrown.
Returns:
Boolean indicating whether two sites are bonded.
bool: whether two sites are bonded.
"""
sp1 = next(iter(site1.species))
sp2 = next(iter(site2.species))
Expand All @@ -102,11 +118,12 @@ def is_bonded(site1, site2, tol: float = 0.2, bond_order: float | None = None, d
return dist < (1 + tol) * default_bl
raise ValueError(f"No bond data for elements {syms[0]} - {syms[1]}")

def __repr__(self) -> str:
return f"Covalent bond between {self.site1} and {self.site2}"


def obtain_all_bond_lengths(sp1, sp2, default_bl: float | None = None):
def obtain_all_bond_lengths(
sp1: SpeciesLike,
sp2: SpeciesLike,
default_bl: float | None = None,
) -> dict[float, float]:
"""Obtain bond lengths for all bond orders from bond length database.
Args:
Expand All @@ -127,17 +144,23 @@ def obtain_all_bond_lengths(sp1, sp2, default_bl: float | None = None):
if syms in bond_lengths:
return bond_lengths[syms].copy()
if default_bl is not None:
return {1: default_bl}
return {1.0: default_bl}
raise ValueError(f"No bond data for elements {syms[0]} - {syms[1]}")


def get_bond_order(sp1, sp2, dist: float, tol: float = 0.2, default_bl: float | None = None):
def get_bond_order(
sp1: SpeciesLike,
sp2: SpeciesLike,
dist: float,
tol: float = 0.2,
default_bl: float | None = None,
) -> float:
"""Calculate the bond order given the distance of 2 species.
Args:
sp1 (Species): First specie.
sp2 (Species): Second specie.
dist: Their distance in angstrom
dist (float): Distance in angstrom
tol (float): Relative tolerance to test. Basically, the code
checks if the distance between the sites is larger than
(1 + tol) * the longest bond distance or smaller than
Expand All @@ -148,8 +171,7 @@ def get_bond_order(sp1, sp2, dist: float, tol: float = 0.2, default_bl: float |
bond length (bond order = 1). If None, a ValueError will be thrown.
Returns:
Float value of bond order. For example, for C-C bond in benzene,
return 1.7.
float: Bond order. For example, 1.7 for C-C bond in benzene.
"""
all_lens = obtain_all_bond_lengths(sp1, sp2, default_bl)
# Transform bond lengths dict to list assuming bond data is successive
Expand All @@ -172,7 +194,11 @@ def get_bond_order(sp1, sp2, dist: float, tol: float = 0.2, default_bl: float |
return trial_bond_order - 1


def get_bond_length(sp1: SpeciesLike, sp2: SpeciesLike, bond_order: float = 1) -> float:
def get_bond_length(
sp1: SpeciesLike,
sp2: SpeciesLike,
bond_order: float = 1,
) -> float:
"""Get the bond length between two species.
Args:
Expand All @@ -184,8 +210,8 @@ def get_bond_length(sp1: SpeciesLike, sp2: SpeciesLike, bond_order: float = 1) -
C-C bond length, this should be set to 2. Defaults to 1.
Returns:
Bond length in Angstrom. If no data is available, the sum of the atomic
radius is used.
float: Bond length in Angstrom. If no data is available,
the sum of the atomic radius is used.
"""
sp1 = Element(sp1) if isinstance(sp1, str) else sp1
sp2 = Element(sp2) if isinstance(sp2, str) else sp2
Expand Down
40 changes: 21 additions & 19 deletions pymatgen/core/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,14 +412,14 @@ def get_integer_formula_and_factor(
Li0.5O0.25 returns (Li2O, 0.25). O0.25 returns (O2, 0.125)
"""
el_amt = self.get_el_amt_dict()
gcd = gcd_float(list(el_amt.values()), 1 / max_denominator)
_gcd = gcd_float(list(el_amt.values()), 1 / max_denominator)

dct = {key: round(val / gcd) for key, val in el_amt.items()}
dct = {key: round(val / _gcd) for key, val in el_amt.items()}
formula, factor = reduce_formula(dct, iupac_ordering=iupac_ordering)
if formula in Composition.special_formulas:
formula = Composition.special_formulas[formula]
factor /= 2
return formula, factor * gcd
return formula, factor * _gcd

@property
def reduced_formula(self) -> str:
Expand Down Expand Up @@ -509,18 +509,19 @@ def contains_element_type(self, category: str) -> bool:
category (str): one of "noble_gas", "transition_metal",
"post_transition_metal", "rare_earth_metal", "metal", "metalloid",
"alkali", "alkaline", "halogen", "chalcogen", "lanthanoid",
"actinoid", "radioactive", "quadrupolar", "s-block", "p-block", "d-block", "f-block"
"actinoid", "radioactive", "quadrupolar", "s-block", "p-block", "d-block", "f-block".
Returns:
bool: True if any elements in Composition match category, otherwise False
bool: Whether any elements in Composition match category.
"""
allowed_categories = [category.value for category in ElementType]
allowed_categories = [element.value for element in ElementType]

if category not in allowed_categories:
raise ValueError(f"Invalid {category=}, pick from {allowed_categories}")

if "block" in category:
return any(category[0] in el.block for el in self.elements)
return category[0] in [el.block for el in self.elements]

return any(getattr(el, f"is_{category}") for el in self.elements)

def _parse_formula(self, formula: str, strict: bool = True) -> dict[str, float]:
Expand Down Expand Up @@ -1083,15 +1084,16 @@ def _comps_from_fuzzy_formula(
"""
m_dict = m_dict or {}

def _parse_chomp_and_rank(m, f, m_dict, m_points):
def _parse_chomp_and_rank(match, formula: str, m_dict: dict[str, float], m_points: int) -> tuple:
"""A helper method for formula parsing that helps in interpreting and
ranking indeterminate formulas
ranking indeterminate formulas.
Author: Anubhav Jain.
Args:
m: A regex match, with the first group being the element and
match: A regex match, with the first group being the element and
the second group being the amount
f: The formula part containing the match
formula: The formula part containing the match
m_dict: A symbol:amt dictionary from the previously parsed
formula
m_points: Number of points gained from the previously parsed
Expand All @@ -1113,10 +1115,10 @@ def _parse_chomp_and_rank(m, f, m_dict, m_points):
points_second_lowercase = 100

# get element and amount from regex match
el = m.group(1)
el = match[1]
if len(el) > 2 or len(el) < 1:
raise ValueError("Invalid element symbol entered!")
amt = float(m.group(2)) if m.group(2).strip() != "" else 1
amt = float(match.group(2)) if match.group(2).strip() != "" else 1

# convert the element string to proper [uppercase,lowercase] format
# and award points if it is already in that format
Expand All @@ -1136,7 +1138,7 @@ def _parse_chomp_and_rank(m, f, m_dict, m_points):
m_dict[el] += amt * factor
else:
m_dict[el] = amt * factor
return f.replace(m.group(), "", 1), m_dict, m_points + points
return formula.replace(match.group(), "", 1), m_dict, m_points + points

# else return None
return None, None, None
Expand Down Expand Up @@ -1197,8 +1199,8 @@ def _parse_chomp_and_rank(m, f, m_dict, m_points):
yield match


def reduce_formula(sym_amt, iupac_ordering: bool = False) -> tuple[str, float]:
"""Helper method to reduce a sym_amt dict to a reduced formula and factor.
def reduce_formula(sym_amt: dict[str, float] | dict[str, int], iupac_ordering: bool = False) -> tuple[str, float]:
"""Helper function to reduce a sym_amt dict to a reduced formula and factor.
Args:
sym_amt (dict): {symbol: amount}.
Expand Down Expand Up @@ -1287,6 +1289,9 @@ def __add__(self, other: object) -> ChemicalPotential:
return ChemicalPotential({e: self.get(e, 0) + other.get(e, 0) for e in els})
return NotImplemented

def __repr__(self) -> str:
return f"ChemPots: {super()!r}"

def get_energy(self, composition: Composition, strict: bool = True) -> float:
"""Calculate the energy of a composition.
Expand All @@ -1298,9 +1303,6 @@ def get_energy(self, composition: Composition, strict: bool = True) -> float:
raise ValueError(f"Potentials not specified for {missing}")
return sum(self.get(key, 0) * val for key, val in composition.items())

def __repr__(self) -> str:
return f"ChemPots: {super()!r}"


class CompositionError(Exception):
"""Exception class for composition errors."""
Loading

0 comments on commit 578d29c

Please sign in to comment.