Skip to content

Commit

Permalink
feat: refactor histplot + new get_plottables function (#534)
Browse files Browse the repository at this point in the history
  • Loading branch information
0ctagon authored Oct 31, 2024
1 parent a646154 commit 3077cde
Show file tree
Hide file tree
Showing 3 changed files with 302 additions and 159 deletions.
2 changes: 2 additions & 0 deletions src/mplhep/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
yscale_legend,
)
from .styles import set_style
from .utils import get_plottables

# Configs
rcParams = Config(
Expand Down Expand Up @@ -76,4 +77,5 @@
"sort_legend",
"save_variations",
"set_style",
"get_plottables",
]
171 changes: 14 additions & 157 deletions src/mplhep/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from mpl_toolkits.axes_grid1 import axes_size, make_axes_locatable

from .utils import (
Plottable,
align_marker,
get_histogram_axes_title,
get_plottable_protocol_bins,
get_plottables,
hist_object_handler,
isLight,
process_histogram_parts,
Expand Down Expand Up @@ -198,86 +198,18 @@ def histplot(
else get_histogram_axes_title(hists[0].axes[0])
)

plottables = []
flow_bins = final_bins
for h in hists:
value, variance = np.copy(h.values()), h.variances()
if has_variances := variance is not None:
variance = np.copy(variance)
underflow, overflow = 0.0, 0.0
underflowv, overflowv = 0.0, 0.0
# One sided flow bins - hist (uproot hist does not have the over- or underflow traits)
if (
hasattr(h, "axes")
and (traits := getattr(h.axes[0], "traits", None)) is not None
and hasattr(traits, "underflow")
and hasattr(traits, "overflow")
):
if traits.overflow:
overflow = np.copy(h.values(flow=True))[-1]
if has_variances:
overflowv = np.copy(h.variances(flow=True))[-1]
if traits.underflow:
underflow = np.copy(h.values(flow=True))[0]
if has_variances:
underflowv = np.copy(h.variances(flow=True))[0]
# Both flow bins exist - uproot
elif hasattr(h, "values") and "flow" in inspect.getfullargspec(h.values).args:
if len(h.values()) + 2 == len(
h.values(flow=True)
): # easy case, both over/under
underflow, overflow = (
np.copy(h.values(flow=True))[0],
np.copy(h.values(flow=True))[-1],
)
if has_variances:
underflowv, overflowv = (
np.copy(h.variances(flow=True))[0],
np.copy(h.variances(flow=True))[-1],
)

# Set plottables
if flow in ("none", "hint"):
plottables.append(Plottable(value, edges=final_bins, variances=variance))
elif flow == "show":
_flow_bin_size: float = np.max(
[0.05 * (final_bins[-1] - final_bins[0]), np.mean(np.diff(final_bins))]
)
flow_bins = np.copy(final_bins)
if underflow > 0:
flow_bins = np.r_[flow_bins[0] - _flow_bin_size, flow_bins]
value = np.r_[underflow, value]
if has_variances:
variance = np.r_[underflowv, variance]
if overflow > 0:
flow_bins = np.r_[flow_bins, flow_bins[-1] + _flow_bin_size]
value = np.r_[value, overflow]
if has_variances:
variance = np.r_[variance, overflowv]
plottables.append(Plottable(value, edges=flow_bins, variances=variance))
elif flow == "sum":
if underflow > 0:
value[0] += underflow
if has_variances:
variance[0] += underflowv
if overflow > 0:
value[-1] += overflow
if has_variances:
variance[-1] += overflowv
plottables.append(Plottable(value, edges=final_bins, variances=variance))
else:
plottables.append(Plottable(value, edges=final_bins, variances=variance))

if w2 is not None:
for _w2, _plottable in zip(
w2.reshape(len(plottables), len(final_bins) - 1), plottables
):
_plottable.variances = _w2
_plottable.method = w2method

if w2 is not None and yerr is not None:
msg = "Can only supply errors or w2"
raise ValueError(msg)
plottables, flow_info = get_plottables(
hists,
bins=final_bins,
w2=w2,
w2method=w2method,
yerr=yerr,
stack=stack,
density=density,
binwnorm=binwnorm,
flow=flow,
)
flow_bins, underflow, overflow = flow_info

_labels: list[str | None]
if label is None:
Expand Down Expand Up @@ -311,52 +243,6 @@ def iterable_not_string(arg):
for i in range(len(_chunked_kwargs)):
_chunked_kwargs[i][kwarg] = kwargs[kwarg]

############################
# # yerr calculation
_yerr: np.ndarray | None
if yerr is not None:
# yerr is array
if hasattr(yerr, "__len__"):
_yerr = np.asarray(yerr)
# yerr is a number
elif isinstance(yerr, (int, float)) and not isinstance(yerr, bool):
_yerr = np.ones((len(plottables), len(final_bins) - 1)) * yerr
# yerr is automatic
else:
_yerr = None
else:
_yerr = None

if _yerr is not None:
assert isinstance(_yerr, np.ndarray)
if _yerr.ndim == 3:
# Already correct format
pass
elif _yerr.ndim == 2 and len(plottables) == 1:
# Broadcast ndim 2 to ndim 3
if _yerr.shape[-2] == 2: # [[1,1], [1,1]]
_yerr = _yerr.reshape(len(plottables), 2, _yerr.shape[-1])
elif _yerr.shape[-2] == 1: # [[1,1]]
_yerr = np.tile(_yerr, 2).reshape(len(plottables), 2, _yerr.shape[-1])
else:
msg = "yerr format is not understood"
raise ValueError(msg)
elif _yerr.ndim == 2:
# Broadcast yerr (nh, N) to (nh, 2, N)
_yerr = np.tile(_yerr, 2).reshape(len(plottables), 2, _yerr.shape[-1])
elif _yerr.ndim == 1:
# Broadcast yerr (1, N) to (nh, 2, N)
_yerr = np.tile(_yerr, 2 * len(plottables)).reshape(
len(plottables), 2, _yerr.shape[-1]
)
else:
msg = "yerr format is not understood"
raise ValueError(msg)

assert _yerr is not None
for yrs, _plottable in zip(_yerr, plottables):
_plottable.fixed_errors(*yrs)

# Sorting
if sort is not None:
if isinstance(sort, str):
Expand All @@ -379,34 +265,6 @@ def iterable_not_string(arg):
_chunked_kwargs = [_chunked_kwargs[ix] for ix in order]
_labels = [_labels[ix] for ix in order]

# ############################
# # Stacking, norming, density
if density is True and binwnorm is not None:
msg = "Can only set density or binwnorm."
raise ValueError(msg)
if density is True:
if stack:
_total = np.sum(
np.array([plottable.values for plottable in plottables]), axis=0
)
for plottable in plottables:
plottable.flat_scale(1.0 / np.sum(np.diff(final_bins) * _total))
else:
for plottable in plottables:
plottable.density = True
elif binwnorm is not None:
for plottable, norm in zip(
plottables, np.broadcast_to(binwnorm, (len(plottables),))
):
plottable.flat_scale(norm)
plottable.binwnorm()

# Stack
if stack and len(plottables) > 1:
from .utils import stack as stack_fun

plottables = stack_fun(*plottables)

##########
# Plotting
return_artists: list[StairsArtists | ErrorBarArtists] = []
Expand Down Expand Up @@ -443,8 +301,7 @@ def iterable_not_string(arg):
if "step" in histtype:
for i in range(len(plottables)):
do_errors = yerr is not False and (
(yerr is not None or w2 is not None)
or (plottables[i].variances is not None)
(yerr is not None or w2 is not None) or plottables[i]._has_variances
)

_kwargs = _chunked_kwargs[i]
Expand Down
Loading

0 comments on commit 3077cde

Please sign in to comment.