Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix array comparison in core.Structure.merge_sites, also allow int property to be merged instead of float alone, mode only allow full name #4198

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
24 changes: 12 additions & 12 deletions src/pymatgen/core/sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def __init__(
iii.Dict of elements/species and occupancies, e.g.
{"Fe" : 0.5, "Mn":0.5}. This allows the setup of
disordered structures.
coords: Cartesian coordinates of site.
properties: Properties associated with the site as a dict, e.g.
coords (ArrayLike): Cartesian coordinates of site.
properties (dict): Properties associated with the site, e.g.
{"magmom": 5}. Defaults to None.
label: Label for the site. Defaults to None.
skip_checks: Whether to ignore all the usual checks and just
label (str): Label for the site. Defaults to None.
skip_checks (bool): Whether to ignore all the usual checks and just
create the site. Use this if the Site is created in a controlled
manner and speed is desired.
"""
Expand Down Expand Up @@ -310,20 +310,20 @@ def __init__(
symbols, e.g. "Li", "Fe2+", "P" or atomic numbers,
e.g. 3, 56, or actual Element or Species objects.
iii.Dict of elements/species and occupancies, e.g.
{"Fe" : 0.5, "Mn":0.5}. This allows the setup of
{"Fe": 0.5, "Mn": 0.5}. This allows the setup of
disordered structures.
coords: Coordinates of site, fractional coordinates
coords (ArrayLike): Coordinates of site, fractional coordinates
by default. See ``coords_are_cartesian`` for more details.
lattice: Lattice associated with the site.
to_unit_cell: Translates fractional coordinate to the
lattice (Lattice): Lattice associated with the site.
to_unit_cell (bool): Translates fractional coordinate to the
basic unit cell, i.e. all fractional coordinates satisfy 0
<= a < 1. Defaults to False.
coords_are_cartesian: Set to True if you are providing
coords_are_cartesian (bool): Set to True if you are providing
Cartesian coordinates. Defaults to False.
properties: Properties associated with the site as a dict, e.g.
properties (dict): Properties associated with the site, e.g.
{"magmom": 5}. Defaults to None.
label: Label for the site. Defaults to None.
skip_checks: Whether to ignore all the usual checks and just
label (str): Label for the site. Defaults to None.
skip_checks (bool): Whether to ignore all the usual checks and just
create the site. Use this if the PeriodicSite is created in a
controlled manner and speed is desired.
"""
Expand Down
74 changes: 47 additions & 27 deletions src/pymatgen/core/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Iterator, Sequence
from typing import Any, SupportsIndex
from typing import Any, ClassVar, SupportsIndex, TypeAlias

import pandas as pd
from ase import Atoms
Expand All @@ -61,7 +61,7 @@

from pymatgen.util.typing import CompositionLike, MillerIndex, PathLike, PbcLike, SpeciesLike

FileFormats = Literal[
FileFormats: TypeAlias = Literal[
"cif",
"poscar",
"cssr",
Expand All @@ -75,7 +75,7 @@
"aims",
"",
]
StructureSources = Literal["Materials Project", "COD"]
StructureSources: TypeAlias = Literal["Materials Project", "COD"]


class Neighbor(Site):
Expand Down Expand Up @@ -216,7 +216,7 @@ class SiteCollection(collections.abc.Sequence, ABC):
"""

# Tolerance in Angstrom for determining if sites are too close
DISTANCE_TOLERANCE = 0.5
DISTANCE_TOLERANCE: ClassVar[float] = 0.5
_properties: dict
Copy link
Contributor Author

@DanielYang59 DanielYang59 Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NEED CONFIRM on 091c911: Is _properties misplaced? an instance variable maybe?

class SiteCollection(collections.abc.Sequence, ABC):
"""Basic SiteCollection. Essentially a sequence of Sites or PeriodicSites.
This serves as a base class for Molecule (a collection of Site, i.e., no
periodicity) and Structure (a collection of PeriodicSites, i.e.,
periodicity). Not meant to be instantiated directly.
"""
# Tolerance in Angstrom for determining if sites are too close
DISTANCE_TOLERANCE = 0.5
_properties: dict


def __contains__(self, site: object) -> bool:
Expand Down Expand Up @@ -4716,44 +4716,64 @@ def scale_lattice(self, volume: float) -> Self:

return self

def merge_sites(self, tol: float = 0.01, mode: Literal["sum", "delete", "average"] = "sum") -> Self:
"""Merges sites (adding occupancies) within tol of each other.
Removes site properties.
def merge_sites(
self,
tol: float = 0.01,
mode: Literal["sum", "delete", "average"] = "sum",
) -> Self:
"""Merges sites (by adding occupancies) within tolerance and removes
site properties in "sum/delete" modes.

Args:
tol (float): Tolerance for distance to merge sites.
mode ("sum" | "delete" | "average"): "delete" means duplicate sites are
deleted. "sum" means the occupancies are summed for the sites.
"average" means that the site is deleted but the properties are averaged
Only first letter is considered.
mode ("sum" | "delete" | "average"): Only first letter is considered at this moment.
- "delete": delete duplicate sites.
- "sum": sum the occupancies for the sites.
- "average": delete the site but average the properties if it's numerical.

Returns:
Structure: self with merged sites.
Structure: Structure with merged sites.
"""
dist_mat = self.distance_matrix
# TODO: change the code the allow full name after 2025-12-01
# TODO2: add a test for mode value, currently it only checks if first letter is "s/a"
if mode.lower() not in {"sum", "delete", "average"} and mode.lower()[0] in {"s", "d", "a"}:
Copy link
Contributor Author

@DanielYang59 DanielYang59 Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Want to hear more opinions on this, I guess it's beneficial to allow only full name ("sum", "delete", "average") instead of checking first letter only to make use of the IDE auto-complete feature and facilitate typing

Also using the full name would be more readable: mode="sum" instead of mode="s"

warnings.warn(
"mode would only allow full name sum/delete/average after 2025-12-01", DeprecationWarning, stacklevel=2
)

if mode.lower()[0] not in {"s", "d", "a"}:
raise ValueError(f"Illegal {mode=}, should start with a/d/s.")

dist_mat: NDArray = self.distance_matrix
np.fill_diagonal(dist_mat, 0)
clusters = fcluster(linkage(squareform((dist_mat + dist_mat.T) / 2)), tol, "distance")
sites = []

sites: list[PeriodicSite] = []
for cluster in np.unique(clusters):
inds = np.where(clusters == cluster)[0]
species = self[inds[0]].species
coords = self[inds[0]].frac_coords
props = self[inds[0]].properties
for n, i in enumerate(inds[1:]):
sp = self[i].species
indexes = np.where(clusters == cluster)[0]
species: Composition = self[indexes[0]].species
coords: NDArray = self[indexes[0]].frac_coords
props: dict = self[indexes[0]].properties

for site_idx, clust_idx in enumerate(indexes[1:]):
# Sum occupancies in "sum" mode
if mode.lower()[0] == "s":
species += sp
offset = self[i].frac_coords - coords
coords += ((offset - np.round(offset)) / (n + 2)).astype(coords.dtype)
species += self[clust_idx].species

offset = self[clust_idx].frac_coords - coords
coords += ((offset - np.round(offset)) / (site_idx + 2)).astype(coords.dtype)
for key in props:
if props[key] is not None and self[i].properties[key] != props[key]:
Copy link
Contributor Author

@DanielYang59 DanielYang59 Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO tag: Using array_equal may not be a good idea either as the property could be (sequence of) floats

TODO2: test failure

# Test that we can average the site properties that are floats
lattice = Lattice.hexagonal(3.587776, 19.622793)
species = ["Na", "V", "S", "S"]
coords = [
[0.333333, 0.666667, 0.165000],
[0, 0, 0.998333],
[0.333333, 0.666667, 0.399394],
[0.666667, 0.333333, 0.597273],
]
site_props = {"prop1": [3.0, 5.0, 7.0, 11.0]}
navs2 = Structure.from_spacegroup(160, lattice, species, coords, site_properties=site_props)
navs2.insert(0, "Na", coords[0], properties={"prop1": 100.0})
navs2.merge_sites(mode="a")
assert len(navs2) == 12
assert 51.5 in [itr.properties["prop1"] for itr in navs2]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Property can actually be anything, including value supplied by user.

Copy link
Contributor Author

@DanielYang59 DanielYang59 Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for stepping in, I'm here adding a note for myself to work on later.

Yes in this case using == for comparison may not be ideal when the value could be float or np.array, this is very similar to #4092 where we want to compare the equality of two dict whose value could be np.array

But I assume we would not be able to use np.allclose as the value could be non-numerical, so as far as I'm aware, using array_equal would be the best approach (though handling of float would be sub-optimal)

if mode.lower()[0] == "a" and isinstance(props[key], float):
if props[key] is not None and not np.array_equal(self[clust_idx].properties[key], props[key]):
if mode.lower()[0] == "a" and isinstance(props[key], float | int):
# update a running total
props[key] = props[key] * (n + 1) / (n + 2) + self[i].properties[key] / (n + 2)
props[key] = props[key] * (site_idx + 1) / (site_idx + 2) + self[clust_idx].properties[
key
] / (site_idx + 2)
else:
props[key] = None
warnings.warn(
f"Sites with different site property {key} are merged. So property is set to none"
f"Sites with different site property {key} are merged. But property is set to None",
stacklevel=2,
)
sites.append(PeriodicSite(species, coords, self.lattice, properties=props))

Expand Down
83 changes: 62 additions & 21 deletions tests/core/test_structure.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

import json
import math
import os
from fractions import Fraction
from pathlib import Path
from shutil import which
from unittest import skipIf

import numpy as np
import pytest
Expand All @@ -29,6 +29,7 @@
from pymatgen.electronic_structure.core import Magmom
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.io.cif import CifParser
from pymatgen.io.vasp.inputs import Poscar
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR, PymatgenTest

Expand All @@ -40,11 +41,11 @@
ase = Atoms = Calculator = EMT = None


enum_cmd = which("enum.x") or which("multienum.x")
mcsqs_cmd = which("mcsqs")
ENUM_CMD = which("enum.x") or which("multienum.x")
MCSQS_CMD = which("mcsqs")


class TestNeighbor(PymatgenTest):
class TestNeighbor:
def test_msonable(self):
struct = PymatgenTest.get_structure("Li2O")
nn = struct.get_neighbors(struct[0], r=3)
Expand Down Expand Up @@ -102,7 +103,7 @@ def setUp(self):
)
self.V2O3 = IStructure.from_file(f"{TEST_FILES_DIR}/cif/V2O3.cif")

@skipIf(not (mcsqs_cmd and enum_cmd), reason="enumlib or mcsqs executable not present")
@pytest.mark.skipif(not (MCSQS_CMD and ENUM_CMD), reason="enumlib or mcsqs executable not present")
def test_get_orderings(self):
ordered = Structure.from_spacegroup("Im-3m", Lattice.cubic(3), ["Fe"], [[0, 0, 0]])
assert ordered.get_orderings()[0] == ordered
Expand Down Expand Up @@ -1633,23 +1634,27 @@ def test_merge_sites(self):
[0.5, 0.5, 1.501],
]
struct = Structure(Lattice.cubic(1), species, coords)
struct.merge_sites(mode="s")
struct.merge_sites(mode="sum")
assert struct[0].specie.symbol == "Ag"
assert struct[1].species == Composition({"Cl": 0.35, "F": 0.25})
assert_allclose(struct[1].frac_coords, [0.5, 0.5, 0.5005])

