Skip to content

Commit

Permalink
Use isclose over == for overlap position check in `SlabGenerator.…
Browse files Browse the repository at this point in the history
…get_slabs` (#3825)

* use isclose over == for position check

* fix some types

* update unit test for `test_get_slabs`

* revert change to unit test

* add `ztol` argument

* rectify docstring
  • Loading branch information
DanielYang59 authored May 30, 2024
1 parent 8c853cf commit 0c9ff40
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions pymatgen/core/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ def __init__(
the c direction is parallel to the third lattice vector
"""

def reduce_vector(vector: tuple[int, int, int]) -> tuple[int, int, int]:
def reduce_vector(vector: MillerIndex) -> MillerIndex:
"""Helper function to reduce vectors."""
divisor = abs(reduce(gcd, vector)) # type: ignore[arg-type]
return cast(tuple[int, int, int], tuple(int(idx / divisor) for idx in vector))
Expand Down Expand Up @@ -1201,6 +1201,7 @@ def get_slabs(
max_broken_bonds: int = 0,
symmetrize: bool = False,
repair: bool = False,
ztol: float = 0,
) -> list[Slab]:
"""Generate slabs with shift values calculated from the internal
calculate_possible_shifts method. If the user decide to avoid breaking
Expand All @@ -1222,6 +1223,8 @@ def get_slabs(
repair (bool): Whether to repair terminations with broken bonds (True)
or just omit them (False). Default to False as repairing terminations
can lead to many more possible slabs.
ztol (float): Fractional tolerance for determine overlapping z-ranges,
smaller ztol might result in more possible Slabs.
Returns:
list[Slab]: All possible Slabs of a particular surface,
Expand Down Expand Up @@ -1282,7 +1285,10 @@ def gen_possible_shifts(ftol: float) -> list[float]:

return sorted(shifts)

def get_z_ranges(bonds: dict[tuple[Species | Element, Species | Element], float]) -> list[tuple[float, float]]:
def get_z_ranges(
bonds: dict[tuple[Species | Element, Species | Element], float],
ztol: float,
) -> list[tuple[float, float]]:
"""Collect occupied z ranges where each z_range is a (lower_z, upper_z) tuple.
This method examines all sites in the oriented unit cell (OUC)
Expand All @@ -1292,7 +1298,7 @@ def get_z_ranges(bonds: dict[tuple[Species | Element, Species | Element], float]
Args:
bonds (dict): A {(species1, species2): max_bond_dist} dict.
tol (float): Fractional tolerance for determine overlapping positions.
ztol (float): Fractional tolerance for determine overlapping z-ranges.
"""
# Sanitize species in dict keys
bonds = {(get_el_sp(s1), get_el_sp(s2)): dist for (s1, s2), dist in bonds.items()}
Expand All @@ -1315,15 +1321,13 @@ def get_z_ranges(bonds: dict[tuple[Species | Element, Species | Element], float]
z_ranges.extend([(0, z_range[1]), (z_range[0] + 1, 1)])

# Neglect overlapping positions
elif z_range[0] != z_range[1]:
# TODO (@DanielYang59): use the following for equality check
# elif not isclose(z_range[0], z_range[1], abs_tol=tol):
elif not isclose(z_range[0], z_range[1], abs_tol=ztol):
z_ranges.append(z_range)

return z_ranges

# Get occupied z_ranges
z_ranges = [] if bonds is None else get_z_ranges(bonds)
z_ranges = [] if bonds is None else get_z_ranges(bonds, ztol)

slabs = []
for shift in gen_possible_shifts(ftol=ftol):
Expand All @@ -1350,7 +1354,7 @@ def get_z_ranges(bonds: dict[tuple[Species | Element, Species | Element], float]
# Filter out surfaces that might be the same
matcher = StructureMatcher(ltol=tol, stol=tol, primitive_cell=False, scale=False)

final_slabs = []
final_slabs: list[Slab] = []
for group in matcher.group_structures(slabs):
# For each unique slab, symmetrize the
# surfaces by removing sites from the bottom
Expand All @@ -1365,7 +1369,7 @@ def get_z_ranges(bonds: dict[tuple[Species | Element, Species | Element], float]
matcher_sym = StructureMatcher(ltol=tol, stol=tol, primitive_cell=False, scale=False)
final_slabs = [group[0] for group in matcher_sym.group_structures(final_slabs)]

return sorted(final_slabs, key=lambda slab: slab.energy) # type: ignore[return-value, arg-type]
return cast(list[Slab], sorted(final_slabs, key=lambda slab: slab.energy))

def repair_broken_bonds(
self,
Expand Down Expand Up @@ -2127,7 +2131,7 @@ def math_lcm(a: int, b: int) -> int:
if len([i for i in transf_hkl if i < 0]) > 1:
transf_hkl *= -1

return tuple(transf_hkl) # type: ignore[return-value]
return tuple(transf_hkl)


def miller_index_from_sites(
Expand Down

0 comments on commit 0c9ff40

Please sign in to comment.