Skip to content

Commit

Permalink
Changes to Formation Energy Diagram Plotting (#140)
Browse files Browse the repository at this point in the history
* make element_change more general

* grouping and plotting

* minor changes to plotting

* fixed test
  • Loading branch information
jmmshn authored Aug 18, 2023
1 parent ec4f507 commit eb8e254
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 45 deletions.
10 changes: 10 additions & 0 deletions pymatgen/analysis/defects/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,16 @@ def defect_type(self) -> int:
"""
return getattr(DefectType, self.__class__.__name__)

@property
def latex_name(self) -> str:
"""Get the latex name of the defect.
Returns:
str: The latex name of the defect.
"""
root, suffix = self.name.split("_")
return rf"{root} $_{{\rm {suffix}}}$"


class Vacancy(Defect):
"""Class representing a vacancy defect."""
Expand Down
111 changes: 66 additions & 45 deletions pymatgen/analysis/defects/thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@

import logging
from dataclasses import dataclass, field
from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

import numpy as np
from matplotlib import cm
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from monty.json import MSONable
from numpy.typing import ArrayLike, NDArray
from pymatgen.analysis.chempot_diagram import ChemicalPotentialDiagram
Expand Down Expand Up @@ -721,7 +720,9 @@ def _get_name(entry):


def group_formation_energy_diagrams(
feds: list[FormationEnergyDiagram], sm: StructureMatcher = None
feds: list[FormationEnergyDiagram],
sm: StructureMatcher = None,
combine_diagrams: bool = True,
):
"""Group formation energy diagrams by their representation.
Expand All @@ -730,9 +731,12 @@ def group_formation_energy_diagrams(
Args:
feds: list of formation energy diagrams
sm: StructureMatcher to use for grouping
combine_diagrams: whether to combine matching diagrams into a single diagram
Returns:
Generator of (name, list of formation energy diagrams) tuples
If combine_diagrams is True, generator of (name, combined formation energy diagram) tuples.
If combine_diagrams is False, generator of (name, list of formation energy diagrams) tuples.
"""
if sm is None:
sm = StructureMatcher(comparator=ElementComparator())
Expand All @@ -746,8 +750,15 @@ def _get_name(fed):
fed_groups = group_docs(
feds, sm=sm, get_structure=_get_structure, get_hash=_get_name
)
for g_name, g_entries in fed_groups:
yield g_name, g_entries
for g_name, f_group in fed_groups:
if combine_diagrams:
fed = f_group[0]
fed_d = fed.as_dict()
dents = [dfed.defect_entries for dfed in f_group]
fed_d["defect_entries"] = list(chain.from_iterable(dents))
yield g_name, FormationEnergyDiagram.from_dict(fed_d)
else:
yield g_name, f_group


def ensure_stable_bulk(
Expand Down Expand Up @@ -960,6 +971,7 @@ def plot_formation_energy_diagrams(
band_edge_color="k",
filterfunction: Callable | None = None,
legend_loc: str = "lower center",
show_legend: bool = True,
axis=None,
):
"""Plot the formation energy diagram.
Expand Down Expand Up @@ -1021,25 +1033,19 @@ def plot_formation_energy_diagrams(
else:
xmin, xmax = xlim
ymin, ymax = np.inf, -np.inf
legends_txt = []
artists = []
legends_txt: list = []
artists: list = []
fontwidth = 12
ax_fontsize = 1.3
lg_fontsize = 10

colors = (
colors
if colors
else cm.Dark2(np.linspace(0, 1, len(formation_energy_diagrams)))
if len(formation_energy_diagrams) <= 8
else cm.gist_rainbow(np.linspace(0, 1, len(formation_energy_diagrams)))
)
named_feds = []
for name_, feds_ in group_formation_energy_diagrams(formation_energy_diagrams):
for fed_ in feds_:
named_feds.append((name_, fed_))
for name_, fed_ in group_formation_energy_diagrams(formation_energy_diagrams):
named_feds.append((name_, fed_))

color_line_gen = _get_line_color_and_style(colors, linestyle)
for fid, (fed_name, single_fed) in enumerate(named_feds):
cur_color, cur_style = next(color_line_gen)
chempots_ = (
chempots
if chempots
Expand All @@ -1055,34 +1061,35 @@ def plot_formation_energy_diagrams(
trans_y = trans[:, 1]
ymin = min(ymin, min(trans_y))
ymax = max(ymax, max(trans_y))
axis.plot(

dfct: Defect = single_fed.defect_entries[0].defect
latexname = dfct.latex_name
if legend_prefix is not None:
latexname = f"{legend_prefix} {latexname}"

if ":" in fed_name:
latexname += f" ({fed_name.split(':')[1]})"

(l,) = axis.plot(
np.subtract(trans[:, 0], alignment),
trans_y,
color=colors[fid],
ls=linestyle,
color=cur_color,
ls=cur_style,
lw=linewidth,
alpha=envelope_alpha,
label=fed_name,
label=latexname,
marker=transition_marker,
markersize=transition_markersize,
)
if not only_lower_envelope:
cur_color = l.get_color()
for line in lines:
x = np.linspace(xmin, xmax)
y = line[0] * x + line[1]
axis.plot(
np.subtract(x, alignment), y, color=colors[fid], alpha=line_alpha
np.subtract(x, alignment), y, color=cur_color, alpha=line_alpha
)

# get latex-like legend titles
dfct = single_fed.defect_entries[0].defect
flds = dfct.name.split("_")
latexname = f"${flds[0]}_{{{flds[1]}}}$"
if legend_prefix:
latexname = f"{legend_prefix} {latexname}"
legends_txt.append(latexname)
artists.append(Line2D([0], [0], color=colors[fid], lw=4))

axis.set_xlim(xmin, xmax)
axis.set_ylim(ylim[0] if ylim else ymin - 0.1, ylim[1] if ylim else ymax + 0.1)
axis.set_xlabel("Fermi energy (eV)", size=ax_fontsize * fontwidth)
Expand Down Expand Up @@ -1122,19 +1129,20 @@ def plot_formation_energy_diagrams(
alpha=0.8,
)

lg = axis.get_legend()
if lg:
handle, leg = lg.legendHandles, [txt._text for txt in lg.texts]
else:
handle, leg = [], []

axis.legend(
handles=artists + handle,
labels=legends_txt + leg,
fontsize=lg_fontsize * ax_fontsize,
ncol=3,
loc=legend_loc,
)
if show_legend:
lg = axis.get_legend()
if lg:
handle, leg = lg.legendHandles, [txt._text for txt in lg.texts]
else:
handle, leg = [], []

axis.legend(
handles=artists + handle,
labels=legends_txt + leg,
fontsize=lg_fontsize * ax_fontsize,
ncol=3,
loc=legend_loc,
)

if save:
save = save if isinstance(save, str) else "formation_energy_diagram.png"
Expand All @@ -1143,3 +1151,16 @@ def plot_formation_energy_diagrams(
plt.show()

return axis


def _get_line_color_and_style(colors=None, styles=None):
if colors is None:
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
if styles is None:
styles = ["-", "--", "-.", ":"]
else:
styles = [styles] if isinstance(styles, str) else styles

for style in styles:
for color in colors:
yield color, style

0 comments on commit eb8e254

Please sign in to comment.