# Test for TaS2 with spacegroup 166 in 160 setting.
# Test illegal mode
with pytest.raises(ValueError, match="Illegal mode='illegal', should start with a/d/s"):
struct.merge_sites(mode="illegal")

# Test for TaS2 with spacegroup 166 in 160 setting
lattice = Lattice.hexagonal(3.374351, 20.308941)
species = ["Ta", "S", "S"]
coords = [
[0, 0, 0.944333],
[0.333333, 0.666667, 0.353424],
[0.666667, 0.333333, 0.535243],
]
tas2 = Structure.from_spacegroup(160, lattice, species, coords)
assert len(tas2) == 13
tas2.merge_sites(mode="d")
assert len(tas2) == 9
struct_tas2 = Structure.from_spacegroup(160, lattice, species, coords)
assert len(struct_tas2) == 13
struct_tas2.merge_sites(mode="delete")
assert len(struct_tas2) == 9

lattice = Lattice.hexagonal(3.587776, 19.622793)
species = ["Na", "V", "S", "S"]
Expand All @@ -1659,12 +1664,12 @@ def test_merge_sites(self):
[0.333333, 0.666667, 0.399394],
[0.666667, 0.333333, 0.597273],
]
navs2 = Structure.from_spacegroup(160, lattice, species, coords)
assert len(navs2) == 18
navs2.merge_sites(mode="d")
assert len(navs2) == 12
struct_navs2 = Structure.from_spacegroup(160, lattice, species, coords)
assert len(struct_navs2) == 18
struct_navs2.merge_sites(mode="delete")
assert len(struct_navs2) == 12

