-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsample_dist.py
64 lines (49 loc) · 2.27 KB
/
sample_dist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def sample_dist(stateCounts,hyperparams,Kextra):
numObj = stateCounts['Ns'].shape[2]
Kz_prev = stateCounts['Ns'].shape[0]
Kz = Kz_prev + Kextra;
Ks = stateCounts['Ns'].shape[1]
# Define alpha0 and kappa0 in terms of alpha0+kappa0 and rho0:
alpha0 = hyperparams['alpha0']
kappa0 = hyperparams['kappa0']
sigma0 = hyperparams['sigma0']
N = stateCounts['N'] # N(i,j) = # z_t = i to z_{t+1}=j transitions. N(Kz+1,i) = 1 for i=z_1.
Ns = stateCounts['Ns'] # Ns(i,j) = # s_t = j given z_t=i
dist_struct[0:numObj] = {'pi_z':np.zeros([Kz,Kz]),'pi_init':np.zeros([1,Kz]),'pi_s':np.zeros([Kz,Ks])}
beta_vec = np.ones([1,Kz])
Ntemp = np.zeros([Kz+1,Kz])
Nstemp = np.zeros([Kz,Ks])
for ii in range(0,numObj):
Ntemp[0:Kz_prev,0:Kz_prev] = N[0:Kz_prev,:,ii]
Ntemp[-1,1:Kz_prev] = N[Kz_prev+1,:,ii]
Nstemp[0:Kz_prev,:] = Ns[:,:,ii]
if Ks>1:
# Sample HMM-state-specific mixture weights \psi_j's with truncation
# level Ks given sampled s stats Ns:
sigma_vec = (sigma0/Ks)*np.ones([1,Ks])
else:
sigma_vec = sigma0
pi_z = np.zeros([Kz,Kz])
pi_s = np.zeros([Kz,Ks])
for j in range(0,Kz):
kappa_vec = np.zeros([1,Kz])
# Add an amount \kappa to Dirichlet parameter corresponding to a
# self-transition:
kappa_vec[j] = kappa0
# Sample \pi_j's given sampled \beta_vec and counts N, where
# DP(\alpha+\kappa,(\alpha\beta+\kappa\delta(j))/(\alpha+\kappa)) is
# Dirichlet distributed over the finite partition defined by beta_vec:
pi_z[j,:] = randdirichlet_unnorm([alpha0*beta_vec + kappa_vec + Ntemp[j,:]].T).T
# Sample HMM-state-specific mixture weights \psi_j's with truncation
# level Ks given sampled s stats Ns:
pi_s[j,:] = randdirichlet([Nstemp[j,:] + sigma_vec].T).T
pi_init = randdirichlet_unnorm([alpha0*beta_vec + Ntemp[Kz+1,:]].T).T
if 'Nr' in stateCounts.keys():
Nr = stateCounts['Nr'][ii,:] # Nr(i) = # r_t = i
Kr = len(Nr)
eta0 = hyperparams['eta0']
dist_struct[ii]['pi_r'] = randdirichlet([Nr + eta0/Kr].T).T
dist_struct[ii]['pi_z'] = pi_z
dist_struct[ii]['pi_init'] = pi_init
dist_struct[ii]['pi_s'] = pi_s
return dist_struct