From ed842fb28ecac7ecc2caf84ef5cc1ba4b3924f10 Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Thu, 4 Jan 2024 20:58:14 +0200 Subject: [PATCH] Refactored 'DiscreteDomain.fit(X, y)' method --- sklearn2pmml/decoration/__init__.py | 68 ++++++++++++++--------------- 1 file changed, 32 insertions(+), 36 deletions(-) diff --git a/sklearn2pmml/decoration/__init__.py b/sklearn2pmml/decoration/__init__.py index a1d2fbf..9e8c0cc 100644 --- a/sklearn2pmml/decoration/__init__.py +++ b/sklearn2pmml/decoration/__init__.py @@ -185,13 +185,12 @@ def _isin_mask(x, values): mask[where] = _isin_mask(X[where], data_values) return mask else: - if isinstance(data_values, list): - if X.shape[1] != len(data_values): - raise ValueError() + if hasattr(data_values, "__len__"): + _check_cols(X, data_values) mask = numpy.full(X.shape, fill_value = False) for col in range(X.shape[1]): col_where = where[:, col] - mask[col_where, col] = _isin_mask(X[col_where, col], data_values[col] if isinstance(data_values, list) else data_values) + mask[col_where, col] = _isin_mask(X[col_where, col], data_values[col] if hasattr(data_values, "__len__") else data_values) return mask return super(DiscreteDomain, self)._valid_value_mask(X, where) @@ -202,50 +201,47 @@ def fit(self, X, y = None): if self._empty_fit(): return self X = to_numpy(X) - missing_mask = self._missing_value_mask(X) - nonmissing_mask = ~missing_mask - if is_1d(X): - if self.with_statistics: - values, counts = numpy.unique(X[nonmissing_mask], return_counts = True) - else: - values = numpy.unique(X[nonmissing_mask]) - if self.with_data: + if self.with_data: + missing_mask = self._missing_value_mask(X) + nonmissing_mask = ~missing_mask + if is_1d(X): if _is_pandas_categorical(self.dtype_): data_values = self.dtype_.categories else: - data_values = values + data_values = numpy.unique(X[nonmissing_mask]) if (self.missing_value_replacement is not None) and numpy.any(missing_mask) > 0: if _is_pandas_categorical(self.dtype_): raise ValueError() data_values = numpy.unique(numpy.append(data_values, self.missing_value_replacement)) self.data_values_ = data_values - if self.with_statistics: - self.counts_ = _count(missing_mask, nonmissing_mask, None) - self.discr_stats_ = (values, counts) - else: - if self.with_data: + else: + if _is_pandas_categorical(self.dtype_): + raise ValueError() self.data_values_ = [] - if self.with_statistics: - self.counts_ = [] - self.discr_stats_ = [] - for col in range(X.shape[1]): - col_X = X[:, col] - col_missing_mask = missing_mask[:, col] - col_nonmissing_mask = nonmissing_mask[:, col] - if self.with_statistics: - values, counts = numpy.unique(col_X[col_nonmissing_mask], return_counts = True) - else: - values = numpy.unique(col_X[col_nonmissing_mask]) - if self.with_data: - if _is_pandas_categorical(self.dtype_): - raise ValueError() - else: - data_values = values + for col in range(X.shape[1]): + col_X = X[:, col] + col_missing_mask = missing_mask[:, col] + col_nonmissing_mask = nonmissing_mask[:, col] + data_values = numpy.unique(col_X[col_nonmissing_mask]) if (self.missing_value_replacement is not None) and numpy.any(col_missing_mask) > 0: data_values = numpy.unique(numpy.append(data_values, self.missing_value_replacement)) self.data_values_.append(data_values) - if self.with_statistics: - self.counts_.append(_count(col_missing_mask, ~col_missing_mask, None)) + if self.with_statistics: + missing_mask, valid_mask, invalid_mask = self._compute_masks(X) + if is_1d(X): + values, counts = numpy.unique(X[valid_mask], return_counts = True) + self.counts_ = _count(missing_mask, valid_mask, invalid_mask) + self.discr_stats_ = (values, counts) + else: + self.counts_ = [] + self.discr_stats_ = [] + for col in range(X.shape[1]): + col_X = X[:, col] + col_missing_mask = missing_mask[:, col] + col_valid_mask = valid_mask[:, col] + col_invalid_mask = invalid_mask[:, col] + values, counts = numpy.unique(col_X[col_valid_mask], return_counts = True) + self.counts_.append(_count(col_missing_mask, col_valid_mask, col_invalid_mask)) self.discr_stats_.append((values, counts)) return self