-
Notifications
You must be signed in to change notification settings - Fork 590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[QUESTION] I would like to fix the transition matrix upon running. Can I have some help pointing to what update steps need to be commented? #1092
Comments
Specifically, when running baum welch for the fit method, I'd like to iterate through any given sample without updating parameter updates for the transition edges |
If you pass in |
Thanks, I am trying to populate the SparseHMM object using sparse (CSR) matrix and am unsure if I am using the newer functionality correctly. Right now I have several steps to input each of the values individually as such (where E is a sparse emission matrix w/ shape (20000 ,1000), T is a sparse transition mat with shape (20000x20000): dists = [Categorical([row]) for row in E.toarray()]
ends = E.max(axis=1).data
ends /= np.sum(ends)
P0 /= np.sum(P0)
rows, cols = T.nonzero()
values = T.data
# Get corresponding distributions with transition prob
index_value_pairs = [(dists[row], dists[col], T[row, col]) for row, col in zip(rows, cols)]
model = SparseHMM(dists,edges =index_value_pairs, starts=P0,ends = ends, max_iter=1, verbose=True, inertia=1.0) I wasn't able to find anything detailing this scenario on the docs but does this follow your design? |
I think that |
Err, right the P0 is uniform. Now that I'm looking at it more carefully I am setting them as the stationary dist of the transition matrix, However I noticed that for some cases problems where I do not initialize the I haven't seen BW implementations that rely on the ending states, so I'm curious on the inspiration for this design. |
I don't think BW here is dependent on the end states but having it can add constraints that make optimization easier. It's hard to diagnose why one might observe the behavior you describe without knowing more about your data or model. |
Upon closer inspection with a debugger, it appears the problem is the result of integer underflow during the forward backward algorithm when the emission probabilities are also sparse. |
Can you elaborate a little bit more? Where are you getting integer underflow? The probabilities should all be floats, right? |
Err, not integer underflow my mistake just general arithmetic underflow or computing an undefined log expression. I'll try and get you a script that can verifiably replicate this problem sometime this week, with inputs as pickled files. But yes, the emission probabilities (whos rows initialize the categorical distribution class) have multiple zero values contained in them. I suspect this is what causes FB to fail. For example the t, f and b variables that get defined between ln 534-542 in sparse_hmm.py (the function |
Totally forgot to update this ticket but I solved this issue by the following modifications to the forward and backwards:
import torch
def forward(self, X=None, emissions=None, priors=None):
emissions = self._check_inputs(X, emissions, priors)
n, l, _ = emissions.shape
f = torch.full((l, n, self.n_distributions), torch.finfo(torch.float64).min, dtype=torch.float64,
device=self.device)
f[0] = self.starts + emissions[:, 0]
for i in range(1, l):
p = f[i-1, :, self._edge_idx_starts]
p += self._edge_log_probs.expand(n, -1)
alpha = torch.max(p, dim=1, keepdims=True).values
p = p - alpha # Stabilized values before exp to prevent underflow
p = torch.exp(p).clamp(min=1e-10) # Prevents exp(0) leading to exact 0s
z = torch.zeros_like(f[i])
z.scatter_add_(1, self._edge_idx_ends.expand(n, -1), p)
z = z.clamp(min=1e-10) # Prevents log(0)
f[i] = alpha.squeeze(1) + torch.log(z) + emissions[:, i] # Corrected alpha handling
f = f.permute(1, 0, 2)
return f |
def backward(self, X=None, emissions=None, priors=None):
emissions = self._check_inputs(X, emissions, priors)
n, l, _ = emissions.shape
b = torch.full((l, n, self.n_distributions), torch.finfo(torch.float64).min, dtype=torch.float64,
device=self.device)
b[-1] = self.ends + emissions[:, -1]
for i in range(l-2, -1, -1):
p = b[i+1, :, self._edge_idx_ends]
p += emissions[:, i+1]
p += self._edge_log_probs.expand(n, -1)
alpha = torch.max(p, dim=1, keepdims=True).values
p = p - alpha # Stabilized values before exp to prevent underflow
p = torch.exp(p).clamp(min=1e-10) # Prevents exp(0) leading to exact 0s
z = torch.zeros_like(b[i])
z.scatter_add_(1, self._edge_idx_starts.expand(n, -1), p)
z = z.clamp(min=1e-10) # Prevents log(0)
b[i] = alpha.squeeze(1) + torch.log(z) # Corrected alpha handling
b = b.permute(1, 0, 2)
return b |
If you think this is worth a PR, I can set that up |
Sorry the delay in response. I'm not sure that I agree with the clamping to a small value because it prevents true -neginfs, right? |
Describe the bug
A clear and concise description of what the bug is, including what you were expecting to happen and what actually happened. Please report the version of pomegranate that you are using and the operating system. Also, please make sure that you have upgraded to the latest version of pomegranate before submitting the bug report.
To Reproduce
Please provide a snippet of code that can reproduce this error. It is much easier for us to track down bugs and fix them if we have an example script that fails until we're successful.
Response time
Although I will likely respond during weekdays if I am not on vacation, I am not likely to be able to merge PRs or write code until the weekend.
The text was updated successfully, but these errors were encountered: