Skip to content

Commit

Permalink
Expanded commit 4625273
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Dec 5, 2023
1 parent af6d6c0 commit 92ea6a8
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
3 changes: 2 additions & 1 deletion sklearn2pmml/decoration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ def _valid_value_mask(self, X, where):
if hasattr(X, "isin"):
mask = X.isin(self.data_)
else:
mask = numpy.isin(X, self.data_)
mask = numpy.full(X.shape, fill_value = False)
mask[where] = numpy.isin(X[where], self.data_)
return numpy.logical_and(mask, where)
return super(DiscreteDomain, self)._valid_value_mask(X, where)

Expand Down
15 changes: 15 additions & 0 deletions sklearn2pmml/decoration/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from unittest import TestCase

import numpy
import pandas

class AliasTest(TestCase):

Expand Down Expand Up @@ -113,6 +114,20 @@ def test_fit_int_missing(self):
Xt = domain.transform(X)
self.assertEqual([0, 0, 2], Xt.tolist())

def test_fit_int64(self):
domain = clone(CategoricalDomain())
X = Series([-1, None, 1, 2, -1]).astype("Int64")
self.assertEqual([False, True, False, False, False], domain._missing_value_mask(X).tolist())
Xt = domain.fit_transform(X)
self.assertEqual([-1, 1, 2], domain.data_.tolist())
self.assertEqual([-1, pandas.NA, 1, 2, -1], Xt.tolist())
domain = clone(CategoricalDomain())
X = X.to_numpy()
self.assertEqual([False, True, False, False, False], domain._missing_value_mask(X).tolist())
Xt = domain.fit_transform(X)
self.assertEqual([-1, 1, 2], domain.data_.tolist())
self.assertEqual([-1, pandas.NA, 1, 2, -1], Xt.tolist())

def test_fit_int_categorical(self):
domain = clone(CategoricalDomain(dtype = CategoricalDtype()))
self.assertIsNone(domain.dtype.categories)
Expand Down

0 comments on commit 92ea6a8

Please sign in to comment.