Skip to content

Commit

Permalink
Excluded the missing value replacement value from the list of valid v…
Browse files Browse the repository at this point in the history
…alues
  • Loading branch information
vruusmann committed Jan 5, 2024
1 parent b5a362b commit 4941ef9
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 10 deletions.
6 changes: 0 additions & 6 deletions sklearn2pmml/decoration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,6 @@ def fit(self, X, y = None):
data_values = self.dtype_.categories
else:
data_values = numpy.unique(X[nonmissing_mask])
if (self.missing_value_replacement is not None) and numpy.any(missing_mask) > 0:
if _is_pandas_categorical(self.dtype_):
raise ValueError()
data_values = numpy.unique(numpy.append(data_values, self.missing_value_replacement))
else:
data_values = numpy.asarray(self.data_values)
self.data_values_ = data_values
Expand All @@ -236,8 +232,6 @@ def fit(self, X, y = None):
col_missing_mask = missing_mask[:, col]
col_nonmissing_mask = nonmissing_mask[:, col]
data_values = numpy.unique(col_X[col_nonmissing_mask])
if (self.missing_value_replacement is not None) and numpy.any(col_missing_mask) > 0:
data_values = numpy.unique(numpy.append(data_values, self.missing_value_replacement))
else:
data_values = numpy.asarray(self.data_values[col])
self.data_values_.append(data_values)
Expand Down
8 changes: 4 additions & 4 deletions sklearn2pmml/decoration/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def test_fit_int_missing(self):
Xt = domain.fit_transform(X)
self.assertIsInstance(Xt, DataFrame)
self.assertEqual(2, len(domain.data_values_))
self.assertEqual([0, 1, 2, 3], domain.data_values_[0].tolist())
self.assertEqual([0, 1, 2], domain.data_values_[1].tolist())
self.assertEqual([1, 2, 3], domain.data_values_[0].tolist())
self.assertEqual([1, 2], domain.data_values_[1].tolist())
self.assertEqual(2, len(domain.counts_))
self.assertEqual({"totalFreq" : 6, "missingFreq" : 2, "invalidFreq" : 0}, domain.counts_[0])
self.assertEqual({"totalFreq" : 6, "missingFreq" : 3, "invalidFreq" : 0}, domain.counts_[1])
Expand Down Expand Up @@ -182,8 +182,8 @@ def test_fit_string_missing(self):
Xt = domain.fit_transform(X)
self.assertIsInstance(Xt, DataFrame)
self.assertEqual(2, len(domain.data_values_))
self.assertEqual(["0", "1", "2", "3"], domain.data_values_[0].tolist())
self.assertEqual(["0", "one", "three", "two"], domain.data_values_[1].tolist())
self.assertEqual(["1", "2", "3"], domain.data_values_[0].tolist())
self.assertEqual(["one", "three", "two"], domain.data_values_[1].tolist())
self.assertEqual(2, len(domain.counts_))
self.assertEqual({"totalFreq" : 6, "missingFreq" : 2, "invalidFreq" : 0}, domain.counts_[0])
self.assertEqual({"totalFreq" : 6, "missingFreq" : 3, "invalidFreq" : 0}, domain.counts_[1])
Expand Down

0 comments on commit 4941ef9

Please sign in to comment.