Skip to content

Commit

Permalink
chore: precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
andrzejnovak committed Jul 11, 2022
1 parent c74e16b commit 9a1346e
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 22 deletions.
42 changes: 21 additions & 21 deletions src/cabinetry/visualize/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""High-level entry point for visualizing fit models and inference results."""

import fnmatch
import glob
import logging
import pathlib
Expand All @@ -16,6 +15,7 @@
from cabinetry.templates import builder
from cabinetry.visualize import plot_model
from cabinetry.visualize import plot_result
from cabinetry.visualize.utils import _exclude_matching


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -387,6 +387,8 @@ def correlation_matrix(
*,
figure_folder: Union[str, pathlib.Path] = "figures",
pruning_threshold: float = 0.0,
exclude: Optional[Union[str, List[str], Tuple[str, ...]]] = None,
exclude_by_type: Optional[List[str]] = None,
close_figure: bool = True,
save_figure: bool = True,
) -> mpl.figure.Figure:
Expand All @@ -399,6 +401,12 @@ def correlation_matrix(
figures in, defaults to "figures"
pruning_threshold (float, optional): minimum correlation for a parameter to
have with any other parameters to not get pruned, defaults to 0.0
exclude (Optional[Union[str, List[str], Tuple[str, ...]]], optional): parameter
or parameters to exclude from plot, defaults to None (nothing excluded),
compatible with unix wildcards
exclude_by_type (Optional[Union[str, List[str], Tuple[str, ...]]], optional):
exclude parameters of the given type, defaults ``['staterror']`` filtering
out mc_stat uncertainties which are centered on 1
close_figure (bool, optional): whether to close figure, defaults to True
save_figure (bool, optional): whether to save figure, defaults to True
Expand All @@ -423,7 +431,14 @@ def correlation_matrix(
fixed_parameter = np.all(np.equal(fit_results.corr_mat, 0.0), axis=0)
# get indices of rows/columns where everything is below threshold, or the parameter
# is fixed
delete_indices = np.where(np.logical_or(all_below_threshold, fixed_parameter))
exclude_set = _exclude_matching(
fit_results, exclude=exclude, exclude_by_type=exclude_by_type
)
exclude_indices = np.array(
[1 if lab in exclude_set else 0 for lab in fit_results.labels]
).astype(bool)
#
delete_indices = np.where(all_below_threshold | fixed_parameter | exclude_indices)
# delete rows and columns where all correlations are below threshold / parameter is
# fixed
corr_mat = np.delete(
Expand Down Expand Up @@ -472,13 +487,9 @@ def pulls(
[True if ty in ["normfactor"] else False for ty in fit_results.types]
)

if exclude is None:
exclude_set = set()
elif isinstance(exclude, str):
exclude_set = set(fnmatch.filter(fit_results.labels, exclude))
else:
exclude_set = set(exclude)

exclude_set = _exclude_matching(
fit_results, exclude=exclude, exclude_by_type=exclude_by_type
)
# exclude fixed parameters from pull plot
exclude_set.update(
[
Expand All @@ -488,19 +499,8 @@ def pulls(
]
)

# exclude by type
if exclude_by_type is None:
exclude_by_type = ["staterror"]
exclude_set.update(
[
label
for label, kind in zip(labels_np, fit_results.types)
if kind in exclude_by_type
]
)

# filter out user-specified parameters
mask = [True if label not in exclude_set else False for label in labels_np]
mask = [label not in exclude_set for label in labels_np]
bestfit = fit_results.bestfit[mask]
uncertainty = fit_results.uncertainty[mask]
labels_np = labels_np[mask]
Expand Down
36 changes: 35 additions & 1 deletion src/cabinetry/visualize/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Provides visualization utilities."""

import fnmatch
import logging
import pathlib
from typing import Optional
from typing import List, Optional, Set, Tuple, Union

import matplotlib as mpl
import matplotlib.pyplot as plt

from cabinetry import fit


log = logging.getLogger(__name__)

Expand Down Expand Up @@ -46,3 +49,34 @@ def _log_figure_path(path: Optional[pathlib.Path]) -> Optional[pathlib.Path]:
if path is not None:
return path.with_name(path.stem + "_log" + path.suffix)
return None


def _exclude_matching(
fit_results: fit.FitResults,
*,
exclude: Optional[Union[str, List[str], Tuple[str, ...]]] = None,
exclude_by_type: Optional[List[str]] = None,
) -> Set[str]:

labels = fit_results.labels
types = fit_results.types

if exclude is None:
exclude_set = set()
elif isinstance(exclude, str):
exclude_set = set(fnmatch.filter(labels, exclude))
elif isinstance(exclude, (list, tuple)):
exclude_set = set().union(
*[set(fnmatch.filter(labels, match_str)) for match_str in exclude]
)
else:
raise TypeError("exclude must be a string, list, or tuple")

# exclude by type
if exclude_by_type is None:
exclude_by_type = ["staterror"]

exclude_set.update(
[label for label, kind in zip(labels, types) if kind in exclude_by_type]
)
return exclude_set

0 comments on commit 9a1346e

Please sign in to comment.