Skip to content

Commit

Permalink
Priors visualizer (#278)
Browse files Browse the repository at this point in the history
* checkpoint, adding overview plot done in plt instead of plotly

* checkpoint fixing the overview plots and adding the pairwise code

* updating comments to match line_length limits, adding mcmc chain plot

* checkpoint, integrating vis_utils into default behavior of abstract_azure_runner

* increasing fig size of overview

* changing size of the correlation_pairs plot

* tight bounding boxes to avoid text cutoff

* adding plotly back since it is still used by the azure visualizer for now

* checkpoint, first draft

* adding a viz for prior distributions

* lowering number of samples for prior distributions
  • Loading branch information
arik-shurygin authored Nov 6, 2024
1 parent 25b3c58 commit 8eb45f3
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 4 deletions.
29 changes: 26 additions & 3 deletions src/mechanistic_azure/shiny_visualizers/azure_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
# this will reduce the time it takes to load the azure connection, but only shows
# one experiment worth of data, which may be what you want...
# leave empty ("") to explore all experiments
PRE_FILTER_EXPERIMENTS = (
"example_azure_experiment" # fifty_state_season2_5strain_2202_2404
)
PRE_FILTER_EXPERIMENTS = ""
# when loading the overview timelines csv for each run, columns
# are expected to have names corresponding to the type of plot they create
# vaccination_0_17 specifies the vaccination_ plot type, multiple columns may share
Expand Down Expand Up @@ -170,6 +168,12 @@
"Sample Violin Plots",
output_widget("plot_sample_violins"),
),
ui.nav_panel(
"Config Visualizer",
ui.output_plot(
"plot_prior_distributions", width=1600, height=1600
),
),
),
),
)
Expand Down Expand Up @@ -369,6 +373,25 @@ def plot_sample_correlations():
print("displaying correlations plot")
return fig

@output(id="plot_prior_distributions")
@render.plot
@reactive.event(input.action_button)
def plot_prior_distributions():
exp = input.experiment()
job_id = input.job_id()
states = input.states()
scenario = input.scenario()
theme = input.dark_mode()
theme = sutils.shiny_to_matplotlib_theme(theme)
cache_paths = sutils.get_azure_files(
exp, job_id, states, scenario, azure_client, SHINY_CACHE_PATH
)
# we have the figure, now update the light/dark mode depending on the switch
fig = sutils.load_prior_distributions_plot(cache_paths[0], theme)
# we have the figure, now update the light/dark mode depending on the switch
print("displaying prior distributions")
return fig

@output(id="plot_sample_violins")
@render_widget
@reactive.event(input.action_button)
Expand Down
34 changes: 34 additions & 0 deletions src/mechanistic_azure/shiny_visualizers/shiny_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tqdm import tqdm

from mechanistic_azure.azure_utilities import download_directory_from_azure
from resp_ode import Config, vis_utils
from resp_ode.utils import drop_keys_with_substring, flatten_list_parameters


