-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathinitializeStructs.py
131 lines (86 loc) · 6.19 KB
/
initializeStructs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def initializeStructs(F,model,data_struct,settings):
Kz = F.shape[1]
Ks = settings['Ks']
prior_params = model['obsModel']['params']
if blockSize not in data_struct[0].keys():
data_struct[1].blockSize = []
if model['obsModel']['type']=='Gaussian':
dimu = data_struct[0]['obs'].shape[0]
for ii in range(0,length(data_struct)):
if np.size(data_struct[ii]['blockSize'])==0:
data_struct[ii]['blockSize'] = np.ones([1,data_struct[ii]['obs'].shape[1]])
data_struct[ii]['blockEnd'] = np.cumsum(data_struct[ii]['blockSize'])
theta = {'invSigma':np.zeros([dimu,dimu,Kz,Ks]),'mu':np.zeros([dimu,Kz,Ks])}
Ustats = {'card':np.zeros([Kz,Ks]),'YY':np.zeros([dimu,dimu,Kz,Ks]),'sumY':np.zeros([dimu,Kz,Ks])}
elif model['obsModel']['type']=='Multinomial':
for ii in range(0,length(data_struct)):
if data_struct[ii]['obs'].shape[0]>1:
raise ValueError('not multinomial obs')
if np.size(data_struct[ii]['blockSize'])==0:
data_struct[ii]['blockSize'] = np.ones([1,data_struct[ii]['obs'][1]])
data_struct[ii]['blockEnd'] = np.cumsum(data_struct[ii]['blockSize'])
data_struct[0]['numVocab'] = len(prior_params['alpha'])
theta = {'p':np.zeros([Kz,Ks,data_struct[0]['numVocab']])}
Ustats = {'card':np.zeros([data_struct[0]['numVocab'],Kz,Ks])}
elif model['obsModel']['type']=='AR' or model['obsModel']['type']=='SLDS':
if settings['Ks']!=1:
raise ValueError('Switching linear dynamical models only defined for Gaussian process noise, not MoG')
if model['obsModel']['priorType']=='MNIW':
dimu = prior_params['M'].shape[0]
dimX = prior_params['M'].shape[1]
theta = {'invSigma':np.zeros([dimu,dimu,Kz,Ks]),'A':np.zeros([dimu,dimX,Kz,Ks])}
elif model['obsModel']['priorType']=='MNIW-N' or model['obsModel']['priorType']=='N-IW-N':
dimu = prior_params['M'].shape[0]
dimX = prior_params['M'].shape[1]
theta = {'invSigma':np.zeros([dimu,dimu,Kz,Ks]),'A':np.zeros([dimu,dimX,Kz,Ks]),'mu':np.zeros([dimu,Kz,Ks])}
elif model['obsModel']['priorType']=='ARD':
dimu = prior_params['M'].shape[0]
dimX = prior_params['M'].shape[1]
theta = {'invSigma':np.zeros([dimu,dimu,Kz,Ks]),'A':np.zeros([dimu,dimX,Kz,Ks]),'mu':np.zeros([dimu,Kz,Ks]),'ARDypers':np.zeros([dimX,Kz,Ks])}
elif model['obsModel']['priorType']=='Afixed-IW-N':
dimu = prior_params['A'].shape[0]
dimX = prior_params['A'].shape[1]
theta = {'invSigma':np.zeros([dimu,dimu,Kz,Ks]),'A':np.kron(np.ones([1,1,Kz,Ks]),prior_params['A']),'mu':np.zeros([dimu,Kz,Ks])}
else:
raise ValueError('no known prior type')
Ustats = {'card':np.zeros((Kz,Ks)),'XX':np.zeros((dimX,dimX,Kz,Ks)),'YX':np.zeros((dimu,dimX,Kz,Ks)),'YY':np.zeros((dimu,dimu,Kz,Ks)),'sumY':np.zeros((dimu,Kz,Ks)),'sumX':np.zeros((dimX,Kz,Ks))}
if model['obsModel']['type']=='SLDS':
model['obsModel']['r'] = 1
if 'Kr' in settings.keys():
Kr = 1
model['HMMmodel']['params']['a_eta'] = 1
model['HMMmodel']['params']['b_eta'] = 1
print('Using single Gaussian measurement noise model')
else:
Kr = settings['Kr']
print('Using mixture of Gaussian measurement noise model')
dimy = prior_params['C'].shape[0]
if model['obsModel']['y_priorType']=='IW':
theta['theta_r'] = {'invSigma':np.zeros([dimy,dimy,Kr])}
elif model['obsModel']['y_priorType']=='NIW' or model['obsModel']['y_priorType']=='IW-N':
theta['theta_r'] = {'invSigma':np.zeros([dimy,dimy,Kr]),'mu':np.zeros([dimy,Kr])}
else:
raise ValueError('no known prior type for measurement noise')
Ustats['Ustats_r'] = {'card':np.zeros([1,Kr]),'YY':np.zeros([dimy,dimy,Kr]),'sumY':np.zeros([dimy,Kr])}
hyperparams['eta0'] = 0
stateCounts['Nr'] = np.zeros([1,Kr])
for ii in range(0,len(data_struct)):
if 'X' not in data_struct[ii].keys() or np.size(data_struct[ii]['X'])==0:
X,valid = makeDesignMatrix(data_struct[ii]['obs'],model['obsModel']['r'])
data_struct[ii]['obs'] = data_struct[ii]['obs'][:,valid.ravel().nonzero()]
data_struct[ii]['X'] = X[:,valid.ravel().nonzero()];
if np.size(data_struct[ii]['blockSize'])==0:
data_struct[ii]['blockSize'] = np.ones([1,data_struct[ii]['obs']])
data_struct[ii]['blockEnd'] = np.cumsum(data_struct[ii]['blockSize'])
if 'true_labels' in data_struct[ii].keys():
data_struct[ii]['true_labels'] = data_struct[ii]['true_labels'][valid.ravel().nonzero()]
numObj = len(data_struct)
stateCounts['N'] = np.zeros([Kz+1,Kz,numObj])
stateCounts['Ns'] = np.zeros([Kz,Ks,numObj])
hyperparams['gamma0'] = 0
hyperparams['alpha0'] = 0
hyperparams['kappa0'] = 0
hyperparams['sigma0'] = 0
numSaves = settings['saveEvery']/settings['storeEvery']
S[1:numSaves] = {'F':[],'config_log_likelihood':[],'theta':[],'dist_struct':[],'hyperparams':[],'stateSeq':[]}
return theta,Ustats,stateCounts,data_struct,model,S