From 09c1c01ee0dea2f53228eacc4e0d4238ded98040 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Fri, 23 Feb 2024 15:20:24 +0100 Subject: [PATCH 1/2] `Kpoint.__eq__` and `PhononBandStructureSymmLine.__eq__` methods + tests (#3650) * plotter.py and phonon band structure: improve doc strings + refactor use list concat over append * Add Kpoint.__eq__ method * Add test for Kpoint equality * Add PhononBandStructureSymmLine.__eq__ method --- .../electronic_structure/bandstructure.py | 61 ++++---- pymatgen/phonon/bandstructure.py | 134 ++++++++++-------- pymatgen/phonon/plotter.py | 64 ++++----- .../test_bandstructure.py | 8 ++ tests/phonon/test_bandstructure.py | 7 + 5 files changed, 152 insertions(+), 122 deletions(-) diff --git a/pymatgen/electronic_structure/bandstructure.py b/pymatgen/electronic_structure/bandstructure.py index 87e9517d4e9..9ef3d64377d 100644 --- a/pymatgen/electronic_structure/bandstructure.py +++ b/pymatgen/electronic_structure/bandstructure.py @@ -108,6 +108,16 @@ def __str__(self) -> str: """Returns a string with fractional, Cartesian coordinates and label.""" return f"{self.frac_coords} {self.cart_coords} {self.label}" + def __eq__(self, other: object) -> bool: + """Check if two kpoints are equal.""" + if not isinstance(other, Kpoint): + return NotImplemented + return ( + np.allclose(self.frac_coords, other.frac_coords) + and self.lattice == other.lattice + and self.label == other.label + ) + def as_dict(self) -> dict[str, Any]: """JSON-serializable dict representation of a kpoint.""" return { @@ -142,14 +152,14 @@ class BandStructure: lattice_rec (Lattice): The reciprocal lattice of the band structure. efermi (float): The Fermi energy. is_spin_polarized (bool): True if the band structure is spin-polarized. - bands (dict): The energy eigenvalues as a {spin: ndarray}. Note that the use of an - ndarray is necessary for computational as well as memory efficiency due to the large - amount of numerical data. The indices of the ndarray are [band_index, kpoint_index]. + bands (dict): The energy eigenvalues as a {spin: array}. Note that the use of an + array is necessary for computational as well as memory efficiency due to the large + amount of numerical data. The indices of the array are [band_index, kpoint_index]. nb_bands (int): Returns the number of bands in the band structure. structure (Structure): Returns the structure. - projections (dict): The projections as a {spin: ndarray}. Note that the use of an - ndarray is necessary for computational as well as memory efficiency due to the large - amount of numerical data. The indices of the ndarray are [band_index, kpoint_index, + projections (dict): The projections as a {spin: array}. Note that the use of an + array is necessary for computational as well as memory efficiency due to the large + amount of numerical data. The indices of the array are [band_index, kpoint_index, orbital_index, ion_index]. """ @@ -184,8 +194,8 @@ def __init__( structure: The crystal structure (as a pymatgen Structure object) associated with the band structure. This is needed if we provide projections to the band structure - projections: dict of orbital projections as {spin: ndarray}. The - indices of the ndarrayare [band_index, kpoint_index, orbital_index, + projections: dict of orbital projections as {spin: array}. The + indices of the array are [band_index, kpoint_index, orbital_index, ion_index].If the band structure is not spin polarized, we only store one data set under Spin.up. """ @@ -363,22 +373,21 @@ def get_cbm(self): """Returns data about the CBM. Returns: - {"band_index","kpoint_index","kpoint","energy"} - - "band_index": A dict with spin keys pointing to a list of the - indices of the band containing the CBM (please note that you - can have several bands sharing the CBM) {Spin.up:[], - Spin.down:[]} - - "kpoint_index": The list of indices in self.kpoints for the - kpoint CBM. Please note that there can be several - kpoint_indices relating to the same kpoint (e.g., Gamma can - occur at different spots in the band structure line plot) - - "kpoint": The kpoint (as a kpoint object) - - "energy": The energy of the CBM - - "projections": The projections along sites and orbitals of the - CBM if any projection data is available (else it is an empty - dictionary). The format is similar to the projections field in - BandStructure: {spin:{'Orbital': [proj]}} where the array - [proj] is ordered according to the sites in structure + dict[str, Any]: with keys band_index, kpoint_index, kpoint, energy. + - "band_index": A dict with spin keys pointing to a list of the + indices of the band containing the CBM (please note that you + can have several bands sharing the CBM) {Spin.up:[], Spin.down:[]} + - "kpoint_index": The list of indices in self.kpoints for the + kpoint CBM. Please note that there can be several + kpoint_indices relating to the same kpoint (e.g., Gamma can + occur at different spots in the band structure line plot) + - "kpoint": The kpoint (as a kpoint object) + - "energy": The energy of the CBM + - "projections": The projections along sites and orbitals of the + CBM if any projection data is available (else it is an empty + dictionary). The format is similar to the projections field in + BandStructure: {spin:{'Orbital': [proj]}} where the array + [proj] is ordered according to the sites in structure """ if self.is_metal(): return { @@ -710,8 +719,8 @@ def __init__( structure: The crystal structure (as a pymatgen Structure object) associated with the band structure. This is needed if we provide projections to the band structure. - projections: dict of orbital projections as {spin: ndarray}. The - indices of the ndarray are [band_index, kpoint_index, orbital_index, + projections: dict of orbital projections as {spin: array}. The + indices of the array are [band_index, kpoint_index, orbital_index, ion_index].If the band structure is not spin polarized, we only store one data set under Spin.up. """ diff --git a/pymatgen/phonon/bandstructure.py b/pymatgen/phonon/bandstructure.py index 1f82b20e4d1..90dac271be3 100644 --- a/pymatgen/phonon/bandstructure.py +++ b/pymatgen/phonon/bandstructure.py @@ -3,7 +3,7 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np from monty.json import MSONable @@ -33,10 +33,9 @@ def get_reasonable_repetitions(n_atoms: int) -> tuple[int, int, int]: return 1, 1, 1 -def eigenvectors_from_displacements(disp, masses) -> np.ndarray: +def eigenvectors_from_displacements(disp: np.ndarray, masses: np.ndarray) -> np.ndarray: """Calculate the eigenvectors from the atomic displacements.""" - sqrt_masses = np.sqrt(masses) - return np.einsum("nax,a->nax", disp, sqrt_masses) + return np.einsum("nax,a->nax", disp, masses**0.5) def estimate_band_connection(prev_eigvecs, eigvecs, prev_band_order) -> list[int]: @@ -45,13 +44,13 @@ def estimate_band_connection(prev_eigvecs, eigvecs, prev_band_order) -> list[int connection_order = [] for overlaps in metric: max_val = 0 - for i in reversed(range(len(metric))): - val = overlaps[i] - if i in connection_order: + for idx in reversed(range(len(metric))): + val = overlaps[idx] + if idx in connection_order: continue if val > max_val: max_val = val - max_idx = i + max_idx = idx connection_order.append(max_idx) return [connection_order[x] for x in prev_band_order] @@ -66,7 +65,7 @@ class PhononBandStructure(MSONable): def __init__( self, - qpoints: list[Kpoint], + qpoints: Sequence[Kpoint], frequencies: ArrayLike, lattice: Lattice, nac_frequencies: Sequence[Sequence] | None = None, @@ -102,15 +101,15 @@ def __init__( A list of tuples. The first element of each tuple should be a list defining the direction. The second element containing a numpy array of complex numbers with shape (3*len(structure), len(structure), 3). - labels_dict: (dict) of {} this links a qpoint (in frac coords or + labels_dict: (dict[str, Kpoint]): this links a qpoint (in frac coords or Cartesian coordinates depending on the coords) to a label. - coords_are_cartesian: Whether the qpoint coordinates are Cartesian. + coords_are_cartesian (bool): Whether the qpoint coordinates are Cartesian. Defaults to False. structure: The crystal structure (as a pymatgen Structure object) - associated with the band structure. This is needed if we - provide projections to the band structure. + associated with the band structure. This is needed to calculate element/orbital + projections of the band structure. """ self.lattice_rec = lattice - self.qpoints = [] + self.qpoints: list[Kpoint] = [] self.labels_dict = {} self.structure = structure if eigendisplacements is None: @@ -127,8 +126,8 @@ def __init__( self.labels_dict[label] = Kpoint( q_pt, lattice, label=label, coords_are_cartesian=coords_are_cartesian ) - self.qpoints.append(Kpoint(q_pt, lattice, label=label, coords_are_cartesian=coords_are_cartesian)) - self.bands = frequencies + self.qpoints += [Kpoint(q_pt, lattice, label=label, coords_are_cartesian=coords_are_cartesian)] + self.bands = np.asarray(frequencies) self.nb_bands = len(self.bands) self.nb_qpoints = len(self.qpoints) @@ -246,7 +245,7 @@ def get_nac_eigendisplacements_along_dir(self, direction) -> np.ndarray | None: return None - def asr_breaking(self, tol_eigendisplacements: float = 1e-5): + def asr_breaking(self, tol_eigendisplacements: float = 1e-5) -> np.ndarray | None: """Returns the breaking of the acoustic sum rule for the three acoustic modes, if Gamma is present. None otherwise. If eigendisplacements are available they are used to determine the acoustic @@ -255,36 +254,34 @@ def asr_breaking(self, tol_eigendisplacements: float = 1e-5): identified or eigendisplacements are missing the first 3 modes will be used (indices [0:3]). """ - for i in range(self.nb_qpoints): - if np.allclose(self.qpoints[i].frac_coords, (0, 0, 0)): + for idx in range(self.nb_qpoints): + if np.allclose(self.qpoints[idx].frac_coords, (0, 0, 0)): if self.has_eigendisplacements: acoustic_modes_index = [] for j in range(self.nb_bands): - eig = self.eigendisplacements[j][i] + eig = self.eigendisplacements[j][idx] if np.max(np.abs(eig[1:] - eig[:1])) < tol_eigendisplacements: acoustic_modes_index.append(j) # if acoustic modes are not correctly identified return use # the first three modes if len(acoustic_modes_index) != 3: acoustic_modes_index = [0, 1, 2] - return self.bands[acoustic_modes_index, i] + return self.bands[acoustic_modes_index, idx] - return self.bands[:3, i] + return self.bands[:3, idx] return None - def as_dict(self): + def as_dict(self) -> dict[str, Any]: """MSONable dict.""" - dct = { + dct: dict[str, Any] = { "@module": type(self).__module__, "@class": type(self).__name__, "lattice_rec": self.lattice_rec.as_dict(), - "qpoints": [], + # qpoints are not Kpoint objects dicts but are frac coords. This makes + # the dict smaller and avoids the repetition of the lattice + "qpoints": [q_pt.as_dict()["fcoords"] for q_pt in self.qpoints], } - # qpoints are not Kpoint objects dicts but are frac coords. This makes - # the dict smaller and avoids the repetition of the lattice - for q in self.qpoints: - dct["qpoints"].append(q.as_dict()["fcoords"]) dct["bands"] = self.bands.tolist() dct["labels_dict"] = {} for kpoint_letter, kpoint_object in self.labels_dict.items(): @@ -307,10 +304,10 @@ def as_dict(self): return dct @classmethod - def from_dict(cls, dct) -> PhononBandStructure: + def from_dict(cls, dct: dict[str, Any]) -> PhononBandStructure: """ Args: - dct (dict): Dict representation. + dct (dict): Dict representation of PhononBandStructure. Returns: PhononBandStructure @@ -345,7 +342,7 @@ class PhononBandStructureSymmLine(PhononBandStructure): def __init__( self, - qpoints: list[Kpoint], + qpoints: Sequence[Kpoint], frequencies: ArrayLike, lattice: Lattice, has_nac: bool = False, @@ -397,7 +394,7 @@ def __repr__(self) -> str: return f"{type(self).__name__}({bands=}, {labels=})" def _reuse_init( - self, eigendisplacements: ArrayLike, frequencies: ArrayLike, has_nac: bool, qpoints: list[Kpoint] + self, eigendisplacements: ArrayLike, frequencies: ArrayLike, has_nac: bool, qpoints: Sequence[Kpoint] ) -> None: self.distance = [] self.branches = [] @@ -410,29 +407,29 @@ def _reuse_init( for idx in range(self.nb_qpoints): label = self.qpoints[idx].label if label is not None and previous_label is not None: - self.distance.append(previous_distance) + self.distance += [previous_distance] else: - self.distance.append( + self.distance += [ np.linalg.norm(self.qpoints[idx].cart_coords - previous_qpoint.cart_coords) + previous_distance - ) + ] previous_qpoint = self.qpoints[idx] previous_distance = self.distance[idx] if label and previous_label: if len(one_group) != 0: - branches_tmp.append(one_group) + branches_tmp += [one_group] one_group = [] previous_label = label - one_group.append(idx) + one_group += [idx] if len(one_group) != 0: - branches_tmp.append(one_group) + branches_tmp += [one_group] for branch in branches_tmp: - self.branches.append( + self.branches += [ { "start_index": branch[0], "end_index": branch[-1], "name": f"{self.qpoints[branch[0]].label}-{self.qpoints[branch[-1]].label}", } - ) + ] # extract the frequencies with non-analytical contribution at gamma if has_nac: naf = [] @@ -462,10 +459,10 @@ def get_equivalent_qpoints(self, index: int) -> list[int]: same frac coords) to the given one. Args: - index: the qpoint index + index (int): the qpoint index Returns: - a list of equivalent indices + list[int]: equivalent indices TODO: now it uses the label we might want to use coordinates instead (in case there was a mislabel) @@ -544,15 +541,15 @@ def as_phononwebsite(self) -> dict: # get qpoints qpoints = [] - for q in self.qpoints: - qpoints.append(list(q.frac_coords)) + for q_pt in self.qpoints: + qpoints.append(list(q_pt.frac_coords)) dct["qpoints"] = qpoints # get labels hsq_dict = {} - for nq, q in enumerate(self.qpoints): - if q.label is not None: - hsq_dict[nq] = q.label + for nq, q_pt in enumerate(self.qpoints): + if q_pt.label is not None: + hsq_dict[nq] = q_pt.label # get distances dist = 0 @@ -583,18 +580,18 @@ def as_phononwebsite(self) -> dict: dct["eigenvalues"] = bands.T.tolist() # eigenvectors - eigenvectors = self.eigendisplacements.copy() - eigenvectors /= np.linalg.norm(eigenvectors[0, 0]) - eigenvectors = eigenvectors.swapaxes(0, 1) - eigenvectors = np.array([eigenvectors.real, eigenvectors.imag]) - eigenvectors = np.rollaxis(eigenvectors, 0, 5) - dct["vectors"] = eigenvectors.tolist() + eigen_vecs = self.eigendisplacements.copy() + eigen_vecs /= np.linalg.norm(eigen_vecs[0, 0]) + eigen_vecs = eigen_vecs.swapaxes(0, 1) + eigen_vecs = np.array([eigen_vecs.real, eigen_vecs.imag]) + eigen_vecs = np.rollaxis(eigen_vecs, 0, 5) + dct["vectors"] = eigen_vecs.tolist() return dct def band_reorder(self) -> None: """Re-order the eigenvalues according to the similarity of the eigenvectors.""" - eiv = self.eigendisplacements + eigen_displacements = self.eigendisplacements eig = self.bands n_phonons, n_qpoints = self.bands.shape @@ -607,19 +604,19 @@ def band_reorder(self) -> None: # get order for nq in range(1, n_qpoints): - old_eiv = eigenvectors_from_displacements(eiv[:, nq - 1], atomic_masses) - new_eiv = eigenvectors_from_displacements(eiv[:, nq], atomic_masses) + old_eig_vecs = eigenvectors_from_displacements(eigen_displacements[:, nq - 1], atomic_masses) + new_eig_vecs = eigenvectors_from_displacements(eigen_displacements[:, nq], atomic_masses) order[nq] = estimate_band_connection( - old_eiv.reshape([n_phonons, n_phonons]).T, - new_eiv.reshape([n_phonons, n_phonons]).T, + old_eig_vecs.reshape([n_phonons, n_phonons]).T, + new_eig_vecs.reshape([n_phonons, n_phonons]).T, order[nq - 1], ) # reorder for nq in range(1, n_qpoints): - eivq = eiv[:, nq] + eivq = eigen_displacements[:, nq] eigq = eig[:, nq] - eiv[:, nq] = eivq[order[nq]] + eigen_displacements[:, nq] = eivq[order[nq]] eig[:, nq] = eigq[order[nq]] def as_dict(self) -> dict: @@ -645,7 +642,6 @@ def from_dict(cls, dct: dict) -> PhononBandStructureSymmLine: eigendisplacements = ( np.array(dct["eigendisplacements"]["real"]) + np.array(dct["eigendisplacements"]["imag"]) * 1j ) - struct = Structure.from_dict(dct["structure"]) if "structure" in dct else None return cls( dct["qpoints"], np.array(dct["bands"]), @@ -653,5 +649,17 @@ def from_dict(cls, dct: dict) -> PhononBandStructureSymmLine: dct["has_nac"], eigendisplacements, dct["labels_dict"], - structure=struct, + structure=Structure.from_dict(dct["structure"]) if "structure" in dct else None, + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PhononBandStructureSymmLine): + return NotImplemented + return ( + self.bands.shape == other.bands.shape + and np.allclose(self.bands, other.bands) + and self.lattice_rec == other.lattice_rec + # and self.qpoints == other.qpoints + and self.labels_dict == other.labels_dict + and self.structure == other.structure ) diff --git a/pymatgen/phonon/plotter.py b/pymatgen/phonon/plotter.py index 5bb57a8fa8a..e756a91a671 100644 --- a/pymatgen/phonon/plotter.py +++ b/pymatgen/phonon/plotter.py @@ -4,7 +4,7 @@ import logging from collections import namedtuple -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable import matplotlib.pyplot as plt import numpy as np @@ -125,8 +125,7 @@ def get_dos_dict(self) -> dict: be the smeared densities, not the original densities. Returns: - Dict of dos data. Generally of the form, {label: {'frequencies':.., - 'densities': ...}} + dict: DOS data. Generally of the form {label: {'frequencies':.., 'densities': ...}} """ return jsanitize(self._doses) @@ -141,15 +140,13 @@ def get_plot( """Get a matplotlib plot showing the DOS. Args: - xlim: Specifies the x-axis limits. Set to None for automatic - determination. + xlim: Specifies the x-axis limits. Set to None for automatic determination. ylim: Specifies the y-axis limits. - units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1. + units (thz | ev | mev | ha | cm-1 | cm^-1): units for the frequencies. Defaults to "thz". legend: dict with legend options. For example, {"loc": "upper right"} - will place the legend in the upper right corner. Defaults to - {"fontsize": 30}. - ax (Axes): An existing axes object onto which the plot will be - added. If None, a new figure will be created. + will place the legend in the upper right corner. Defaults to {"fontsize": 30}. + ax (Axes): An existing axes object onto which the plot will be added. + If None, a new figure will be created. """ legend = legend or {} legend.setdefault("fontsize", 30) @@ -390,7 +387,7 @@ def _get_weight(self, vec: np.ndarray, indices: list[list[int]]) -> np.ndarray: @staticmethod def _make_color(colors: Sequence[int]) -> Sequence[int]: - """Convert the eigendisplacements to rgb colors.""" + """Convert the eigen-displacements to rgb colors.""" # if there are two groups, use red and blue if len(colors) == 2: return [colors[0], 0, colors[1]] @@ -601,22 +598,23 @@ def plot_compare( other_kwargs: dict | None = None, **kwargs, ) -> Axes: - """Plot two band structure for comparison. One is in red the other in blue. + """Plot two band structure for comparison. self in blue, other in red. The two band structures need to be defined on the same symmetry lines! - and the distance between symmetry lines is the one of the band structure - used to build the PhononBSPlotter. + The distance between symmetry lines is determined by the band structure used to + initialize PhononBSPlotter (self). Args: - other_plotter (PhononBSPlotter): another PhononBSPlotter object defined along the same symmetry lines + other_plotter (PhononBSPlotter): another PhononBSPlotter object defined along the + same symmetry lines units (str): units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1. Defaults to 'thz'. - labels (tuple[str, str] | None): labels for the two band structures. Defaults to None, which will use the - label of the two PhononBSPlotter objects if present. + labels (tuple[str, str] | None): labels for the two band structures. Defaults to None, + which will use the label of the two PhononBSPlotter objects if present. Label order is (self_label, other_label), i.e. the label of the PhononBSPlotter on which plot_compare() is called must come first. legend_kwargs: dict[str, Any]: kwargs passed to ax.legend(). - on_incompatible ('raise' | 'warn' | 'ignore'): What to do if the two band structures are not compatible. - Defaults to 'raise'. + on_incompatible ('raise' | 'warn' | 'ignore'): What to do if the two band structures + are not compatible. Defaults to 'raise'. other_kwargs: dict[str, Any]: kwargs passed to other_plotter ax.plot(). **kwargs: passed to ax.plot(). @@ -663,15 +661,14 @@ def plot_compare( def plot_brillouin(self) -> None: """Plot the Brillouin zone.""" + q_pts = self._bs.qpoints # get labels and lines - labels = {} - for q_pt in self._bs.qpoints: - if q_pt.label: - labels[q_pt.label] = q_pt.frac_coords + labels = {q_pt.label: q_pt.frac_coords for q_pt in q_pts if q_pt.label} - lines = [] - for b in self._bs.branches: - lines.append([self._bs.qpoints[b["start_index"]].frac_coords, self._bs.qpoints[b["end_index"]].frac_coords]) + lines = [ + [q_pts[branch["start_index"]].frac_coords, q_pts[branch["end_index"]].frac_coords] + for branch in self._bs.branches + ] plot_brillouin_zone(self._bs.lattice_rec, lines=lines, labels=labels) @@ -693,8 +690,8 @@ def __init__(self, dos: PhononDos, structure: Structure = None) -> None: def _plot_thermo( self, - func, - temperatures: Sequence, + func: Callable[[float, Structure | None], float], + temperatures: Sequence[float], factor: float = 1, ax: Axes = None, ylabel: str | None = None, @@ -705,10 +702,11 @@ def _plot_thermo( """Plots a thermodynamic property for a generic function from a PhononDos instance. Args: - func: the thermodynamic function to be used to calculate the property - temperatures: a list of temperatures + func (Callable[[float, Structure | None], float]): Takes a temperature and structure (in that order) + and returns a thermodynamic property (e.g., heat capacity, entropy, etc.). + temperatures (list[float]): temperatures (in K) at which to evaluate func. factor: a multiplicative factor applied to the thermodynamic property calculated. Used to change - the units. + the units. Defaults to 1. ax: matplotlib Axes or None if a new figure should be created. ylabel: label for the y axis label: label of the plot @@ -722,8 +720,8 @@ def _plot_thermo( values = [] - for t in temperatures: - values.append(func(t, structure=self.structure) * factor) + for temp in temperatures: + values.append(func(temp, self.structure) * factor) ax.plot(temperatures, values, label=label, **kwargs) diff --git a/tests/electronic_structure/test_bandstructure.py b/tests/electronic_structure/test_bandstructure.py index c2415bde72d..7f3461edf3d 100644 --- a/tests/electronic_structure/test_bandstructure.py +++ b/tests/electronic_structure/test_bandstructure.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import json import unittest @@ -27,6 +28,13 @@ def setUp(self): self.lattice = Lattice.cubic(10.0) self.kpoint = Kpoint([0.1, 0.4, -0.5], self.lattice, label="X") + def test_eq(self): + assert self.kpoint == self.kpoint + assert self.kpoint == copy.deepcopy(self.kpoint) + assert self.kpoint != Kpoint([0.1, 0.4, -0.5], self.lattice, label="Y") + assert self.kpoint != Kpoint([0.1, 0.4, -0.6], self.lattice, label="X") + assert self.kpoint != Kpoint([0.1, 0.4, -0.5], Lattice.cubic(20.0), label="X") + def test_properties(self): assert list(self.kpoint.frac_coords) == [0.1, 0.4, -0.5] assert self.kpoint.a == 0.1 diff --git a/tests/phonon/test_bandstructure.py b/tests/phonon/test_bandstructure.py index 99299746b98..9e2ec8cb469 100644 --- a/tests/phonon/test_bandstructure.py +++ b/tests/phonon/test_bandstructure.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import json from numpy.testing import assert_allclose, assert_array_equal @@ -26,6 +27,12 @@ def test_repr(self): r"PhononBandStructureSymmLine(bands=(6, 130), labels=['$\\Gamma$', 'X', 'W', 'K', 'L', 'U'])" ) + def test_eq(self): + assert self.bs == self.bs + assert self.bs == copy.deepcopy(self.bs) + assert self.bs2 == self.bs2 + assert self.bs != self.bs2 + def test_basic(self): assert self.bs.bands[1][10] == approx(0.7753555184) assert self.bs.bands[5][100] == approx(5.2548379776) From 40afffb1e7f4ba863a6244b8417ac41c9102a363 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Fri, 23 Feb 2024 17:45:14 +0100 Subject: [PATCH 2/2] Fix `BSPlotterProjected.get_projected_plots_dots_patom_pmorb` fix set & list intersect (#3651) * fix set & list intersect in BSPlotterProjected by converting list->set * slightly improve BSPlotterProjected coverage --- pymatgen/electronic_structure/plotter.py | 2 +- tests/electronic_structure/test_plotter.py | 27 +++++++++++++++------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/pymatgen/electronic_structure/plotter.py b/pymatgen/electronic_structure/plotter.py index 69d70bbed56..f75c6fdfbd1 100644 --- a/pymatgen/electronic_structure/plotter.py +++ b/pymatgen/electronic_structure/plotter.py @@ -1737,7 +1737,7 @@ def _Orbitals_SumOrbitals(cls, dictio, sum_morbs): ) if orb not in all_orbitals: raise ValueError(f"The invalid name of orbital in 'sum_morbs[{elt}]' is given.") - if orb in individual_orbs and len(set(sum_morbs[elt]) & individual_orbs[orb]) != 0: + if orb in individual_orbs and len(set(sum_morbs[elt]) & set(individual_orbs[orb])) != 0: raise ValueError(f"The 'sum_morbs[{elt}]' contains orbitals repeated.") nelems = Counter(sum_morbs[elt]).values() if sum(nelems) > len(nelems): diff --git a/tests/electronic_structure/test_plotter.py b/tests/electronic_structure/test_plotter.py index a40c8f987a9..54b7e17e5a0 100644 --- a/tests/electronic_structure/test_plotter.py +++ b/tests/electronic_structure/test_plotter.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt import numpy as np +import pytest from matplotlib import rc from numpy.testing import assert_allclose from pytest import approx @@ -182,19 +183,29 @@ class TestBSPlotterProjected(unittest.TestCase): def setUp(self): with open(f"{TEST_FILES_DIR}/Cu2O_361_bandstructure.json") as file: dct = json.load(file) - self.bs = BandStructureSymmLine.from_dict(dct) - self.plotter = BSPlotterProjected(self.bs) + self.bs_Cu2O = BandStructureSymmLine.from_dict(dct) + self.plotter_Cu2O = BSPlotterProjected(self.bs_Cu2O) + + with open(f"{TEST_FILES_DIR}/boltztrap2/PbTe_bandstructure.json") as file: + dct = json.load(file) + self.bs_PbTe = BandStructureSymmLine.from_dict(dct) - # Minimal baseline testing for get_plot. not a true test. Just checks that - # it can actually execute. def test_methods(self): - self.plotter.get_elt_projected_plots() - self.plotter.get_elt_projected_plots_color() - self.plotter.get_projected_plots_dots({"Cu": ["d", "s"], "O": ["p"]}) - self.plotter.get_projected_plots_dots_patom_pmorb( + # Minimal baseline testing for get_plot. not a true test. Just checks that + # it can actually execute. + self.plotter_Cu2O.get_elt_projected_plots() + self.plotter_Cu2O.get_elt_projected_plots_color() + self.plotter_Cu2O.get_projected_plots_dots({"Cu": ["d", "s"], "O": ["p"]}) + ax = self.plotter_Cu2O.get_projected_plots_dots_patom_pmorb( {"Cu": ["dxy", "s", "px"], "O": ["px", "py", "pz"]}, {"Cu": [3, 5], "O": [1]}, ) + assert isinstance(ax, plt.Axes) + assert len(ax.get_lines()) == 44_127 + assert ax.get_ylim() == pytest.approx((-4.0, 4.5047)) + + with pytest.raises(ValueError, match="try to plot projections on a band structure without any"): + self.plotter_PbTe = BSPlotterProjected(self.bs_PbTe) class TestBSDOSPlotter(unittest.TestCase):