Skip to content

Commit c0dac7f

Browse files
committed
Summarize model as rich table
1 parent 95ddad8 commit c0dac7f

File tree

3 files changed

+289
-0
lines changed

3 files changed

+289
-0
lines changed

docs/api_reference.rst

+9
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,12 @@ Model Transforms
7474

7575
autoreparam.vip_reparametrize
7676
autoreparam.VIP
77+
78+
79+
Printing
80+
========
81+
.. currentmodule:: pymc_experimental.printing
82+
.. autosummary::
83+
:toctree: generated/
84+
85+
model_table

pymc_experimental/printing.py

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import numpy as np
2+
3+
from pymc import Model
4+
from pymc.printing import str_for_dist, str_for_potential_or_deterministic
5+
from pytensor import Mode
6+
from pytensor.compile.sharedvalue import SharedVariable
7+
from pytensor.graph.type import Constant, Variable
8+
from rich.box import SIMPLE_HEAD
9+
from rich.table import Table
10+
11+
12+
def variable_expression(
13+
model: Model,
14+
var: Variable,
15+
truncate_deterministic: int | None,
16+
) -> str:
17+
"""Get the expression of a variable in a human-readable format."""
18+
if var in model.data_vars:
19+
var_expr = "Data"
20+
elif var in model.deterministics:
21+
str_repr = str_for_potential_or_deterministic(var, dist_name="")
22+
_, var_expr = str_repr.split(" ~ ")
23+
var_expr = var_expr[1:-1] # Remove outer parentheses (f(...))
24+
if truncate_deterministic is not None and len(var_expr) > truncate_deterministic:
25+
contents = var_expr[2:-1].split(", ")
26+
str_len = 0
27+
for show_n, content in enumerate(contents):
28+
str_len += len(content) + 2
29+
if str_len > truncate_deterministic:
30+
break
31+
var_expr = f"f({', '.join(contents[:show_n])}, ...)"
32+
elif var in model.potentials:
33+
var_expr = str_for_potential_or_deterministic(var, dist_name="Potential").split(" ~ ")[1]
34+
else: # basic_RVs
35+
var_expr = str_for_dist(var).split(" ~ ")[1]
36+
return var_expr
37+
38+
39+
def _extract_dim_value(var: SharedVariable | Constant) -> np.ndarray:
40+
if isinstance(var, SharedVariable):
41+
return var.get_value(borrow=True)
42+
else:
43+
return var.data
44+
45+
46+
def dims_expression(model: Model, var: Variable) -> str:
47+
"""Get the dimensions of a variable in a human-readable format."""
48+
if (dims := model.named_vars_to_dims.get(var.name)) is not None:
49+
dim_sizes = {dim: _extract_dim_value(model.dim_lengths[dim]) for dim in dims}
50+
return " × ".join(f"{dim}[{dim_size}]" for dim, dim_size in dim_sizes.items())
51+
else:
52+
dim_sizes = list(var.shape.eval(mode=Mode(linker="py", optimizer="fast_compile")))
53+
return f"[{', '.join(map(str, dim_sizes))}]" if dim_sizes else ""
54+
55+
56+
def model_parameter_count(model: Model) -> int:
57+
"""Count the number of parameters in the model."""
58+
rv_shapes = model.eval_rv_shapes() # Includes transformed variables
59+
return np.sum([np.prod(rv_shapes[free_rv.name]).astype(int) for free_rv in model.free_RVs])
60+
61+
62+
def model_table(
63+
model: Model,
64+
*,
65+
split_groups: bool = True,
66+
truncate_deterministic: int | None = None,
67+
parameter_count: bool = True,
68+
) -> Table:
69+
"""Create a rich table with a summary of the model's variables and their expressions.
70+
71+
Parameters
72+
----------
73+
model : Model
74+
The PyMC model to summarize.
75+
split_groups : bool
76+
If True, each group of variables (data, free_RVs, deterministics, potentials, observed_RVs)
77+
will be separated by a section.
78+
truncate_deterministic : int | None
79+
If not None, truncate the expression of deterministic variables that go beyond this length.
80+
empty_dims : bool
81+
If True, show the dimensions of scalar variables as an empty list.
82+
parameter_count : bool
83+
If True, add a row with the total number of parameters in the model.
84+
85+
Returns
86+
-------
87+
Table
88+
A rich table with the model's variables, their expressions and dims.
89+
90+
Examples
91+
--------
92+
.. code-block:: python
93+
94+
import numpy as np
95+
import pymc as pm
96+
97+
from pymc_experimental.printing import model_table
98+
99+
coords = {"subject": range(20), "param": ["a", "b"]}
100+
with pm.Model(coords=coords) as m:
101+
x = pm.Data("x", np.random.normal(size=(20, 2)), dims=("subject", "param"))
102+
y = pm.Data("y", np.random.normal(size=(20,)), dims="subject")
103+
104+
beta = pm.Normal("beta", mu=0, sigma=1, dims="param")
105+
mu = pm.Deterministic("mu", pm.math.dot(x, beta), dims="subject")
106+
sigma = pm.HalfNormal("sigma", sigma=1)
107+
108+
y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, dims="subject")
109+
110+
table = model_table(m)
111+
table # Displays the following table in an interactive environment
112+
'''
113+
Variable Expression Dimensions
114+
─────────────────────────────────────────────────────
115+
x = Data subject[20] × param[2]
116+
y = Data subject[20]
117+
118+
beta ~ Normal(0, 1) param[2]
119+
sigma ~ HalfNormal(0, 1)
120+
Parameter count = 3
121+
122+
mu = f(beta) subject[20]
123+
124+
y_obs ~ Normal(mu, sigma) subject[20]
125+
'''
126+
127+
Output can be explicitly rendered in a rich console or exported to text, html or svg.
128+
129+
.. code-block:: python
130+
131+
from rich.console import Console
132+
133+
console = Console(record=True)
134+
console.print(table)
135+
text_export = console.export_text()
136+
html_export = console.export_html()
137+
svg_export = console.export_svg()
138+
139+
"""
140+
table = Table(
141+
show_header=True,
142+
show_edge=False,
143+
box=SIMPLE_HEAD,
144+
highlight=False,
145+
collapse_padding=True,
146+
)
147+
table.add_column("Variable", justify="right")
148+
table.add_column("Expression", justify="left")
149+
table.add_column("Dimensions")
150+
151+
if split_groups:
152+
groups = (
153+
model.data_vars,
154+
model.free_RVs,
155+
model.deterministics,
156+
model.potentials,
157+
model.observed_RVs,
158+
)
159+
else:
160+
# Show variables in the order they were defined
161+
groups = (model.named_vars.values(),)
162+
163+
for group in groups:
164+
if not group:
165+
continue
166+
167+
for var in group:
168+
var_name = var.name
169+
sep = f'[b]{" ~" if (var in model.basic_RVs) else " ="}[/b]'
170+
var_expr = variable_expression(model, var, truncate_deterministic)
171+
dims_expr = dims_expression(model, var)
172+
if dims_expr == "[]":
173+
dims_expr = ""
174+
table.add_row(var_name + sep, var_expr, dims_expr)
175+
176+
if parameter_count and (not split_groups or group == model.free_RVs):
177+
n_parameters = model_parameter_count(model)
178+
table.add_row("", "", f"[i]Parameter count = {n_parameters}[/i]")
179+
180+
table.add_section()
181+
182+
return table

