Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] pomegranate.distributions.LogNormal.summarize() fails to invoke super() #1122

Open
levon003 opened this issue Nov 22, 2024 · 0 comments

Comments

@levon003
Copy link

levon003 commented Nov 22, 2024

Bug Description
LogNormal.summarize() throws a NoneType exception.

This is because Normal() correctly calls _distribution.Distribution.summarize() via super() before any processing with X:

X, sample_weight = super().summarize(X, sample_weight=sample_weight)

LogNormal() attempts to access self.means before super().summarize() is called.

To Reproduce

import numpy as np
import pomegranate.distributions
x_train = np.array([0, 1, 2, 2, 3, 4, 20, 21, 22, 22, 23, 24]).reshape([-1, 1])
d = pomegranate.distributions.LogNormal()
d.fit(x_train)

produces:

File [.../site-packages/pomegranate/distributions/lognormal.py:174, in LogNormal.summarize(self, X, sample_weight)
    172 if self.frozen is True:
    173     return
--> 174 X = _cast_as_tensor(X, dtype=self.means.dtype)
    175 super().summarize(X.log(), sample_weight=sample_weight)

AttributeError: 'NoneType' object has no attribute 'dtype'

I would cut a PR for this, but I'm not actually sure what the most Pythonic resolution for this kind of inheritance issue is.

A basic workaround is invoking summarize on the distribution manually:

import pomegranate.distributions._distribution
true_mean = 10
x_train = np.random.lognormal(mean=true_mean, sigma=5, size=10000).reshape([-1, 1])
d = pomegranate.distributions.LogNormal()
X, sample_weight = pomegranate.distributions._distribution.Distribution.summarize(d, X, sample_weight=None)
d.fit(x_train)
assert np.isclose(d.means[0], true_mean, atol=0.1)

Edit: a better workaround:

from types import MethodType
import pomegranate.gmm
import pomegranate.distributions
import pomegranate.distributions._distribution
import pomegranate._utils

def fixed_summarize(self, X, sample_weight=None):
    if self.frozen is True:
        return
    X, sample_weight = pomegranate.distributions._distribution.Distribution.summarize(self, X, sample_weight=sample_weight)
    X = pomegranate._utils._cast_as_tensor(X, dtype=self.means.dtype)
    pomegranate.distributions.Normal.summarize(self, X.log(), sample_weight=sample_weight)

x_train = np.array([1, 2, 2, 3, 4, 20, 21, 22, 22, 23, 24, 10000, 10001, 10002]).reshape([-1, 1])
components = []
for i in range(2):
    d = pomegranate.distributions.LogNormal()
    d.summarize = MethodType(fixed_summarize, d)
    components.append(d)
model = pomegranate.gmm.GeneralMixtureModel(components, tol=0.01).fit(x_train)
[d.means for d in model.distributions]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant