Skip to content

Commit

Permalink
Fixed missing value replacement
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jan 5, 2024
1 parent 4941ef9 commit cb65508
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
14 changes: 12 additions & 2 deletions sklearn2pmml/decoration/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pandas import DataFrame
from pandas.api.types import is_object_dtype
from sklearn.base import clone, BaseEstimator, TransformerMixin
from sklearn2pmml import _is_pandas_categorical
from sklearn2pmml.util import cast, common_dtype, is_1d, to_numpy
Expand Down Expand Up @@ -112,6 +113,7 @@ def is_missing(X, missing_value):
if isinstance(missing_value, float) and numpy.isnan(missing_value):
return pandas.isnull(X)
return X == missing_value

if type(self.missing_values) is list:
mask = numpy.full(X.shape, fill_value = False)
for missing_value in self.missing_values:
Expand All @@ -129,8 +131,14 @@ def _transform_missing_values(self, X, where):
if self.missing_value_treatment == "return_invalid":
if numpy.any(where) > 0:
raise ValueError("Data contains {0} missing values".format(numpy.count_nonzero(where)))
if self.missing_value_replacement is not None:
X[where] = self.missing_value_replacement
elif self.missing_value_treatment in ["as_is", "as_mean", "as_mode", "as_median", "as_value"]:
if self.missing_value_replacement is not None:
X[where] = self.missing_value_replacement
# Special case for object data type columns: replacing non-None values with None values
elif is_object_dtype(self.dtype_) and (self.missing_value_replacement is None):
X[where] = self.missing_value_replacement
else:
raise ValueError()

def _transform_valid_values(self, X, where):
pass
Expand All @@ -146,6 +154,8 @@ def _transform_invalid_values(self, X, where):
elif self.invalid_value_treatment == "as_value":
if self.invalid_value_replacement is not None:
X[where] = self.invalid_value_replacement
else:
raise ValueError()

def _compute_masks(self, X):
X = to_numpy(X)
Expand Down
12 changes: 6 additions & 6 deletions sklearn2pmml/decoration/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_fit_int64(self):
self.assertEqual([False, True, False, False, False], domain._missing_value_mask(X).tolist())
Xt = domain.fit_transform(X)
self.assertEqual([-1, 1, 2], domain.data_values_.tolist())
self.assertEqual([-1, pandas.NA, 1, 2, -1], Xt.tolist())
self.assertEqual([-1, None, 1, 2, -1], Xt.tolist())

def test_fit_string(self):
domain = clone(CategoricalDomain(with_data = False, with_statistics = False))
Expand Down Expand Up @@ -200,20 +200,20 @@ def test_fit_string_missing(self):
self.assertEqual(["one", "two", "0"], Xt[:, 1].tolist())

def test_fit_string_valid(self):
domain = clone(CategoricalDomain(data_values = [["1", "2", "3"], ["zero", "one", "two"]], invalid_value_treatment = "as_value", invalid_value_replacement = "X"))
domain = clone(CategoricalDomain(data_values = [["1", "2", "3"], ["zero", "one", "two"]], invalid_value_treatment = "as_missing"))
self.assertEqual("as_is", domain.missing_value_treatment)
self.assertIsNone(domain.missing_value_replacement)
self.assertEqual("as_value", domain.invalid_value_treatment)
self.assertEqual("X", domain.invalid_value_replacement)
self.assertEqual("as_missing", domain.invalid_value_treatment)
self.assertIsNone(domain.invalid_value_replacement)
self.assertTrue(hasattr(domain, "data_values"))
self.assertFalse(hasattr(domain, "data_values_"))
X = DataFrame([["-1", None], ["0", "zero"], ["3", None], ["1", "two"], ["2", "one"], ["0", "three"]])
Xt = domain.fit_transform(X)
self.assertIsInstance(Xt, DataFrame)
self.assertTrue(hasattr(domain, "data_values"))
self.assertTrue(hasattr(domain, "data_values_"))
self.assertEqual(["X", "X", "3", "1", "2", "X"], Xt.iloc[:, 0].tolist())
self.assertEqual([None, "zero", None, "two", "one", "X"], Xt[1].tolist())
self.assertEqual([None, None, "3", "1", "2", None], Xt.iloc[:, 0].tolist())
self.assertEqual([None, "zero", None, "two", "one", None], Xt[1].tolist())
self.assertEqual(2, len(domain.counts_))
self.assertEqual({"totalFreq" : 6, "missingFreq" : 0, "invalidFreq" : 3}, domain.counts_[0])
self.assertEqual({"totalFreq" : 6, "missingFreq" : 2, "invalidFreq" : 1}, domain.counts_[1])
Expand Down

0 comments on commit cb65508

Please sign in to comment.