# Test that we can average the site properties that are floats
# Test that we can average the site properties that are numerical (float/int)
lattice = Lattice.hexagonal(3.587776, 19.622793)
species = ["Na", "V", "S", "S"]
coords = [
Expand All @@ -1674,11 +1679,47 @@ def test_merge_sites(self):
[0.666667, 0.333333, 0.597273],
]
site_props = {"prop1": [3.0, 5.0, 7.0, 11.0]}
navs2 = Structure.from_spacegroup(160, lattice, species, coords, site_properties=site_props)
navs2.insert(0, "Na", coords[0], properties={"prop1": 100.0})
navs2.merge_sites(mode="a")
assert len(navs2) == 12
assert 51.5 in [itr.properties["prop1"] for itr in navs2]
struct_navs2 = Structure.from_spacegroup(160, lattice, species, coords, site_properties=site_props)
struct_navs2.insert(0, "Na", coords[0], properties={"prop1": 100}) # int property
struct_navs2.merge_sites(mode="average")
assert len(struct_navs2) == 12
assert any(math.isclose(site.properties["prop1"], 51.5) for site in struct_navs2)

# Test non-numerical property warning
struct_navs2.insert(0, "Na", coords[0], properties={"prop1": "hi"})
with pytest.warns(UserWarning, match="But property is set to None"):
struct_navs2.merge_sites(mode="average")