Expand Down Expand Up @@ -309,6 +310,22 @@ def load_checkpoint_inference_chains(
return fig


def load_prior_distributions_plot(cache_path, matplotlib_theme):
path = os.path.join(cache_path, "config_inferer_used.json")
if os.path.exists(path):
config = Config(open(path).read())
styles = ["seaborn-v0_8-colorblind", matplotlib_theme]
fig = vis_utils.plot_prior_distributions(
config.asdict(), matplotlib_style=styles
)
else:
raise FileNotFoundError(
"%s does not exist, either the experiment did "
"not save a config used or loading files failed" % path
)
return fig


def load_checkpoint_inference_correlations(
cache_path,
overview_subplot_size: int,
Expand Down Expand Up @@ -855,3 +872,20 @@ def shiny_to_plotly_theme(shiny_theme: str):
plotly theme as str, used in `fig.update_layout(template=theme)`
"""
return "plotly_%s" % (shiny_theme if shiny_theme == "dark" else "white")


def shiny_to_matplotlib_theme(shiny_theme: str):
"""shiny themes are "dark" and "light", plotly themes are
"plotly_dark" and "plotly_white", this function converts from shiny to plotly theme names
Parameters
----------
shiny_theme : str
shiny theme as str
Returns
-------
str
plotly theme as str, used in `fig.update_layout(template=theme)`
"""
return "dark_background" if shiny_theme == "dark" else "ggplot"
3 changes: 3 additions & 0 deletions src/resp_ode/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def add_file(self, config_json_str):
self.set_downstream_parameters()
return self

def asdict(self):
return self.__dict__

def convert_types(self, config):
"""
takes a dictionary of config parameters, consults the PARAMETERS global list and attempts to convert the type
Expand Down
109 changes: 108 additions & 1 deletion src/resp_ode/vis_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
"""A series of utility functions for generating visualizations for the model"""

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from jax.random import PRNGKey
from matplotlib.axes import Axes
from matplotlib.colors import LinearSegmentedColormap

from .utils import drop_keys_with_substring, flatten_list_parameters
from .utils import (
drop_keys_with_substring,
flatten_list_parameters,
identify_distribution_indexes,
)


class VisualizationError(Exception):
pass


def _cleanup_and_normalize_timelines(
Expand Down Expand Up @@ -460,3 +470,100 @@ def plot_mcmc_chains(
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc="outside upper center")
return fig


def plot_prior_distributions(
priors: dict[str],
matplotlib_style: list[str]
| str = [
"seaborn-v0_8-colorblind",
],
num_samples=5000,
hist_kwargs={"bins": 50, "density": True},
) -> plt.Figure:
"""Given a dictionary of parameter keys and possibly values of
numpyro.distribution objects, samples them a number of times
and returns a plot of those samples to help
visualize the range of values taken by that prior distribution.
Parameters
----------
priors : dict[str: Any]
a dictionary with str keys possibly containing distribution
objects as values. Each key with a distribution object type
key will be included in the plot
matplotlib_style : list[str] | str, optional
matplotlib style to plot in by default ["seaborn-v0_8-colorblind"]
num_samples: int, optional
the number of times to sample each distribution, mild impact on
figure performance. By default 50000
hist_kwargs: dict[str: Any]
additional kwargs passed to plt.hist(), by default {"bins": 50}
Returns
-------
plt.Figure
matplotlib figure that is roughly square containing all distribution
keys found within priors.
"""
dist_only = {}
d = identify_distribution_indexes(priors)
# filter down to just the distribution objects
for dist_name, locator_dct in d.items():
parameter_name = locator_dct["sample_name"]
parameter_idx = locator_dct["sample_idx"]
# if the sample is on its own, not nested in a list, sample_idx is none
if parameter_idx is None:
dist_only[parameter_name] = priors[parameter_name]
# otherwise this sample is nested in a list and should be retrieved
else:
# go in index by index to access multi-dimensional lists
temp = priors[parameter_name]
for i in parameter_idx:
temp = temp[i]
dist_only[dist_name] = temp
param_names = list(dist_only.keys())
num_params = len(param_names)
if num_params == 0:
raise VisualizationError(
"Attempted to visualize a config without any distributions"
)
# Calculate the number of rows and columns for a square-ish layout
num_cols = int(np.ceil(np.sqrt(num_params)))
num_rows = int(np.ceil(num_params / num_cols))
with plt.style.context(matplotlib_style):
fig, axs = plt.subplots(
num_rows,
num_cols,
figsize=(3 * num_cols, 3 * num_rows),
squeeze=False,
)
# Flatten the axis array for easy indexing
axs_flat = axs.flatten()
# Loop over each parameter and sample
for i, param_name in enumerate(param_names):
ax: Axes = axs_flat[i]
ax.set_title(param_name)
dist = dist_only[param_name]
samples = dist.sample(PRNGKey(0), sample_shape=(num_samples,))
ax.hist(samples, **hist_kwargs)
ax.axvline(
samples.mean(),
linestyle="dashed",
linewidth=1,
label="mean",
)
ax.axvline(
jnp.median(samples),
linestyle="dotted",
linewidth=3,
label="median",
)
# Turn off any unused subplots
for j in range(i + 1, len(axs_flat)):
axs_flat[j].axis("off")
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc="outside upper right")
fig.suptitle("Prior Distributions Visualized, n=%s" % num_samples)
plt.tight_layout()
return fig

0 comments on commit 8eb45f3

Please sign in to comment.