From 0c9ff40135607040a021083048cb45ab5aed7ad7 Mon Sep 17 00:00:00 2001 From: "Haoyu (Daniel)" Date: Fri, 31 May 2024 03:33:50 +0800 Subject: [PATCH] Use `isclose` over `==` for overlap position check in `SlabGenerator.get_slabs` (#3825) * use isclose over == for position check * fix some types * update unit test for `test_get_slabs` * revert change to unit test * add `ztol` argument * rectify docstring --- pymatgen/core/surface.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/pymatgen/core/surface.py b/pymatgen/core/surface.py index 1d5c08652d3..6f1f293e481 100644 --- a/pymatgen/core/surface.py +++ b/pymatgen/core/surface.py @@ -946,7 +946,7 @@ def __init__( the c direction is parallel to the third lattice vector """ - def reduce_vector(vector: tuple[int, int, int]) -> tuple[int, int, int]: + def reduce_vector(vector: MillerIndex) -> MillerIndex: """Helper function to reduce vectors.""" divisor = abs(reduce(gcd, vector)) # type: ignore[arg-type] return cast(tuple[int, int, int], tuple(int(idx / divisor) for idx in vector)) @@ -1201,6 +1201,7 @@ def get_slabs( max_broken_bonds: int = 0, symmetrize: bool = False, repair: bool = False, + ztol: float = 0, ) -> list[Slab]: """Generate slabs with shift values calculated from the internal calculate_possible_shifts method. If the user decide to avoid breaking @@ -1222,6 +1223,8 @@ def get_slabs( repair (bool): Whether to repair terminations with broken bonds (True) or just omit them (False). Default to False as repairing terminations can lead to many more possible slabs. + ztol (float): Fractional tolerance for determine overlapping z-ranges, + smaller ztol might result in more possible Slabs. Returns: list[Slab]: All possible Slabs of a particular surface, @@ -1282,7 +1285,10 @@ def gen_possible_shifts(ftol: float) -> list[float]: return sorted(shifts) - def get_z_ranges(bonds: dict[tuple[Species | Element, Species | Element], float]) -> list[tuple[float, float]]: + def get_z_ranges( + bonds: dict[tuple[Species | Element, Species | Element], float], + ztol: float, + ) -> list[tuple[float, float]]: """Collect occupied z ranges where each z_range is a (lower_z, upper_z) tuple. This method examines all sites in the oriented unit cell (OUC) @@ -1292,7 +1298,7 @@ def get_z_ranges(bonds: dict[tuple[Species | Element, Species | Element], float] Args: bonds (dict): A {(species1, species2): max_bond_dist} dict. - tol (float): Fractional tolerance for determine overlapping positions. + ztol (float): Fractional tolerance for determine overlapping z-ranges. """ # Sanitize species in dict keys bonds = {(get_el_sp(s1), get_el_sp(s2)): dist for (s1, s2), dist in bonds.items()} @@ -1315,15 +1321,13 @@ def get_z_ranges(bonds: dict[tuple[Species | Element, Species | Element], float] z_ranges.extend([(0, z_range[1]), (z_range[0] + 1, 1)]) # Neglect overlapping positions - elif z_range[0] != z_range[1]: - # TODO (@DanielYang59): use the following for equality check - # elif not isclose(z_range[0], z_range[1], abs_tol=tol): + elif not isclose(z_range[0], z_range[1], abs_tol=ztol): z_ranges.append(z_range) return z_ranges # Get occupied z_ranges - z_ranges = [] if bonds is None else get_z_ranges(bonds) + z_ranges = [] if bonds is None else get_z_ranges(bonds, ztol) slabs = [] for shift in gen_possible_shifts(ftol=ftol): @@ -1350,7 +1354,7 @@ def get_z_ranges(bonds: dict[tuple[Species | Element, Species | Element], float] # Filter out surfaces that might be the same matcher = StructureMatcher(ltol=tol, stol=tol, primitive_cell=False, scale=False) - final_slabs = [] + final_slabs: list[Slab] = [] for group in matcher.group_structures(slabs): # For each unique slab, symmetrize the # surfaces by removing sites from the bottom @@ -1365,7 +1369,7 @@ def get_z_ranges(bonds: dict[tuple[Species | Element, Species | Element], float] matcher_sym = StructureMatcher(ltol=tol, stol=tol, primitive_cell=False, scale=False) final_slabs = [group[0] for group in matcher_sym.group_structures(final_slabs)] - return sorted(final_slabs, key=lambda slab: slab.energy) # type: ignore[return-value, arg-type] + return cast(list[Slab], sorted(final_slabs, key=lambda slab: slab.energy)) def repair_broken_bonds( self, @@ -2127,7 +2131,7 @@ def math_lcm(a: int, b: int) -> int: if len([i for i in transf_hkl if i < 0]) > 1: transf_hkl *= -1 - return tuple(transf_hkl) # type: ignore[return-value] + return tuple(transf_hkl) def miller_index_from_sites(