Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a modified class for a GLM-HMM with different inputs/covariates for observation GLM and transition GLM #166

Open
wants to merge 43 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
d9b05c4
Update .gitignore
Zeinab-Mohammadi Oct 3, 2023
c55c8de
Update README.md
Zeinab-Mohammadi Oct 3, 2023
5498558
Initial commit of glmhmm transitions observations for git pull request
Zeinab-Mohammadi Dec 11, 2023
f2df3a7
github test
Zeinab-Mohammadi Dec 26, 2023
3ed0773
reverse_github_test
Zeinab-Mohammadi Dec 26, 2023
4e0a3d4
init_file_finished
Zeinab-Mohammadi Dec 26, 2023
e893d69
editting hmm file
Zeinab-Mohammadi Dec 27, 2023
c095138
Editing files for pull request
Zeinab-Mohammadi Dec 28, 2023
6006632
separating hmm_TO
Zeinab-Mohammadi Dec 29, 2023
7c3d49d
Edits for the notebook
Zeinab-Mohammadi Jan 2, 2024
7d412a6
Some comments for hmm and adding my notebook/ adding scott new notebooks
Zeinab-Mohammadi Jan 4, 2024
34ca0d5
update my notebooks for .py format
Zeinab-Mohammadi Jan 5, 2024
6a7ee72
edit .py
Zeinab-Mohammadi Jan 5, 2024
29b2ec1
edit .py
Zeinab-Mohammadi Jan 5, 2024
e69e0c5
another modification
Zeinab-Mohammadi Jan 5, 2024
7e4700f
add the title in pycharm
Zeinab-Mohammadi Jan 5, 2024
7d0ba14
title edit in pycharm
Zeinab-Mohammadi Jan 5, 2024
172b0fd
edit grammar errors
Zeinab-Mohammadi Jan 5, 2024
a09a322
edit file
Zeinab-Mohammadi Jan 23, 2024
a08fcf7
edits
Zeinab-Mohammadi Jan 23, 2024
5751e1c
final edits by comparing code with ssm
Zeinab-Mohammadi Jan 23, 2024
abb463b
edits based on ssm in observations file
Zeinab-Mohammadi Jan 23, 2024
2d37412
edits ssm observations
Zeinab-Mohammadi Jan 23, 2024
c8ffc96
more edits
Zeinab-Mohammadi Jan 23, 2024
2362483
edits
Zeinab-Mohammadi Jan 23, 2024
aa90e14
final edits for ssm
Zeinab-Mohammadi Jan 23, 2024
d1bcbce
workspace
Zeinab-Mohammadi Jan 24, 2024
51f131b
edit setup file
Zeinab-Mohammadi Jan 25, 2024
f360c2a
add setup
Zeinab-Mohammadi Feb 2, 2024
43c63fc
remove copy of setup
Zeinab-Mohammadi Feb 2, 2024
72b614f
edits for observation
Zeinab-Mohammadi Feb 2, 2024
ef69c84
Update 2c-Input-Driven-Transitions-and-Observations-GLM-HMM.py
Zeinab-Mohammadi Feb 2, 2024
d2a0d27
Final edits
Zeinab-Mohammadi Feb 2, 2024
cc70078
Merge branch 'master' of https://github.com/Zeinab-Mohammadi/ssm
Zeinab-Mohammadi Feb 2, 2024
a8c30f4
Update 2c-Input-Driven-Transitions-and-Observations-GLM-HMM.py
Zeinab-Mohammadi Feb 2, 2024
cc6cc13
edits
Zeinab-Mohammadi Feb 2, 2024
4fc7faf
Merge branch 'master' of https://github.com/Zeinab-Mohammadi/ssm
Zeinab-Mohammadi Feb 2, 2024
8b796d3
editing the notebook
Zeinab-Mohammadi Feb 2, 2024
69520c5
Update 2c-Input-Driven-Transitions-and-Observations-GLM-HMM.py
Zeinab-Mohammadi Feb 2, 2024
57b2e35
edits
Zeinab-Mohammadi Feb 2, 2024
b0b92c8
final edit
Zeinab-Mohammadi Feb 2, 2024
a000eec
Adding paper link
Zeinab-Mohammadi Feb 5, 2024
35377d4
New updates
Zeinab-Mohammadi Aug 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,5 @@ notebooks/*.html
.vscode

.DS_Store

.idea/
938 changes: 938 additions & 0 deletions notebooks/2c-Input-Driven-Transitions-and-Observations-GLM-HMM.ipynb

Large diffs are not rendered by default.

454 changes: 454 additions & 0 deletions notebooks/2c-Input-Driven-Transitions-and-Observations-GLM-HMM.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ doc =
memory_profiler # measuring memory during docs building
mkl
jupytext
myst-nb
myst-nb
myst-parser
numpydoc
sphinx
Expand Down
2 changes: 1 addition & 1 deletion ssm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Default imports for SSM

from .hmm_TO import *
from .hmm import *
from .lds import *
551 changes: 551 additions & 0 deletions ssm/hmm_TO.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions ssm/init_state_distns.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def m_step(self, expectations, datas, inputs, masks, tags, **kwargs):
pi0 = sum([Ez[0] for Ez, _, _ in expectations]) + 1e-8
self.log_pi0 = np.log(pi0 / pi0.sum())

def m_step_modified(self, expectations, datas, transition_input, observation_input, masks, tags, **kwargs):
pi0 = sum([Ez[0] for Ez, _, _ in expectations]) + 1e-8
self.log_pi0 = np.log(pi0 / pi0.sum())

class FixedInitialStateDistribution(InitialStateDistribution):
def __init__(self, K, D, pi0=None, M=0):
Expand Down
291 changes: 291 additions & 0 deletions ssm/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,297 @@ def smooth(self, expectations, data, input, tag):
"""
raise NotImplementedError

class InputDrivenObservationsDiffInputs(Observations):

def __init__(self, K, D, M_obs=0, C=2, prior_mean=0, prior_sigma=1000):
"""
@param K: number of states
@param D: dimensionality of output
@param C: number of distinct classes for each dimension of output
@param prior_sigma: parameter governing strength of prior. Prior on GLM weights is multivariate
normal distribution with mean 'prior_mean' and diagonal covariance matrix (prior_sigma is on diagonal)
"""
super(InputDrivenObservationsDiffInputs, self).__init__(K, D, M_obs)
self.C = C
self.M_obs = M_obs
self.D = D
self.K = K
self.prior_mean = prior_mean
self.prior_sigma = prior_sigma
# Parameters linking input to distribution over output classes
self.Wk = npr.randn(K, C - 1, M_obs)

@property
def params(self):
return self.Wk

@params.setter
def params(self, value):
self.Wk = value

def permute(self, perm):
self.Wk = self.Wk[perm]

def log_prior(self):
lp = 0
for k in range(self.K):
for c in range(self.C - 1):
weights = self.Wk[k][c]
lp += stats.multivariate_normal_logpdf(weights, mus=np.repeat(self.prior_mean, (self.M_obs)),
Sigmas=((self.prior_sigma) ** 2) * np.identity(self.M_obs))
return lp

# Calculate time dependent logits - output is matrix of size TxKxC
# Input is size TxM
def calculate_logits(self, observation_input):
"""
Return array of size TxKxC containing log(pr(yt=C|zt=k))
:param observation_input: observation_input array of covariates of size TxM_obs
:return: array of size TxKxC containing log(pr(yt=c|zt=k, ut)) for all c in {1, ..., C} and k in {1, ..., K}
"""
# Transpose array dimensions, so that array is now of shape ((C-1)xKx(M+1))
Wk_tranpose = np.transpose(self.Wk, (1, 0, 2))
# Stack column of zeros to transform array from size ((C-1)xKx(M_obs+1)) to ((C)xKx(M_obs+1)) and then transform shape back to (KxCx(M_obs+1))
Wk = np.transpose(np.vstack([Wk_tranpose, np.zeros((1, Wk_tranpose.shape[1], Wk_tranpose.shape[2]))]),
(1, 0, 2))
# Input effect; transpose so that output has dims TxKxC
time_dependent_logits = np.transpose(np.dot(Wk, observation_input.T), (2, 0,
1)) # Note: this has an unexpected effect when both input (and thus Wk) are empty arrays and returns an array of zeros
time_dependent_logits = time_dependent_logits - logsumexp(time_dependent_logits, axis=2, keepdims=True)
return time_dependent_logits

def log_likelihoods(self, data, observation_input, mask, tag):
if observation_input.ndim == 1 and observation_input.shape == (
self.M_obs,): # if input is vector of size self.M_obs (one time point), expand dims to be (1, M_obs)
observation_input = np.expand_dims(observation_input, axis=0)
time_dependent_logits = self.calculate_logits(observation_input)
assert self.D == 1, "InputDrivenObservationsDiffInputs written for D = 1!"
mask = np.ones_like(data, dtype=bool) if mask is None else mask
return stats.categorical_logpdf(data[:, None, :], time_dependent_logits[:, :, None, :], mask=mask[:, None, :])

def sample_x(self, z, xhist, observation_input=None, tag=None, with_noise=True):
assert self.D == 1, "InputDrivenObservationsDiffInputs written for D = 1!"
if observation_input.ndim == 1 and observation_input.shape == (self.M_obs,):
observation_input = np.expand_dims(observation_input, axis=0)
time_dependent_logits = self.calculate_logits(observation_input) # size TxKxC
ps = np.exp(time_dependent_logits)
T = time_dependent_logits.shape[0]

if T == 1:
sample = np.array([npr.choice(self.C, p=ps[t, z]) for t in range(T)])
elif T > 1:
sample = np.array([npr.choice(self.C, p=ps[t, z[t]]) for t in range(T)])
return sample

def m_step(self, expectations, datas, observation_input, masks, tags, optimizer="bfgs", **kwargs):

T = sum([data.shape[0] for data in datas]) # total number of data points: time_bins

def _multisoftplus(X):
'''
computes f(X) = log(1+sum(exp(X), axis =1)) and its first derivative
:param X: array of size Tx(C-1)
:return f(X) of size T and df of size (Tx(C-1))
'''
X_augmented = np.append(X, np.zeros((X.shape[0], 1)),
1) # append a column of zeros to X for rowmax calculation
rowmax = np.max(X_augmented, axis=1,
keepdims=1) # get max along column for log-sum-exp trick, rowmax is size T
# compute f:
f = np.log(np.exp(-rowmax[:, 0]) + np.sum(np.exp(X - rowmax), axis=1)) + rowmax[:, 0]
# compute df
df = np.exp(X - rowmax) / np.expand_dims((np.exp(-rowmax[:, 0]) + np.sum(np.exp(X - rowmax), axis=1)),
axis=1)
return f, df

def _objective(params, k):
'''
computes term in negative expected complete loglikelihood that depends on weights for state k
:param params: vector of size (C-1)xM_obs
:return term in negative expected complete LL that depends on weights for state k; scalar value
'''
W = np.reshape(params, (self.C - 1, self.M_obs))
obj = 0
for data, input, mask, tag, (expected_states, _, _) \
in zip(datas, observation_input, masks, tags, expectations):
xproj = input @ W.T # projection of input onto weight matrix for particular state, size is Tx(C-1)
f, _ = _multisoftplus(xproj)
assert data.shape[1] == 1, "InputDrivenObservationsDiffInputs written for D = 1!"
data_one_hot = one_hot(data[:, 0], self.C) # convert to one-hot representation of size TxC
temp_obj = (-np.sum(data_one_hot[:, :-1] * xproj, axis=1) + f) @ expected_states[:, k]
obj += temp_obj

# add contribution of prior:
if self.prior_sigma != 0:
obj += 1 / (2 * self.prior_sigma ** 2) * np.sum(W ** 2)
return obj / T

def _gradient(params, k):
'''
Explicit calculation of gradient of _objective w.r.t weight matrix for state k, W_{k}
:param params: vector of size (C-1)xM_obs
:param k: state whose parameters we are currently optimizing
:return gradient of objective with respect to parameters; vector of size (C-1)xM_obs
'''
W = np.reshape(params, (self.C - 1, self.M_obs))
grad = np.zeros((self.C - 1, self.M_obs))
for data, input, mask, tag, (expected_states, _, _) \
in zip(datas, observation_input, masks, tags, expectations):
xproj = input @ W.T # projection of input onto weight matrix for particular state, size is Tx(C-1)
_, df = _multisoftplus(xproj)
assert data.shape[1] == 1, "InputDrivenObservationsDiffInputs written for D = 1!"
data_one_hot = one_hot(data[:, 0], self.C) # convert to one-hot representation of size TxC
grad += (df - data_one_hot[:, :-1]).T @ (
expected_states[:, [k]] * input) # gradient is shape (C-1,M_obs)
# Add contribution to gradient from prior:
if self.prior_sigma != 0:
grad += (1 / (self.prior_sigma) ** 2) * W
# Now flatten grad into a vector:
grad = grad.flatten()
return grad / T

def _hess(params, k):
'''
Explicit calculation of hessian of _objective w.r.t weight matrix for state k, W_{k}
:param params: vector of size (C-1)xM_obs
:param k: state whose parameters we are currently optimizing
:return hessian of objective with respect to parameters; matrix of size ((C-1)xM_obs) x ((C-1)xM_obs)
'''
W = np.reshape(params, (self.C - 1, self.M_obs))
hess = np.zeros(((self.C - 1) * self.M_obs, (self.C - 1) * self.M_obs))
for data, input, mask, tag, (expected_states, _, _) \
in zip(datas, observation_input, masks, tags, expectations):
xproj = input @ W.T # projection of input onto weight matrix for particular state
_, df = _multisoftplus(xproj)
# center blocks:
dftensor = np.expand_dims(df, axis=2) # dims are now (T, (C-1), 1)
Xdf = np.expand_dims(input,
axis=1) * dftensor # multiply every input covariate term with every class derivative term for a given time step; dims are now (T, (C-1), M)
# reshape Xdf to (T, (C-1)*M_obs)
Xdf = np.reshape(Xdf, (Xdf.shape[0], -1))
# weight Xdf by posterior state probabilities
pXdf = expected_states[:, [k]] * Xdf # output is size (T, (C-1)*M_obs)
# outer product with input vector, size (M_obs, (C-1)*M_obs)
XXdf = input.T @ pXdf
# center blocks of hessian:
temp_hess = np.zeros(((self.C - 1) * self.M_obs, (self.C - 1) * self.M_obs))
for c in range(1, self.C):
inds = range((c - 1) * self.M_obs, c * self.M_obs)
temp_hess[np.ix_(inds, inds)] = XXdf[:, inds]
# off diagonal entries:
hess += temp_hess - Xdf.T @ pXdf
# add contribution of prior to hessian
if self.prior_sigma != 0:
hess += (1 / self.prior_sigma ** 2)
return hess / T

from scipy.optimize import minimize
# Optimize weights for each state separately:
for k in range(self.K):
def _objective_k(params):
return _objective(params, k)

def _gradient_k(params):
return _gradient(params, k)

def _hess_k(params):
return _hess(params, k)

sol = minimize(_objective_k, self.params[k].reshape(((self.C - 1) * self.M_obs)), hess=_hess_k,
jac=_gradient_k, method="trust-ncg")
self.params[k] = np.reshape(sol.x, (self.C - 1,
self.M_obs)) # for InputDrivenObservationsDiffInputs class: comment out if you want to stop observation weights being updated

def smooth(self, expectations, data, observation_input, tag):
"""
Compute the mean observation under the posterior distribution
of latent discrete states.
"""
raise NotImplementedError


class _AutoRegressiveObservationsBase(Observations):
"""
Base class for autoregressive observations of the form,

E[x_t | x_{t-1}, z_t=k, u_t]
= \sum_{l=1}^{L} A_k^{(l)} x_{t-l} + b_k + V_k u_t.

where L is the number of lags and u_t is the input.
"""

def __init__(self, K, D, M=0, lags=1):
super(_AutoRegressiveObservationsBase, self).__init__(K, D, M)

# Distribution over initial point
self.mu_init = np.zeros((K, D))

# AR parameters
assert lags > 0
self.lags = lags
self.bs = npr.randn(K, D)
self.Vs = npr.randn(K, D, M)

# Inheriting classes may treat _As differently
self._As = None

@property
def As(self):
return self._As

@As.setter
def As(self, value):
self._As = value

@property
def params(self):
return self.As, self.bs, self.Vs

@params.setter
def params(self, value):
self.As, self.bs, self.Vs = value

def permute(self, perm):
self.mu_init = self.mu_init[perm]
self.As = self.As[perm]
self.bs = self.bs[perm]
self.Vs = self.Vs[perm]

def _compute_mus(self, data, input, mask, tag):
# assert np.all(mask), "ARHMM cannot handle missing data"
K, M = self.K, self.M
T, D = data.shape
As, bs, Vs, mu0s = self.As, self.bs, self.Vs, self.mu_init

# Instantaneous inputs
mus = np.empty((K, T, D))
mus = []
for k, (A, b, V, mu0) in enumerate(zip(As, bs, Vs, mu0s)):
# Initial condition
mus_k_init = mu0 * np.ones((self.lags, D))

# Subsequent means are determined by the AR process
mus_k_ar = np.dot(input[self.lags:, :M], V.T)
for l in range(self.lags):
Al = A[:, l * D:(l + 1) * D]
mus_k_ar = mus_k_ar + np.dot(data[self.lags - l - 1:-l - 1], Al.T)
mus_k_ar = mus_k_ar + b

# Append concatenated mean
mus.append(np.vstack((mus_k_init, mus_k_ar)))

return np.array(mus)

def smooth(self, expectations, data, input, tag):
"""
Compute the mean observation under the posterior distribution
of latent discrete states.
"""
T = expectations.shape[0]
mask = np.ones((T, self.D), dtype=bool)
mus = np.swapaxes(self._compute_mus(data, input, mask, tag), 0, 1)
return (expectations[:, :, None] * mus).sum(1)


class _AutoRegressiveObservationsBase(Observations):
"""
Expand Down
Loading