diff --git a/src/pymatgen/core/composition.py b/src/pymatgen/core/composition.py index d269c247ff9..03fac2a7412 100644 --- a/src/pymatgen/core/composition.py +++ b/src/pymatgen/core/composition.py @@ -25,7 +25,7 @@ if TYPE_CHECKING: from collections.abc import Generator, Iterator - from typing import Any, ClassVar + from typing import Any, ClassVar, Literal from typing_extensions import Self @@ -774,17 +774,25 @@ def to_reduced_dict(self) -> dict[str, float]: def to_weight_dict(self) -> dict[str, float]: """ Returns: - dict[str, float] with weight fraction of each component {"Ti": 0.90, "V": 0.06, "Al": 0.04}. + dict[str, float]: weight fractions of each component, e.g. {"Ti": 0.90, "V": 0.06, "Al": 0.04}. """ return {str(el): self.get_wt_fraction(el) for el in self.elements} @property - def to_data_dict(self) -> dict[str, Any]: + def to_data_dict( + self, + ) -> dict[ + Literal["reduced_cell_composition", "unit_cell_composition", "reduced_cell_formula", "elements", "nelements"], + Any, + ]: """ Returns: - A dict with many keys and values relating to Composition/Formula, - including reduced_cell_composition, unit_cell_composition, - reduced_cell_formula, elements and nelements. + dict with the following keys: + - reduced_cell_composition + - unit_cell_composition + - reduced_cell_formula + - elements + - nelements. """ return { "reduced_cell_composition": self.reduced_composition, diff --git a/src/pymatgen/core/periodic_table.py b/src/pymatgen/core/periodic_table.py index a981c74578f..63ef361aaf9 100644 --- a/src/pymatgen/core/periodic_table.py +++ b/src/pymatgen/core/periodic_table.py @@ -876,7 +876,6 @@ def print_periodic_table(filter_function: Callable | None = None) -> None: print(" ".join(row_str)) -@functools.total_ordering class Element(ElementBase): """Enum representing an element in the periodic table.""" @@ -1597,14 +1596,12 @@ def from_dict(cls, dct: dict) -> Self: return cls(dct["element"], dct["oxidation_state"], spin=dct.get("spin")) -@functools.total_ordering class Specie(Species): """This maps the historical grammatically inaccurate Specie to Species to maintain backwards compatibility. """ -@functools.total_ordering class DummySpecie(DummySpecies): """This maps the historical grammatically inaccurate DummySpecie to DummySpecies to maintain backwards compatibility. diff --git a/src/pymatgen/core/structure.py b/src/pymatgen/core/structure.py index 8e46793c837..9fe9408027e 100644 --- a/src/pymatgen/core/structure.py +++ b/src/pymatgen/core/structure.py @@ -19,9 +19,7 @@ import warnings from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import MutableSequence from fnmatch import fnmatch -from io import StringIO from typing import TYPE_CHECKING, Literal, cast, get_args import numpy as np @@ -245,7 +243,7 @@ def sites(self) -> list[PeriodicSite] | tuple[PeriodicSite, ...]: def sites(self, sites: Sequence[PeriodicSite]) -> None: """Set the sites in the Structure.""" # If self is mutable Structure or Molecule, set _sites as list - is_mutable = isinstance(self._sites, MutableSequence) + is_mutable = isinstance(self._sites, collections.abc.MutableSequence) self._sites: list[PeriodicSite] | tuple[PeriodicSite, ...] = list(sites) if is_mutable else tuple(sites) @abstractmethod @@ -1098,9 +1096,8 @@ def __init__( self._properties = properties or {} def __eq__(self, other: object) -> bool: + """Define equality by comparing all three attributes: lattice, sites, properties.""" needed_attrs = ("lattice", "sites", "properties") - - # Return NotImplemented as in https://docs.python.org/3/library/functools.html#functools.total_ordering if not all(hasattr(other, attr) for attr in needed_attrs): return NotImplemented @@ -1109,8 +1106,10 @@ def __eq__(self, other: object) -> bool: if other is self: return True + if len(self) != len(other): return False + if self.lattice != other.lattice: return False if self.properties != other.properties: @@ -2982,7 +2981,7 @@ def to(self, filename: PathLike = "", fmt: FileFormats = "", **kwargs) -> str: return Prismatic(self).to_str() elif fmt in ("yaml", "yml") or fnmatch(filename, "*.yaml*") or fnmatch(filename, "*.yml*"): yaml = YAML() - str_io = StringIO() + str_io = io.StringIO() yaml.dump(self.as_dict(), str_io) yaml_str = str_io.getvalue() if filename: @@ -3923,7 +3922,7 @@ def to(self, filename: str = "", fmt: str = "") -> str | None: return json_str elif fmt in {"yaml", "yml"} or fnmatch(filename, "*.yaml*") or fnmatch(filename, "*.yml*"): yaml = YAML() - str_io = StringIO() + str_io = io.StringIO() yaml.dump(self.as_dict(), str_io) yaml_str = str_io.getvalue() if filename: