Skip to content

Commit

Permalink
Added 'DiscreteDomain.data_values' attribute. Fixes #300
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jan 5, 2024
1 parent ed842fb commit b5a362b
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 32 deletions.
61 changes: 38 additions & 23 deletions sklearn2pmml/decoration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,12 @@ def transform(self, X):

class DiscreteDomain(Domain):

def __init__(self, missing_values = None, missing_value_treatment = "as_is", missing_value_replacement = None, invalid_value_treatment = "return_invalid", invalid_value_replacement = None, with_data = True, with_statistics = True, dtype = None, display_name = None):
def __init__(self, missing_values = None, missing_value_treatment = "as_is", missing_value_replacement = None, invalid_value_treatment = "return_invalid", invalid_value_replacement = None, with_data = True, with_statistics = True, dtype = None, display_name = None, data_values = None):
super(DiscreteDomain, self).__init__(missing_values = missing_values, missing_value_treatment = missing_value_treatment, missing_value_replacement = missing_value_replacement, invalid_value_treatment = invalid_value_treatment, invalid_value_replacement = invalid_value_replacement, with_data = with_data, with_statistics = with_statistics, dtype = dtype, display_name = display_name)
if data_values:
if not with_data:
raise ValueError("Valid values require with_data attribute")
self.data_values = data_values

def _valid_value_mask(self, X, where):
if hasattr(self, "data_values_"):
Expand All @@ -190,7 +194,8 @@ def _isin_mask(x, values):
mask = numpy.full(X.shape, fill_value = False)
for col in range(X.shape[1]):
col_where = where[:, col]
mask[col_where, col] = _isin_mask(X[col_where, col], data_values[col] if hasattr(data_values, "__len__") else data_values)
col_data_values = data_values[col] if hasattr(data_values, "__len__") else data_values
mask[col_where, col] = _isin_mask(X[col_where, col], col_data_values)
return mask
return super(DiscreteDomain, self)._valid_value_mask(X, where)

Expand All @@ -202,29 +207,39 @@ def fit(self, X, y = None):
return self
X = to_numpy(X)
if self.with_data:
missing_mask = self._missing_value_mask(X)
nonmissing_mask = ~missing_mask
if self.data_values is None:
missing_mask = self._missing_value_mask(X)
nonmissing_mask = ~missing_mask
else:
_check_cols(X, self.data_values)
if is_1d(X):
if _is_pandas_categorical(self.dtype_):
data_values = self.dtype_.categories
else:
data_values = numpy.unique(X[nonmissing_mask])
if (self.missing_value_replacement is not None) and numpy.any(missing_mask) > 0:
if self.data_values is None:
if _is_pandas_categorical(self.dtype_):
raise ValueError()
data_values = numpy.unique(numpy.append(data_values, self.missing_value_replacement))
data_values = self.dtype_.categories
else:
data_values = numpy.unique(X[nonmissing_mask])
if (self.missing_value_replacement is not None) and numpy.any(missing_mask) > 0:
if _is_pandas_categorical(self.dtype_):
raise ValueError()
data_values = numpy.unique(numpy.append(data_values, self.missing_value_replacement))
else:
data_values = numpy.asarray(self.data_values)
self.data_values_ = data_values
else:
if _is_pandas_categorical(self.dtype_):
raise ValueError()
if self.data_values is None:
if _is_pandas_categorical(self.dtype_):
raise ValueError()
self.data_values_ = []
for col in range(X.shape[1]):
col_X = X[:, col]
col_missing_mask = missing_mask[:, col]
col_nonmissing_mask = nonmissing_mask[:, col]
data_values = numpy.unique(col_X[col_nonmissing_mask])
if (self.missing_value_replacement is not None) and numpy.any(col_missing_mask) > 0:
data_values = numpy.unique(numpy.append(data_values, self.missing_value_replacement))
if self.data_values is None:
col_X = X[:, col]
col_missing_mask = missing_mask[:, col]
col_nonmissing_mask = nonmissing_mask[:, col]
data_values = numpy.unique(col_X[col_nonmissing_mask])
if (self.missing_value_replacement is not None) and numpy.any(col_missing_mask) > 0:
data_values = numpy.unique(numpy.append(data_values, self.missing_value_replacement))
else:
data_values = numpy.asarray(self.data_values[col])
self.data_values_.append(data_values)
if self.with_statistics:
missing_mask, valid_mask, invalid_mask = self._compute_masks(X)
Expand All @@ -247,13 +262,13 @@ def fit(self, X, y = None):

class CategoricalDomain(DiscreteDomain):

def __init__(self, missing_values = None, missing_value_treatment = "as_is", missing_value_replacement = None, invalid_value_treatment = "return_invalid", invalid_value_replacement = None, with_data = True, with_statistics = True, dtype = None, display_name = None):
super(CategoricalDomain, self).__init__(missing_values = missing_values, missing_value_treatment = missing_value_treatment, missing_value_replacement = missing_value_replacement, invalid_value_treatment = invalid_value_treatment, invalid_value_replacement = invalid_value_replacement, with_data = with_data, with_statistics = with_statistics, dtype = dtype, display_name = display_name)
def __init__(self, missing_values = None, missing_value_treatment = "as_is", missing_value_replacement = None, invalid_value_treatment = "return_invalid", invalid_value_replacement = None, with_data = True, with_statistics = True, dtype = None, display_name = None, data_values = None):
super(CategoricalDomain, self).__init__(missing_values = missing_values, missing_value_treatment = missing_value_treatment, missing_value_replacement = missing_value_replacement, invalid_value_treatment = invalid_value_treatment, invalid_value_replacement = invalid_value_replacement, with_data = with_data, with_statistics = with_statistics, dtype = dtype, display_name = display_name, data_values = data_values)

