Skip to content

Commit

Permalink
FIX more categorical
Browse files Browse the repository at this point in the history
  • Loading branch information
jmschrei committed Aug 26, 2023
1 parent 1ca070f commit 1faa346
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
4 changes: 2 additions & 2 deletions pomegranate/distributions/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def log_probability(self, X):
for i in range(self.d):
if isinstance(X, torch.masked.MaskedTensor):
logp_ = self._log_probs[i][X[:, i]._masked_data]
logp_[logp_ == float("-inf")] = 0
_inplace_add(logps, logp_)
logp_[~X[:, i]._masked_mask] = 0
logps += logp_
else:
logps += self._log_probs[i][X[:, i]]

Expand Down
6 changes: 4 additions & 2 deletions tests/distributions/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,8 @@ def test_masked_probability(probs, X, X_masked):

assert_array_almost_equal(y, d.probability(X_))

y = [0.042 , 0.0075, 0.001 , 0.003 , 0.007 , 0.001 , 0.014 ]
y = [2.1000e-01, 1.5000e-01, 1.0000e+00, 3.0000e-03, 1.0000e-01, 1.0000e-03,
1.4000e-01]
assert_array_almost_equal(y, d.probability(X_masked))


Expand All @@ -842,7 +843,8 @@ def test_masked_log_probability(probs, X, X_masked):

assert_array_almost_equal(y, d.log_probability(X_))

y = numpy.log([0.042 , 0.0075, 0.001 , 0.003 , 0.007 , 0.001 , 0.014 ])
y = numpy.log([2.1000e-01, 1.5000e-01, 1.0000e+00, 3.0000e-03, 1.0000e-01,
1.0000e-03, 1.4000e-01])
assert_array_almost_equal(y, d.log_probability(X_masked))

def test_masked_summarize(X, X_masked, w):
Expand Down

0 comments on commit 1faa346

Please sign in to comment.