tests/test_printing.py

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import numpy as np
2+
import pymc as pm
3+
4+
from rich.console import Console
5+
6+
from pymc_experimental.printing import model_table
7+
8+
9+
def get_text(table) -> str:
10+
console = Console(width=80)
11+
with console.capture() as capture:
12+
console.print(table)
13+
return capture.get()
14+
15+
16+
def test_model_table():
17+
with pm.Model(coords={"trial": range(6), "subject": range(20)}) as model:
18+
x_data = pm.Data("x_data", np.random.normal(size=(6, 20)), dims=("trial", "subject"))
19+
y_data = pm.Data("y_data", np.random.normal(size=(6, 20)), dims=("trial", "subject"))
20+
21+
mu = pm.Normal("mu", mu=0, sigma=1)
22+
sigma = pm.HalfNormal("sigma", sigma=1)
23+
global_intercept = pm.Normal("global_intercept", mu=0, sigma=1)
24+
intercept_subject = pm.Normal("intercept_subject", mu=0, sigma=1, shape=(20, 1))
25+
beta_subject = pm.Normal("beta_subject", mu=mu, sigma=sigma, dims="subject")
26+
27+
mu_trial = pm.Deterministic(
28+
"mu_trial",
29+
global_intercept.squeeze() + intercept_subject + beta_subject * x_data,
30+
dims=["trial", "subject"],
31+
)
32+
noise = pm.Exponential("noise", lam=1)
33+
y = pm.Normal("y", mu=mu_trial, sigma=noise, observed=y_data, dims=("trial", "subject"))
34+
35+
pm.Potential("beta_subject_penalty", -pm.math.abs(beta_subject), dims="subject")
36+
37+
table_txt = get_text(model_table(model))
38+
expected = """ Variable Expression Dimensions
39+
────────────────────────────────────────────────────────────────────────────────
40+
x_data = Data trial[6] × subject[20]
41+
y_data = Data trial[6] × subject[20]
42+
43+
mu ~ Normal(0, 1)
44+
sigma ~ HalfNormal(0, 1)
45+
global_intercept ~ Normal(0, 1)
46+
intercept_subject ~ Normal(0, 1) [20, 1]
47+
beta_subject ~ Normal(mu, sigma) subject[20]
48+
noise ~ Exponential(f())
49+
Parameter count = 44
50+
51+
mu_trial = f(intercept_subject, trial[6] × subject[20]
52+
beta_subject,
53+
global_intercept)
54+
55+
beta_subject_penalty = Potential(f(beta_subject)) subject[20]
56+
57+
y ~ Normal(mu_trial, noise) trial[6] × subject[20]
58+
"""
59+
assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]
60+
61+
table_txt = get_text(model_table(model, split_groups=False))
62+
expected = """ Variable Expression Dimensions
63+
────────────────────────────────────────────────────────────────────────────────
64+
x_data = Data trial[6] × subject[20]
65+
y_data = Data trial[6] × subject[20]
66+
mu ~ Normal(0, 1)
67+
sigma ~ HalfNormal(0, 1)
68+
global_intercept ~ Normal(0, 1)
69+
intercept_subject ~ Normal(0, 1) [20, 1]
70+
beta_subject ~ Normal(mu, sigma) subject[20]
71+
mu_trial = f(intercept_subject, trial[6] × subject[20]
72+
beta_subject,
73+
global_intercept)
74+
noise ~ Exponential(f())
75+
y ~ Normal(mu_trial, noise) trial[6] × subject[20]
76+
beta_subject_penalty = Potential(f(beta_subject)) subject[20]
77+
Parameter count = 44
78+
"""
79+
assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]
80+
81+
table_txt = get_text(
82+
model_table(model, split_groups=False, truncate_deterministic=30, parameter_count=False)
83+
)
84+
expected = """ Variable Expression Dimensions
85+
────────────────────────────────────────────────────────────────────────────
86+
x_data = Data trial[6] × subject[20]
87+
y_data = Data trial[6] × subject[20]
88+
mu ~ Normal(0, 1)
89+
sigma ~ HalfNormal(0, 1)
90+
global_intercept ~ Normal(0, 1)
91+
intercept_subject ~ Normal(0, 1) [20, 1]
92+
beta_subject ~ Normal(mu, sigma) subject[20]
93+
mu_trial = f(intercept_subject, ...) trial[6] × subject[20]
94+
noise ~ Exponential(f())
95+
y ~ Normal(mu_trial, noise) trial[6] × subject[20]
96+
beta_subject_penalty = Potential(f(beta_subject)) subject[20]
97+
"""
98+
assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]

0 commit comments

Comments
 (0)