class OrdinalDomain(DiscreteDomain):

def __init__(self, missing_values = None, missing_value_treatment = "as_is", missing_value_replacement = None, invalid_value_treatment = "return_invalid", invalid_value_replacement = None, with_data = True, with_statistics = True, dtype = None, display_name = None):
super(OrdinalDomain, self).__init__(missing_values = missing_values, missing_value_treatment = missing_value_treatment, missing_value_replacement = missing_value_replacement, invalid_value_treatment = invalid_value_treatment, invalid_value_replacement = invalid_value_replacement, with_data = with_data, with_statistics = with_statistics, dtype = dtype, display_name = display_name)
def __init__(self, missing_values = None, missing_value_treatment = "as_is", missing_value_replacement = None, invalid_value_treatment = "return_invalid", invalid_value_replacement = None, with_data = True, with_statistics = True, dtype = None, display_name = None, data_values = None):
super(OrdinalDomain, self).__init__(missing_values = missing_values, missing_value_treatment = missing_value_treatment, missing_value_replacement = missing_value_replacement, invalid_value_treatment = invalid_value_treatment, invalid_value_replacement = invalid_value_replacement, with_data = with_data, with_statistics = with_statistics, dtype = dtype, display_name = display_name, data_values = data_values)

def _interquartile_range(X, axis):
quartiles = numpy.nanpercentile(X, [25, 75], axis = axis)
Expand Down
40 changes: 31 additions & 9 deletions sklearn2pmml/decoration/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,15 @@ def test_fit_int_missing(self):
self.assertEqual([0, 0, 2], Xt[:, 0].tolist())
self.assertEqual([1, 2, 0], Xt[:, 1].tolist())

def test_fit_int_categorical(self):
domain = clone(CategoricalDomain(dtype = CategoricalDtype()))
self.assertIsNone(domain.dtype.categories)
self.assertFalse(hasattr(domain, "dtype_"))
X = Series([-1, 0, 1, 0, -1])
Xt = domain.fit_transform(X)
self.assertIsNone(domain.dtype.categories)
self.assertEqual([-1, 0, 1], domain.dtype_.categories.tolist())

def test_fit_int64(self):
domain = clone(CategoricalDomain())
X = Series([-1, None, 1, 2, -1]).astype("Int64")
Expand All @@ -138,15 +147,6 @@ def test_fit_int64(self):
self.assertEqual([-1, 1, 2], domain.data_values_.tolist())
self.assertEqual([-1, pandas.NA, 1, 2, -1], Xt.tolist())

def test_fit_int_categorical(self):
domain = clone(CategoricalDomain(dtype = CategoricalDtype()))
self.assertIsNone(domain.dtype.categories)
self.assertFalse(hasattr(domain, "dtype_"))
X = Series([-1, 0, 1, 0, -1])
Xt = domain.fit_transform(X)
self.assertIsNone(domain.dtype.categories)
self.assertEqual([-1, 0, 1], domain.dtype_.categories.tolist())

def test_fit_string(self):
domain = clone(CategoricalDomain(with_data = False, with_statistics = False))
self.assertTrue(domain._empty_fit())
Expand Down Expand Up @@ -199,6 +199,28 @@ def test_fit_string_missing(self):
self.assertEqual(["0", "0", "1"], Xt[:, 0].tolist())
self.assertEqual(["one", "two", "0"], Xt[:, 1].tolist())

def test_fit_string_valid(self):
domain = clone(CategoricalDomain(data_values = [["1", "2", "3"], ["zero", "one", "two"]], invalid_value_treatment = "as_value", invalid_value_replacement = "X"))
self.assertEqual("as_is", domain.missing_value_treatment)
self.assertIsNone(domain.missing_value_replacement)
self.assertEqual("as_value", domain.invalid_value_treatment)
self.assertEqual("X", domain.invalid_value_replacement)
self.assertTrue(hasattr(domain, "data_values"))
self.assertFalse(hasattr(domain, "data_values_"))
X = DataFrame([["-1", None], ["0", "zero"], ["3", None], ["1", "two"], ["2", "one"], ["0", "three"]])
Xt = domain.fit_transform(X)
self.assertIsInstance(Xt, DataFrame)
self.assertTrue(hasattr(domain, "data_values"))
self.assertTrue(hasattr(domain, "data_values_"))
self.assertEqual(["X", "X", "3", "1", "2", "X"], Xt.iloc[:, 0].tolist())
self.assertEqual([None, "zero", None, "two", "one", "X"], Xt[1].tolist())
self.assertEqual(2, len(domain.counts_))
self.assertEqual({"totalFreq" : 6, "missingFreq" : 0, "invalidFreq" : 3}, domain.counts_[0])
self.assertEqual({"totalFreq" : 6, "missingFreq" : 2, "invalidFreq" : 1}, domain.counts_[1])
self.assertEqual(2, len(domain.discr_stats_))
self.assertEqual({"1" : 1, "2" : 1, "3" : 1}, _value_count(domain.discr_stats_[0]))
self.assertEqual({"zero" : 1, "one" : 1, "two" : 1}, _value_count(domain.discr_stats_[1]))

def test_fit_string_categorical(self):
domain = clone(CategoricalDomain())
X = Categorical(["a", "b", "c", "b", "a"])
Expand Down

0 comments on commit b5a362b

Please sign in to comment.