Skip to content

Commit

Permalink
Added '_BaseEnsemble.recaller' attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Sep 20, 2023
1 parent ab7bdb4 commit 47ec417
Showing 1 changed file with 23 additions and 13 deletions.
36 changes: 23 additions & 13 deletions sklearn2pmml/ensemble/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def predict_proba(self, X, **predict_proba_params):

class _BaseEnsemble(_BaseComposition):

def __init__(self, steps):
def __init__(self, steps, recaller):
for step in steps:
if type(step) is not tuple:
raise TypeError("Step is not a tuple")
Expand All @@ -191,6 +191,7 @@ def __init__(self, steps):
if not isinstance(predicate, (str, Predicate)):
raise TypeError()
self.steps = steps
self.recaller = recaller

@property
def _steps(self):
Expand All @@ -207,6 +208,11 @@ def set_params(self, **kwargs):
self._set_params("_steps", **kwargs)
return self

def _to_evaluation_dataset(self, X):
if self.recaller is not None:
return self.recaller.transform(X)
return X

def _to_sparse(X, step_mask, step_result):
# Make array
if len(step_result.shape) == 1:
Expand Down Expand Up @@ -256,18 +262,19 @@ def augment(self, X):

class EstimatorChain(_BaseEnsemble):

def __init__(self, steps, multioutput = True):
super(EstimatorChain, self).__init__(steps)
def __init__(self, steps, recaller = None, multioutput = True):
super(EstimatorChain, self).__init__(steps, recaller)
self.multioutput = multioutput

def fit(self, X, y, **fit_params):
if len(y.shape) > 1:
if len(self.steps) != y.shape[1]:
raise ValueError()
y = numpy.asarray(y)
X_eval = self._to_evaluation_dataset(X)
i = 0
for name, estimator, predicate in self.steps:
step_mask = eval_expr_rows(X, predicate, dtype = bool)
step_mask = eval_expr_rows(X_eval, predicate, dtype = bool)
if numpy.sum(step_mask) < 1:
raise ValueError(predicate)
step_X = X[step_mask]
Expand All @@ -286,8 +293,9 @@ def fit(self, X, y, **fit_params):

def _predict(self, X, predict_method):
result = None
X_eval = self._to_evaluation_dataset(X)
for name, estimator, predicate in self.steps:
step_mask = eval_expr_rows(X, predicate, dtype = bool)
step_mask = eval_expr_rows(X_eval, predicate, dtype = bool)
if numpy.sum(step_mask) < 1:
continue
step_X = X[step_mask]
Expand Down Expand Up @@ -316,13 +324,14 @@ def predict_proba(self, X):

class SelectFirstEstimator(_BaseEnsemble):

def __init__(self, steps):
super(SelectFirstEstimator, self).__init__(steps)
def __init__(self, steps, recaller):
super(SelectFirstEstimator, self).__init__(steps, recaller)

def fit(self, X, y, **fit_params):
X_eval = self._to_evaluation_dataset(X)
mask = numpy.zeros(X.shape[0], dtype = bool)
for name, estimator, predicate in self.steps:
step_mask = eval_expr_rows(X, predicate, dtype = bool)
step_mask = eval_expr_rows(X_eval, predicate, dtype = bool)
step_mask[mask] = False
if numpy.sum(step_mask) < 1:
raise ValueError(predicate)
Expand All @@ -335,9 +344,10 @@ def fit(self, X, y, **fit_params):

def _predict(self, X, predict_method):
result = None
X_eval = self._to_evaluation_dataset(X)
mask = numpy.zeros(X.shape[0], dtype = bool)
for name, estimator, predicate in self.steps:
step_mask = eval_expr_rows(X, predicate, dtype = bool)
step_mask = eval_expr_rows(X_eval, predicate, dtype = bool)
step_mask[mask] = False
if numpy.sum(step_mask) < 1:
continue
Expand All @@ -359,13 +369,13 @@ def predict(self, X):

class SelectFirstRegressor(SelectFirstEstimator, RegressorMixin):

def __init__(self, steps):
super(SelectFirstRegressor, self).__init__(steps)
def __init__(self, steps, recaller = None):
super(SelectFirstRegressor, self).__init__(steps, recaller)

class SelectFirstClassifier(SelectFirstEstimator, ClassifierMixin):

def __init__(self, steps):
super(SelectFirstClassifier, self).__init__(steps)
def __init__(self, steps, recaller = None):
super(SelectFirstClassifier, self).__init__(steps, recaller)

def predict_proba(self, X):
return self._predict(X, "predict_proba")

0 comments on commit 47ec417

Please sign in to comment.