diff --git a/sklearn2pmml/decoration/__init__.py b/sklearn2pmml/decoration/__init__.py index e593e81..b77eab6 100644 --- a/sklearn2pmml/decoration/__init__.py +++ b/sklearn2pmml/decoration/__init__.py @@ -1,4 +1,5 @@ from pandas import DataFrame +from pandas.api.types import is_object_dtype from sklearn.base import clone, BaseEstimator, TransformerMixin from sklearn2pmml import _is_pandas_categorical from sklearn2pmml.util import cast, common_dtype, is_1d, to_numpy @@ -112,6 +113,7 @@ def is_missing(X, missing_value): if isinstance(missing_value, float) and numpy.isnan(missing_value): return pandas.isnull(X) return X == missing_value + if type(self.missing_values) is list: mask = numpy.full(X.shape, fill_value = False) for missing_value in self.missing_values: @@ -129,8 +131,14 @@ def _transform_missing_values(self, X, where): if self.missing_value_treatment == "return_invalid": if numpy.any(where) > 0: raise ValueError("Data contains {0} missing values".format(numpy.count_nonzero(where))) - if self.missing_value_replacement is not None: - X[where] = self.missing_value_replacement + elif self.missing_value_treatment in ["as_is", "as_mean", "as_mode", "as_median", "as_value"]: + if self.missing_value_replacement is not None: + X[where] = self.missing_value_replacement + # Special case for object data type columns: replacing non-None values with None values + elif is_object_dtype(self.dtype_) and (self.missing_value_replacement is None): + X[where] = self.missing_value_replacement + else: + raise ValueError() def _transform_valid_values(self, X, where): pass @@ -146,6 +154,8 @@ def _transform_invalid_values(self, X, where): elif self.invalid_value_treatment == "as_value": if self.invalid_value_replacement is not None: X[where] = self.invalid_value_replacement + else: + raise ValueError() def _compute_masks(self, X): X = to_numpy(X) diff --git a/sklearn2pmml/decoration/tests/__init__.py b/sklearn2pmml/decoration/tests/__init__.py index 7706ee6..333b212 100644 --- a/sklearn2pmml/decoration/tests/__init__.py +++ b/sklearn2pmml/decoration/tests/__init__.py @@ -145,7 +145,7 @@ def test_fit_int64(self): self.assertEqual([False, True, False, False, False], domain._missing_value_mask(X).tolist()) Xt = domain.fit_transform(X) self.assertEqual([-1, 1, 2], domain.data_values_.tolist()) - self.assertEqual([-1, pandas.NA, 1, 2, -1], Xt.tolist()) + self.assertEqual([-1, None, 1, 2, -1], Xt.tolist()) def test_fit_string(self): domain = clone(CategoricalDomain(with_data = False, with_statistics = False)) @@ -200,11 +200,11 @@ def test_fit_string_missing(self): self.assertEqual(["one", "two", "0"], Xt[:, 1].tolist()) def test_fit_string_valid(self): - domain = clone(CategoricalDomain(data_values = [["1", "2", "3"], ["zero", "one", "two"]], invalid_value_treatment = "as_value", invalid_value_replacement = "X")) + domain = clone(CategoricalDomain(data_values = [["1", "2", "3"], ["zero", "one", "two"]], invalid_value_treatment = "as_missing")) self.assertEqual("as_is", domain.missing_value_treatment) self.assertIsNone(domain.missing_value_replacement) - self.assertEqual("as_value", domain.invalid_value_treatment) - self.assertEqual("X", domain.invalid_value_replacement) + self.assertEqual("as_missing", domain.invalid_value_treatment) + self.assertIsNone(domain.invalid_value_replacement) self.assertTrue(hasattr(domain, "data_values")) self.assertFalse(hasattr(domain, "data_values_")) X = DataFrame([["-1", None], ["0", "zero"], ["3", None], ["1", "two"], ["2", "one"], ["0", "three"]]) @@ -212,8 +212,8 @@ def test_fit_string_valid(self): self.assertIsInstance(Xt, DataFrame) self.assertTrue(hasattr(domain, "data_values")) self.assertTrue(hasattr(domain, "data_values_")) - self.assertEqual(["X", "X", "3", "1", "2", "X"], Xt.iloc[:, 0].tolist()) - self.assertEqual([None, "zero", None, "two", "one", "X"], Xt[1].tolist()) + self.assertEqual([None, None, "3", "1", "2", None], Xt.iloc[:, 0].tolist()) + self.assertEqual([None, "zero", None, "two", "one", None], Xt[1].tolist()) self.assertEqual(2, len(domain.counts_)) self.assertEqual({"totalFreq" : 6, "missingFreq" : 0, "invalidFreq" : 3}, domain.counts_[0]) self.assertEqual({"totalFreq" : 6, "missingFreq" : 2, "invalidFreq" : 1}, domain.counts_[1])