-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbishop_generative_model.py
49 lines (34 loc) · 1.2 KB
/
bishop_generative_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
""" Reproduce figure 12.9 Bishop's PRML
"""
import torch
from torch import nn
from torch.distributions import constraints
import pyro
from pyro import poutine
from pyro.distributions import Normal, LogNormal, Dirichlet, Categorical, Gamma
from pyro.contrib.autoguide import *
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate
from matplotlib import pyplot as plt
import seaborn as sns
from common.plot.scatter import config_font_size
PLOT_DIR = "./plots/esann2019"
# note: working with all distribution in 1D and 2D
def plot_dist(samples, label="p(?)", name=None):
sns.distplot(samples, label=label)
plt.legend()
plt.savefig(f"{PLOT_DIR}/{name or label}.png", bbox_inches="tight")
def plot_ppca_model_2D(N=1000, D=2, M=1):
# z: [N, M], M = 1
p_z = Normal(loc=0.0, scale=1.0)
z = p_z.sample(sample_shape=(N,))
# w: [M, D], M = 1, D = 2
p_w = Normal(loc=torch.zeros([M, D]), scale=torch.ones([M, D])).to_event(D)
w = p_w.sample()
print(w)
plot_dist(z, label="p(z)", name="z_hist1D")
if __name__ == "__main__":
pyro.set_rng_seed(2019)
n_samples = 1000
config_font_size(min_size=12)
plot_ppca_model_2D()