Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DenseHMM in pomegranate 1.0.0 doesn't work with single-element sequences #1049

Open
vidraj opened this issue Jul 20, 2023 · 0 comments
Open

Comments

@vidraj
Copy link

vidraj commented Jul 20, 2023

In pomegranate 1.0.0, an attempt to classify a single-element sequence using DenseHMM fails with a RuntimeError due to an illegal reshaping operation. SparseHMM works fine.

Sample code that reproduces the problem (mostly copied from the docs):

import numpy as np
from pomegranate.distributions import Categorical
from pomegranate.hmm import DenseHMM

d1 = Categorical([[0.25, 0.25, 0.25, 0.25]])
d2 = Categorical([[0.10, 0.40, 0.40, 0.10]])

model = DenseHMM()
model.add_distributions([d1, d2])
model.add_edge(model.start, d1, 0.5)
model.add_edge(model.start, d2, 0.5)
model.add_edge(d1, d1, 0.9)
model.add_edge(d1, d2, 0.1)
model.add_edge(d2, d1, 0.1)
model.add_edge(d2, d2, 0.9)

sequence = 'A'
X = np.array([[[['A', 'C', 'G', 'T'].index(char)] for char in sequence]])

y_hat = model.predict(X)
print(y_hat)

Expected output:

tensor([[0]])

Actual output:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "venv/lib/python3.11/site-packages/pomegranate/hmm/_base.py", line 518, in predict
    return torch.argmax(self.predict_log_proba(X, priors=priors), dim=-1)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.11/site-packages/pomegranate/hmm/_base.py", line 459, in predict_log_proba
    _, r, _, _, _ = self.forward_backward(X, priors=priors)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.11/site-packages/pomegranate/hmm/dense_hmm.py", line 486, in forward_backward
    t = t.reshape(n, l-1, -1)
        ^^^^^^^^^^^^^^^^^^^^^
RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] because the unspecified dimension size -1 can be any value and is ambiguous
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant