Skip to content

Commit

Permalink
Fixed 'LookupTransformer.transform(X)' method. Fixes #395
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Oct 23, 2023
1 parent 51a937d commit f53450a
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 19 deletions.
33 changes: 22 additions & 11 deletions sklearn2pmml/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,14 +324,14 @@ def __init__(self, mapping, default_value):
k_type = None
v_type = None
for k, v in mapping.items():
if k is None:
raise ValueError("Key is None")
if pandas.isnull(k):
raise ValueError("Key is a missing value")
if k_type is None:
k_type = type(k)
else:
if type(k) != k_type:
raise TypeError("Key is not a {0}".format(k_type.__name__))
if v is None:
if pandas.isnull(v):
continue
if v_type is None:
v_type = type(v)
Expand All @@ -357,8 +357,13 @@ def fit(self, X, y = None):
def transform(self, X):
X = ensure_1d(X)
transform_dict = self._transform_dict()
func = lambda k: transform_dict[k]
Xt = eval_rows(X, func)

def _eval_row(x):
if pandas.isnull(x):
return x
return transform_dict[x]

Xt = eval_rows(X, _eval_row)
return _col2d(Xt)

class FilterLookupTransformer(LookupTransformer):
Expand All @@ -375,8 +380,8 @@ class FilterLookupTransformer(LookupTransformer):
def __init__(self, mapping):
super(FilterLookupTransformer, self).__init__(mapping, default_value = None)
for k, v in mapping.items():
if v is None:
raise ValueError("Value is None")
if pandas.isnull(v):
raise ValueError("Value is a missing value")
if type(k) != type(v):
raise TypeError("Key and Value type mismatch")

Expand Down Expand Up @@ -416,10 +421,16 @@ def fit(self, X, y = None):

def transform(self, X):
transform_dict = self._transform_dict()
# See https://stackoverflow.com/a/3460747
# See https://stackoverflow.com/a/3338368
func = lambda k: transform_dict[tuple(k) if isinstance(k, Hashable) else tuple(numpy.squeeze(numpy.asarray(k)))]
Xt = eval_rows(X, func)

def _eval_row(x):
if (pandas.isnull(x)).any():
return None
# See https://stackoverflow.com/a/3460747
# See https://stackoverflow.com/a/3338368
x = x if isinstance(x, Hashable) else tuple(numpy.squeeze(numpy.asarray(x)))
return transform_dict[tuple(x)]

Xt = eval_rows(X, _eval_row)
return _col2d(Xt)

def _make_index(values):
Expand Down
22 changes: 14 additions & 8 deletions sklearn2pmml/preprocessing/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,14 @@ def test_transform_float(self):
with self.assertRaises(TypeError):
LookupTransformer(mapping, int(0))
transformer = LookupTransformer(mapping, float("NaN"))
X = numpy.array([[0.0], [90.0]])
self.assertEqual([[math.cos(0.0)], [math.cos(90.0)]], transformer.transform(X).tolist())
X = numpy.array([180.0])
self.assertTrue(math.isnan(transformer.transform(X)))
X = Series([0.0, 45.0, 90.0])
self.assertEqual([[math.cos(0.0)], [math.cos(45.0)], [math.cos(90.0)]], transformer.transform(X).tolist())
X = numpy.array([[0.0], [90.0]])
self.assertEqual([[math.cos(0.0)], [math.cos(90.0)]], transformer.transform(X).tolist())
X = numpy.array([float("NaN"), 180.0])
self.assertTrue(nan_eq([[float("NaN")], [float("NaN")]], transformer.transform(X)))
transformer = LookupTransformer(mapping, -999.0)
self.assertTrue(nan_eq([[float("NaN")], [-999.0]], transformer.transform(X)))

def test_transform_string(self):
mapping = {
Expand All @@ -353,10 +355,12 @@ def test_transform_string(self):
LookupTransformer(mapping, None)
mapping.pop(None)
transformer = LookupTransformer(mapping, None)
X = numpy.array([["zero"], ["one"]])
self.assertEqual([[None], ["ein"]], transformer.transform(X).tolist())
X = Series(["one", "two", "three"])
self.assertEqual([["ein"], ["zwei"], ["drei"]], transformer.transform(X).tolist())
X = numpy.array([[None], ["zero"]])
self.assertEqual([[None], [None]], transformer.transform(X).tolist())
transformer = LookupTransformer(mapping, "(other)")
self.assertEqual([[None], ["(other)"]], transformer.transform(X).tolist())

class FilterLookupTransformerTest(TestCase):

Expand Down Expand Up @@ -413,8 +417,10 @@ def test_transform_object(self):
transformer = MultiLookupTransformer(mapping, None)
X = DataFrame([["one", None], ["one", True], [None, True], ["two", True], ["three", True]])
self.assertEqual([[None], ["ein"], [None], ["zwei"], ["drei"]], transformer.transform(X).tolist())
X = numpy.matrix([["one", True], ["one", None], ["two", True]], dtype = "O")
self.assertEqual([["ein"], [None], ["zwei"]], transformer.transform(X).tolist())
X = numpy.matrix([["one", True], ["one", None], ["one", False], ["two", True]], dtype = "O")
self.assertEqual([["ein"], [None], [None], ["zwei"]], transformer.transform(X).tolist())
transformer = MultiLookupTransformer(mapping, "(other)")
self.assertEqual([["ein"], [None], ["(other)"], ["zwei"]], transformer.transform(X).tolist())

class PMMLLabelBinarizerTest(TestCase):

Expand Down

0 comments on commit f53450a

Please sign in to comment.