# Test property handling for np.array (selective dynamics)
poscar_str_0 = """Test POSCAR
1.0
3.840198 0.000000 0.000000
1.920099 3.325710 0.000000
0.000000 -2.217138 3.135509
1 1
Selective dynamics
direct
0.000000 0.000000 0.000000 T T T Si
0.750000 0.500000 0.750000 F F F O
"""
poscar_str_1 = """offset a bit
1.0
3.840198 0.000000 0.000000
1.920099 3.325710 0.000000
0.000000 -2.217138 3.135509
1 1
Selective dynamics
direct
0.100000 0.000000 0.000000 T T T Si
0.750000 0.500000 0.750000 F F F O
"""

struct_0 = Poscar.from_str(poscar_str_0).structure
struct_1 = Poscar.from_str(poscar_str_1).structure

for site in struct_0:
struct_1.append(site.species, site.frac_coords, properties=site.properties)
struct_1.merge_sites(mode="average")

def test_properties(self):
assert self.struct.num_sites == len(self.struct)
Expand Down
3 changes: 0 additions & 3 deletions tests/io/vasp/test_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import hashlib
import os
import unittest
from glob import glob
from zipfile import ZipFile

Expand Down Expand Up @@ -1607,7 +1606,6 @@ def test_user_incar_settings(self):
assert not vis.incar["LASPH"], "LASPH user setting not applied"
assert vis.incar["VDW_SR"] == 1.5, "VDW_SR user setting not applied"

@unittest.skipIf(not os.path.exists(TEST_DIR), "Test files are not present.")
def test_from_prev_calc(self):
prev_run = os.path.join(TEST_DIR, "fixtures", "relaxation")

Expand All @@ -1624,7 +1622,6 @@ def test_from_prev_calc(self):
assert "VDW_A2" in vis_bj.incar
assert "VDW_S8" in vis_bj.incar

@unittest.skipIf(not os.path.exists(TEST_DIR), "Test files are not present.")
def test_override_from_prev_calc(self):
prev_run = os.path.join(TEST_DIR, "fixtures", "relaxation")

Expand Down
24 changes: 18 additions & 6 deletions tests/util/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,6 @@ def test_nested_arrays(self):

def test_diff_dtype(self):
"""Make sure it also works for other data types as value."""

@dataclass
class CustomClass:
name: str
value: int

# Test with bool values
dict1 = {"a": True}
dict2 = {"a": True}
Expand All @@ -69,13 +63,31 @@ class CustomClass:
assert not is_np_dict_equal(dict4, dict6)

# Test with a custom data class
@dataclass
class CustomClass:
name: str
value: int

dict7 = {"a": CustomClass(name="test", value=1)}
dict8 = {"a": CustomClass(name="test", value=1)}
assert is_np_dict_equal(dict7, dict8)

dict9 = {"a": CustomClass(name="test", value=2)}
assert not is_np_dict_equal(dict7, dict9)

# Test __eq__ method being used
@dataclass
class NewCustomClass:
name: str
value: int

def __eq__(self, other):
return True

dict7_1 = {"a": NewCustomClass(name="test", value=1)}
dict8_1 = {"a": NewCustomClass(name="hello", value=2)}
assert is_np_dict_equal(dict7_1, dict8_1)

# Test with None
dict10 = {"a": None}
dict11 = {"a": None}
Expand Down
Loading