Skip to content

Commit

Permalink
Refactored 'DiscreteDomain.fit(X, y)' method
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jan 4, 2024
1 parent 3c65556 commit ed842fb
Showing 1 changed file with 32 additions and 36 deletions.
68 changes: 32 additions & 36 deletions sklearn2pmml/decoration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,12 @@ def _isin_mask(x, values):
mask[where] = _isin_mask(X[where], data_values)
return mask
else:
if isinstance(data_values, list):
if X.shape[1] != len(data_values):
raise ValueError()
if hasattr(data_values, "__len__"):
_check_cols(X, data_values)
mask = numpy.full(X.shape, fill_value = False)
for col in range(X.shape[1]):
col_where = where[:, col]
mask[col_where, col] = _isin_mask(X[col_where, col], data_values[col] if isinstance(data_values, list) else data_values)
mask[col_where, col] = _isin_mask(X[col_where, col], data_values[col] if hasattr(data_values, "__len__") else data_values)
return mask
return super(DiscreteDomain, self)._valid_value_mask(X, where)

Expand All @@ -202,50 +201,47 @@ def fit(self, X, y = None):
if self._empty_fit():
return self
X = to_numpy(X)
missing_mask = self._missing_value_mask(X)
nonmissing_mask = ~missing_mask
if is_1d(X):
if self.with_statistics:
values, counts = numpy.unique(X[nonmissing_mask], return_counts = True)
else:
values = numpy.unique(X[nonmissing_mask])
if self.with_data:
if self.with_data:
missing_mask = self._missing_value_mask(X)
nonmissing_mask = ~missing_mask
if is_1d(X):
if _is_pandas_categorical(self.dtype_):
data_values = self.dtype_.categories
else:
data_values = values
data_values = numpy.unique(X[nonmissing_mask])
if (self.missing_value_replacement is not None) and numpy.any(missing_mask) > 0:
if _is_pandas_categorical(self.dtype_):
raise ValueError()
data_values = numpy.unique(numpy.append(data_values, self.missing_value_replacement))
self.data_values_ = data_values
if self.with_statistics:
self.counts_ = _count(missing_mask, nonmissing_mask, None)
self.discr_stats_ = (values, counts)
else:
if self.with_data:
else:
if _is_pandas_categorical(self.dtype_):
raise ValueError()
self.data_values_ = []
if self.with_statistics:
self.counts_ = []
self.discr_stats_ = []
for col in range(X.shape[1]):
col_X = X[:, col]
col_missing_mask = missing_mask[:, col]
col_nonmissing_mask = nonmissing_mask[:, col]
if self.with_statistics:
values, counts = numpy.unique(col_X[col_nonmissing_mask], return_counts = True)
else:
values = numpy.unique(col_X[col_nonmissing_mask])
if self.with_data:
if _is_pandas_categorical(self.dtype_):
raise ValueError()
else:
data_values = values
for col in range(X.shape[1]):
col_X = X[:, col]
col_missing_mask = missing_mask[:, col]
col_nonmissing_mask = nonmissing_mask[:, col]
data_values = numpy.unique(col_X[col_nonmissing_mask])
if (self.missing_value_replacement is not None) and numpy.any(col_missing_mask) > 0:
data_values = numpy.unique(numpy.append(data_values, self.missing_value_replacement))
self.data_values_.append(data_values)
if self.with_statistics:
self.counts_.append(_count(col_missing_mask, ~col_missing_mask, None))
if self.with_statistics:
missing_mask, valid_mask, invalid_mask = self._compute_masks(X)
if is_1d(X):
values, counts = numpy.unique(X[valid_mask], return_counts = True)
self.counts_ = _count(missing_mask, valid_mask, invalid_mask)
self.discr_stats_ = (values, counts)
else:
self.counts_ = []
self.discr_stats_ = []
for col in range(X.shape[1]):
col_X = X[:, col]
col_missing_mask = missing_mask[:, col]
col_valid_mask = valid_mask[:, col]
col_invalid_mask = invalid_mask[:, col]
values, counts = numpy.unique(col_X[col_valid_mask], return_counts = True)
self.counts_.append(_count(col_missing_mask, col_valid_mask, col_invalid_mask))
self.discr_stats_.append((values, counts))
return self

Expand Down

0 comments on commit ed842fb

Please sign in to comment.